Skip to main content

diskann_utils/
strided.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    fmt,
8    ops::{Index, IndexMut},
9};
10use thiserror::Error;
11
12use crate::views::{self, DenseData, MutDenseData};
13
14/// A row-major strided matrix.
15///
16/// This is a generalization of the `MatrixBase` class as it does not mandate a dense
17/// layout in memory.
18///
19/// ```text
20///            |<------ cstride ----->|
21///            |<-- ncols -->|
22///            +-------------+
23/// slice 0 -> | a0 a1 a2 a3 | a4 a5 a6     ^
24/// slice 1 -> | b0 b1 b2 b3 | b4 b5 b6     |
25/// slice 2 -> | c0 c1 c2 c3 | c4 c5 c6   nrows
26/// slice 3 -> | d0 d1 d2 d3 | d4 d5 d6     |
27/// slice 4 -> | e0 e1 e2 e3 | e4 e5 e6     |
28/// slicf 5 -> | f0 f1 f2 f3 | f4 f5 f6     v
29///            +-------------+
30///                  ^
31///                  |
32///             StridedView
33/// ```
34///
35/// This abstraction is useful when performing PQ related operations such as training or
36/// compression as it provides a convenient abstraction for working with columnar subsets
37/// of dense data in-place.
38#[derive(Debug, Clone, Copy)]
39pub struct StridedBase<T>
40where
41    T: DenseData,
42{
43    data: T,
44    nrows: usize,
45    ncols: usize,
46    // The stride along the columns. This must be greater than or equal to `ncols`.
47    cstride: usize,
48}
49
50/// Return the linear length of a slice underlying a `StridedBase` with the given parameters.
51pub fn linear_length(nrows: usize, ncols: usize, cstride: usize) -> usize {
52    (nrows.max(1) - 1) * cstride + nrows.min(1) * ncols
53}
54
55#[derive(Debug, Error)]
56#[non_exhaustive]
57#[error(
58    "tried to construct a strided matrix with {nrows} rows and {ncols} cols and \
59     column stride {cstride} over a slice of length {} (expected {})",
60     len,
61     linear_length(self.nrows, self.ncols, self.cstride)
62)]
63pub struct TryFromErrorLight {
64    len: usize,
65    nrows: usize,
66    ncols: usize,
67    cstride: usize,
68}
69
70#[derive(Error)]
71#[non_exhaustive]
72#[error(
73    "tried to construct a strided matrix with {nrows} rows and {ncols} cols and \
74     column stride {cstride} over a slice of length {} (expected {})",
75     data.as_slice().len(),
76     linear_length(self.nrows, self.ncols, self.cstride)
77)]
78pub struct TryFromError<T: views::DenseData> {
79    data: T,
80    nrows: usize,
81    ncols: usize,
82    cstride: usize,
83}
84
85// Manually implement `fmt::Debug` so we don't require `T::Debug`.
86impl<T: DenseData> fmt::Debug for TryFromError<T> {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("TryFromError")
89            .field("data_len", &self.data.as_slice().len())
90            .field("nrows", &self.nrows)
91            .field("ncols", &self.ncols)
92            .field("cstride", &self.cstride)
93            .finish()
94    }
95}
96
97impl<T: views::DenseData> TryFromError<T> {
98    /// Consume the error and return the base data.
99    pub fn into_inner(self) -> T {
100        self.data
101    }
102
103    /// Drop the data portion of the error and return an equivalent error that is guaranteed
104    /// to be `'static`.
105    pub fn as_static(&self) -> TryFromErrorLight {
106        TryFromErrorLight {
107            len: self.data.as_slice().len(),
108            nrows: self.nrows,
109            ncols: self.ncols,
110            cstride: self.cstride,
111        }
112    }
113}
114
115impl<'a, T> StridedBase<&'a [T]> {
116    /// Construct a strided view over data slice, shrinking the slice as needed.
117    ///
118    /// Returns an error if `data` is shorter than the value returned by `linear_length`.
119    ///
120    /// # Panics
121    ///
122    /// * Panics if `cstride < ncols`.
123    pub fn try_shrink_from(
124        data: &'a [T],
125        nrows: usize,
126        ncols: usize,
127        cstride: usize,
128    ) -> Result<Self, TryFromError<&'a [T]>> {
129        assert!(
130            cstride >= ncols,
131            "cstride must be greater than or equal to ncols"
132        );
133        let required_length = linear_length(nrows, ncols, cstride);
134        match data.get(..required_length) {
135            Some(data) => Ok(Self {
136                data,
137                nrows,
138                ncols,
139                cstride,
140            }),
141            None => Err(TryFromError {
142                data,
143                nrows,
144                ncols,
145                cstride,
146            }),
147        }
148    }
149}
150
151impl<'a, T> StridedBase<&'a mut [T]> {
152    /// Construct a strided view over data slice, shrinking the slice as needed.
153    ///
154    /// Returns an error if `data` is shorter than the value returned by `linear_length`.
155    ///
156    /// # Panics
157    ///
158    /// * Panics if `cstride < ncols`.
159    pub fn try_shrink_from_mut(
160        data: &'a mut [T],
161        nrows: usize,
162        ncols: usize,
163        cstride: usize,
164    ) -> Result<Self, TryFromError<&'a mut [T]>> {
165        assert!(
166            cstride >= ncols,
167            "cstride must be greater than or equal to ncols"
168        );
169        let required_length = linear_length(nrows, ncols, cstride);
170        if data.as_slice().len() >= required_length {
171            Ok(Self {
172                data: &mut data[..required_length],
173                nrows,
174                ncols,
175                cstride,
176            })
177        } else {
178            Err(TryFromError {
179                data,
180                nrows,
181                ncols,
182                cstride,
183            })
184        }
185    }
186}
187
188impl<T> StridedBase<T>
189where
190    T: DenseData,
191{
192    /// Construct a strided view over data slice, shrinking the slice as needed.
193    ///
194    /// Returns an error if `data` is not equal to the expected length as determined
195    /// by `linear_length`.
196    ///
197    /// # Panics
198    ///
199    /// * Panics if `cstride < ncols`.
200    pub fn try_from(
201        data: T,
202        nrows: usize,
203        ncols: usize,
204        cstride: usize,
205    ) -> Result<Self, TryFromError<T>> {
206        assert!(
207            cstride >= ncols,
208            "cstride must be greater than or equal to ncols"
209        );
210        // This computation needs to be set up such that:
211        // 1. When `nrows == 0`, the expected length is 0.
212        // 2. We make a tight upper-bound on the expected length for the last row.
213        let required_length = linear_length(nrows, ncols, cstride);
214        if data.as_slice().len() == required_length {
215            Ok(Self {
216                data,
217                nrows,
218                ncols,
219                cstride,
220            })
221        } else {
222            Err(TryFromError {
223                data,
224                nrows,
225                ncols,
226                cstride,
227            })
228        }
229    }
230
231    /// Return the number of columns in the matrix.
232    pub fn ncols(&self) -> usize {
233        self.ncols
234    }
235
236    /// Return the number of rows in the matrix.
237    pub fn nrows(&self) -> usize {
238        self.nrows
239    }
240
241    /// Return the count of elements between the start of each row.
242    pub fn cstride(&self) -> usize {
243        self.cstride
244    }
245
246    /// Return the underlying data as a slice.
247    ///
248    /// # Note
249    ///
250    /// The underlying representation for a strided matrix is not necessarily dense.
251    pub fn as_slice(&self) -> &[T::Elem] {
252        self.data.as_slice()
253    }
254
255    /// Return the underlying data as a slice.
256    ///
257    /// # Note
258    ///
259    /// The underlying representation for a strided matrix is not necessarily dense.
260    pub fn as_mut_slice(&mut self) -> &mut [T::Elem]
261    where
262        T: MutDenseData,
263    {
264        self.data.as_mut_slice()
265    }
266
267    /// Return row `row` as a slice.
268    ///
269    /// # Panic
270    ///
271    /// Panics if `row >= self.nrows()`.
272    pub fn row(&self, row: usize) -> &[T::Elem] {
273        assert!(
274            row < self.nrows(),
275            "tried to access row {row} of a matrix with {} rows",
276            self.nrows()
277        );
278
279        // SAFETY: `row` is in-bounds.
280        unsafe { self.get_row_unchecked(row) }
281    }
282
283    /// Returns the requested row without boundschecking.
284    ///
285    /// # Safety
286    ///
287    /// The following conditions must hold to avoid undefined behavior:
288    /// * `row < self.nrows()`.
289    pub unsafe fn get_row_unchecked(&self, row: usize) -> &[T::Elem] {
290        debug_assert!(row < self.nrows);
291        let cstride = self.cstride;
292        let ncols = self.ncols;
293        let start = row * cstride;
294
295        debug_assert!(start + ncols <= self.as_slice().len());
296        // SAFETY: The idempotency requirement of `as_slice` and our audited constructors
297        // mean that `self.as_slice()` has a length of `self.nrows * self.ncols`.
298        //
299        // Therefore, this access is in-bounds.
300        unsafe { self.as_slice().get_unchecked(start..start + ncols) }
301    }
302
303    /// Return row `row` as a mutable slice.
304    ///
305    /// # Panics
306    ///
307    /// Panics if `row >= self.nrows()`.
308    pub fn row_mut(&mut self, row: usize) -> &mut [T::Elem]
309    where
310        T: MutDenseData,
311    {
312        assert!(
313            row < self.nrows(),
314            "tried to access row {row} of a matrix with {} rows",
315            self.nrows()
316        );
317
318        // SAFETY: `row` is in-bounds.
319        unsafe { self.get_row_unchecked_mut(row) }
320    }
321
322    /// Returns the requested row without boundschecking.
323    ///
324    /// # Safety
325    ///
326    /// The following conditions must hold to avoid undefined behavior:
327    /// * `row < self.nrows()`.
328    pub unsafe fn get_row_unchecked_mut(&mut self, row: usize) -> &mut [T::Elem]
329    where
330        T: MutDenseData,
331    {
332        debug_assert!(row < self.nrows);
333        let cstride = self.cstride;
334        let ncols = self.ncols;
335        let start = row * cstride;
336
337        debug_assert!(start + ncols <= self.as_slice().len());
338        // SAFETY: The idempotency requirement of `as_mut_slice` and our audited constructors
339        // mean that `self.as_mut_slice()` has a length of `self.nrows * self.ncols`.
340        //
341        // Therefore, this access is in-bounds.
342        unsafe {
343            self.data
344                .as_mut_slice()
345                .get_unchecked_mut(start..start + ncols)
346        }
347    }
348
349    /// Return a iterator over all rows in the matrix.
350    ///
351    /// Rows are yielded sequentially beginning with row 0.
352    ///
353    /// # Panics
354    ///
355    /// Panics if `self.ncols() == 0` (because the implementation does not work correctly
356    /// in this case and it's too corner-case to bother fixing). This restriction may
357    /// be lifted in the future.
358    pub fn row_iter(&self) -> impl Iterator<Item = &[T::Elem]> {
359        assert!(self.ncols() != 0);
360        let ncols = self.ncols;
361
362        self.data
363            .as_slice()
364            .chunks(self.cstride())
365            .map(move |i| &i[..ncols])
366    }
367
368    /// Return a mutable iterator over all rows in the matrix.
369    ///
370    /// Rows are yielded sequentially beginning with row 0.
371    ///
372    /// # Panics
373    ///
374    /// Panics if `self.ncols() == 0` (because the implementation does not work correctly
375    /// in this case and it's too corner-case to bother fixing). This restriction may
376    /// be lifted in the future.
377    pub fn row_iter_mut(&mut self) -> impl Iterator<Item = &mut [T::Elem]>
378    where
379        T: MutDenseData,
380    {
381        assert!(self.ncols() != 0);
382
383        let ncols = self.ncols();
384        let cstride = self.cstride();
385        self.data
386            .as_mut_slice()
387            .chunks_mut(cstride)
388            .map(move |i| &mut i[..ncols])
389    }
390
391    /// Return a pointer to the base of the matrix.
392    pub fn as_ptr(&self) -> *const T::Elem {
393        self.as_slice().as_ptr()
394    }
395
396    /// Return a pointer to the base of the matrix.
397    pub fn as_mut_ptr(&mut self) -> *mut T::Elem
398    where
399        T: MutDenseData,
400    {
401        self.as_mut_slice().as_mut_ptr()
402    }
403
404    /// Returns a reference to an element without boundschecking.
405    ///
406    /// # Safety
407    ///
408    /// The following conditions must hold to avoid undefined behavior:
409    /// * `row < self.nrows()`.
410    /// * `col < self.ncols()`.
411    pub unsafe fn get_unchecked(&self, row: usize, col: usize) -> &T::Elem {
412        debug_assert!(row < self.nrows);
413        debug_assert!(col < self.ncols);
414        self.as_slice().get_unchecked(row * self.cstride + col)
415    }
416
417    /// Returns a mutable reference to an element without boundschecking.
418    ///
419    /// # Safety
420    ///
421    /// The following conditions must hold to avoid undefined behavior:
422    /// * `row < self.nrows()`.
423    /// * `col < self.ncols()`.
424    pub unsafe fn get_unchecked_mut(&mut self, row: usize, col: usize) -> &mut T::Elem
425    where
426        T: MutDenseData,
427    {
428        let cstride = self.cstride;
429        debug_assert!(row < self.nrows);
430        debug_assert!(col < self.ncols);
431        self.as_mut_slice().get_unchecked_mut(row * cstride + col)
432    }
433
434    /// Return a view over the matrix.
435    pub fn as_view(&self) -> StridedView<'_, T::Elem> {
436        StridedView {
437            data: self.as_slice(),
438            nrows: self.nrows,
439            ncols: self.ncols,
440            cstride: self.cstride,
441        }
442    }
443}
444
445pub type StridedView<'a, T> = StridedBase<&'a [T]>;
446pub type MutStridedView<'a, T> = StridedBase<&'a mut [T]>;
447
448/// Return a reference to the item at entry `(row, col)` in the matrix.
449///
450/// # Panics
451///
452/// Panics if `row >= self.nrows()` or `col >= self.ncols()`.
453impl<T> Index<(usize, usize)> for StridedBase<T>
454where
455    T: DenseData,
456{
457    type Output = T::Elem;
458
459    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
460        assert!(
461            row < self.nrows(),
462            "row {row} is out of bounds (max: {})",
463            self.nrows()
464        );
465        assert!(
466            col < self.ncols(),
467            "col {col} is out of bounds (max: {})",
468            self.ncols()
469        );
470        // SAFETY: We have checked that `row` and `col` are in-bounds.
471        unsafe { self.get_unchecked(row, col) }
472    }
473}
474
475/// Return a mutable reference to the item at entry `(row, col)` in the matrix.
476///
477/// # Panics
478///
479/// Panics if `row >= self.nrows()` or `col >= self.ncols()`.
480impl<T> IndexMut<(usize, usize)> for StridedBase<T>
481where
482    T: MutDenseData,
483{
484    fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output {
485        assert!(
486            row < self.nrows(),
487            "row {row} is out of bounds (max: {})",
488            self.nrows()
489        );
490        assert!(
491            col < self.ncols(),
492            "col {col} is out of bounds (max: {})",
493            self.ncols()
494        );
495        // SAFETY: We have checked that `row` and `col` are in-bounds.
496        unsafe { self.get_unchecked_mut(row, col) }
497    }
498}
499
500impl<T, U> From<views::MatrixBase<T>> for StridedBase<U>
501where
502    T: DenseData,
503    U: DenseData,
504    T: Into<U>,
505{
506    fn from(matrix: views::MatrixBase<T>) -> Self {
507        let nrows = matrix.nrows();
508        let ncols = matrix.ncols();
509        Self {
510            data: matrix.into_inner().into(),
511            nrows,
512            ncols,
513            cstride: ncols,
514        }
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_linear_length() {
524        // If the number of rows is zero - the output should always be zero.
525        assert_eq!(linear_length(0, 1, 1), 0);
526        assert_eq!(linear_length(0, 2, 2), 0);
527        assert_eq!(linear_length(0, 2, 3), 0);
528        assert_eq!(linear_length(0, 2, 4), 0);
529
530        // If `cstride == ncols`, then the computation should be trivial.
531        for row in 1..10 {
532            for col in 1..10 {
533                assert_eq!(linear_length(row, col, col), row * col);
534            }
535        }
536
537        // If there is only one row, then `cstride` should be ignored.
538        assert_eq!(linear_length(1, 5, 10), 5);
539        assert_eq!(linear_length(1, 7, 99), 7);
540
541        // Otherwise, the computation is a block of `nrows - 1` chunks of `cstride` and then
542        // `ncols`. Yes - this runs a bunch of computations.
543        for row in 2..10 {
544            for col in 0..10 {
545                for cstride in col..12 {
546                    assert_eq!(linear_length(row, col, cstride), (row - 1) * cstride + col);
547                }
548            }
549        }
550    }
551
552    fn assert_is_static<T: 'static>(_x: &T) {}
553
554    #[test]
555    fn try_from_error_misc() {
556        let x = TryFromError::<&[f32]> {
557            data: &[],
558            nrows: 1,
559            ncols: 2,
560            cstride: 3,
561        };
562
563        let display = format!("{}", x);
564        let debug = format!("{:?}", x);
565        println!("debug = {}", debug);
566        assert!(debug.contains("TryFromError"));
567        assert!(debug.contains("data_len: 0"));
568        assert!(debug.contains("nrows: 1"));
569        assert!(debug.contains("ncols: 2"));
570        assert!(debug.contains("cstride: 3"));
571
572        let x = x.as_static();
573        assert_is_static(&x);
574        assert_eq!(
575            display,
576            format!("{}", x),
577            "static version of the error must hav ethe same message"
578        );
579    }
580
581    fn expected_error(len: usize, nrows: usize, ncols: usize, cstride: usize) -> String {
582        format!(
583            "tried to construct a strided matrix with {nrows} rows and {ncols} cols and \
584             column stride {cstride} over a slice of length {} (expected {})",
585            len,
586            linear_length(nrows, ncols, cstride)
587        )
588    }
589
590    // Test that the contents of `dut` match those in the dense 2d matrix.
591    fn test_indexing(dut: StridedView<'_, usize>, expected: views::MatrixView<'_, usize>) {
592        assert_eq!(dut.nrows(), expected.nrows());
593        assert_eq!(dut.ncols(), expected.ncols());
594
595        // Check the underlying data.
596        if dut.cstride() == dut.ncols() {
597            assert_eq!(dut.as_slice(), expected.as_slice());
598        } else {
599            assert_ne!(dut.as_slice(), expected.as_slice());
600        }
601
602        // Compare via linear indexing.
603        for row in 0..dut.nrows() {
604            for col in 0..dut.ncols() {
605                assert_eq!(
606                    dut[(row, col)],
607                    expected[(row, col)],
608                    "failed on (row, col) = ({}, {})",
609                    row,
610                    col
611                );
612            }
613        }
614
615        // Compare via row.
616        for row in 0..dut.nrows() {
617            assert_eq!(dut.row(row), expected.row(row), "failed on row {}", row);
618        }
619
620        // Compare via row iterators.
621        assert!(dut.row_iter().eq(expected.row_iter()));
622    }
623
624    // Create a base Matrix with the following pattern:
625    // ```text
626    //       0         1         2 ...   ncols-1
627    //   ncols   ncols+1   ncols+2 ... 2*ncols-1
628    // 2*ncols 2*ncols+1 2*ncols+2 ... 3*ncols-1
629    // ...
630    // ```
631    fn create_test_matrix(nrows: usize, ncols: usize) -> views::Matrix<usize> {
632        let mut i = 0;
633        views::Matrix::new(
634            views::Init(|| {
635                let v = i;
636                i += 1;
637                v
638            }),
639            nrows,
640            ncols,
641        )
642    }
643
644    #[test]
645    fn test_basic_indexing() {
646        let m = create_test_matrix(5, 3);
647
648        // First - test a dense StridedView over the entire matrix.
649        let ptr = m.as_ptr();
650        let v = StridedView::try_from(m.as_slice(), m.nrows(), m.ncols(), m.ncols()).unwrap();
651        assert_eq!(v.as_ptr(), ptr, "base pointer was not preserved");
652
653        assert_eq!(v.nrows(), m.nrows());
654        assert_eq!(v.ncols(), m.ncols());
655        assert_eq!(v.cstride(), m.ncols());
656        test_indexing(v, m.as_view());
657
658        // Now - create a truly strided view over the first two columns.
659        let v = StridedView::try_from(
660            &(m.as_slice()[..(4 * m.ncols() + 2)]),
661            m.nrows(),
662            2,
663            m.ncols(),
664        )
665        .unwrap();
666        assert_eq!(v.as_ptr(), ptr, "base pointer was not preserved");
667
668        // Create the expected matrix.
669        let mut expected = views::Matrix::new(0, 5, 2);
670        for row in 0..expected.nrows() {
671            for col in 0..expected.ncols() {
672                expected[(row, col)] = m[(row, col)];
673            }
674        }
675        test_indexing(v, expected.as_view());
676
677        // Create a strided view over the last two columns.
678        let v = StridedView::try_from(&(m.as_slice()[1..]), m.nrows(), 2, m.ncols()).unwrap();
679        let mut expected = views::Matrix::new(0, 5, 2);
680        for row in 0..expected.nrows() {
681            for col in 0..expected.ncols() {
682                expected[(row, col)] = m[(row, col + 1)];
683            }
684        }
685        test_indexing(v, expected.as_view());
686    }
687
688    #[test]
689    fn test_mutable_indexing() {
690        // The source matrix.
691        let src = create_test_matrix(5, 4);
692
693        // Initialize using 2d indexing.
694        {
695            let mut dst = views::Matrix::<usize>::new(0, 5, 10);
696
697            let ptr = dst.as_ptr();
698
699            let ncols = src.ncols();
700            let nrows = src.nrows();
701            let cstride = dst.ncols();
702            let mut dst_view =
703                MutStridedView::try_shrink_from_mut(dst.as_mut_slice(), nrows, ncols, cstride)
704                    .unwrap();
705
706            assert_eq!(dst_view.as_ptr(), ptr);
707            assert_eq!(dst_view.as_mut_ptr().cast_const(), ptr);
708            assert_eq!(dst_view.nrows(), nrows);
709            assert_eq!(dst_view.ncols(), ncols);
710            assert_eq!(dst_view.cstride(), cstride);
711
712            // Initialize using linear indexing.
713            for row in 0..dst_view.nrows() {
714                for col in 0..dst_view.ncols() {
715                    dst_view[(row, col)] = src[(row, col)]
716                }
717            }
718
719            // Check equality.
720            test_indexing(dst_view.as_view(), src.as_view());
721        }
722
723        // Initialize using row-wise indexing.
724        {
725            let mut dst = views::Matrix::<usize>::new(0, 5, 10);
726
727            let ptr = dst.as_ptr();
728
729            let ncols = src.ncols();
730            let nrows = src.nrows();
731            let cstride = dst.ncols();
732            let mut dst_view =
733                MutStridedView::try_shrink_from_mut(dst.as_mut_slice(), nrows, ncols, cstride)
734                    .unwrap();
735
736            assert_eq!(dst_view.as_ptr(), ptr);
737            assert_eq!(dst_view.as_mut_ptr().cast_const(), ptr);
738            assert_eq!(dst_view.nrows(), nrows);
739            assert_eq!(dst_view.ncols(), ncols);
740            assert_eq!(dst_view.cstride(), cstride);
741
742            // Initialize by looping over rows.
743            for row in 0..dst_view.nrows() {
744                dst_view.row_mut(row).copy_from_slice(src.row(row))
745            }
746
747            // Check equality.
748            test_indexing(dst_view.as_view(), src.as_view());
749        }
750
751        // Initialize using row-iterator indexing.
752        {
753            let mut dst = views::Matrix::<usize>::new(0, 5, 10);
754
755            let offset = 2;
756            // SAFETY: The underlying allocation is valid for much more than 2 elements.
757            let ptr = unsafe { dst.as_ptr().add(offset) };
758
759            let ncols = src.ncols();
760            let nrows = src.nrows();
761            let cstride = dst.ncols();
762            let mut dst_view = MutStridedView::try_shrink_from_mut(
763                &mut dst.as_mut_slice()[2..],
764                nrows,
765                ncols,
766                cstride,
767            )
768            .unwrap();
769
770            assert_eq!(dst_view.as_ptr(), ptr);
771            assert_eq!(dst_view.as_mut_ptr().cast_const(), ptr);
772            assert_eq!(dst_view.nrows(), nrows);
773            assert_eq!(dst_view.ncols(), ncols);
774            assert_eq!(dst_view.cstride(), cstride);
775
776            // Initialize using row iterators.
777            for (d, s) in std::iter::zip(dst_view.row_iter_mut(), src.row_iter()) {
778                d.copy_from_slice(s)
779            }
780
781            // Check equality.
782            test_indexing(dst_view.as_view(), src.as_view());
783        }
784    }
785
786    #[test]
787    fn matrix_conversion() {
788        let m = create_test_matrix(3, 4);
789        let ptr = m.as_ptr();
790        let v: StridedView<_> = m.as_view().into();
791        assert_eq!(v.as_ptr(), ptr);
792        test_indexing(v, m.as_view());
793    }
794
795    #[test]
796    fn test_zero_sized() {
797        let m = create_test_matrix(5, 5);
798        let v = StridedView::try_shrink_from(m.as_slice(), 0, 4, 5).unwrap();
799
800        assert_eq!(v.nrows(), 0);
801        assert_eq!(v.ncols(), 4);
802        assert_eq!(v.cstride(), 5);
803
804        let v = StridedView::try_shrink_from(m.as_slice(), 5, 0, 5).unwrap();
805        assert_eq!(v.nrows(), 5);
806        assert_eq!(v.ncols(), 0);
807        assert_eq!(v.cstride(), 5);
808
809        for row in 0..v.nrows() {
810            let empty: &[usize] = &[];
811            assert_eq!(v.row(row), empty);
812        }
813    }
814
815    #[test]
816    #[should_panic]
817    fn test_row_iter_panics() {
818        let m = create_test_matrix(5, 5);
819        let v = StridedView::try_shrink_from(m.as_slice(), 5, 0, 5).unwrap();
820        let _ = v.row_iter();
821    }
822
823    #[test]
824    #[should_panic]
825    fn test_row_iter_mut_panics() {
826        let mut m = create_test_matrix(5, 5);
827        let mut v = MutStridedView::try_shrink_from_mut(m.as_mut_slice(), 5, 0, 5).unwrap();
828        let _ = v.row_iter_mut();
829    }
830
831    #[test]
832    fn test_try_shrink_from() {
833        // Exact is okay.
834        let m = views::Matrix::<usize>::new(0, 10, 10);
835        let nrows = m.nrows();
836        let ncols = m.ncols();
837        let s = StridedView::try_shrink_from(m.as_slice(), nrows, ncols, ncols).unwrap();
838        assert_eq!(s.as_slice(), m.as_slice());
839
840        // Giving a slice that is too large is okay.
841        let s = StridedView::try_shrink_from(m.as_slice(), nrows, 5, ncols).unwrap();
842        assert_eq!(s.as_ptr(), m.as_ptr());
843
844        // Too small is a problem.
845        let s = StridedView::try_shrink_from(m.as_slice(), nrows, ncols, ncols + 1);
846        assert!(s.is_err());
847        let err = s.unwrap_err();
848        assert_eq!(
849            err.to_string(),
850            expected_error(m.as_slice().len(), nrows, ncols, ncols + 1)
851        );
852        assert_eq!(err.into_inner(), m.as_slice());
853    }
854
855    #[test]
856    #[should_panic(expected = "cstride must be greater than or equal to ncols")]
857    fn test_try_shink_from_panics() {
858        let m = views::Matrix::<usize>::new(0, 4, 4);
859        let _ = StridedView::try_shrink_from(m.as_slice(), 2, 2, 1);
860    }
861
862    #[test]
863    fn test_try_shrink_from_mut() {
864        // Exact is okay.
865        let mut m = views::Matrix::<usize>::new(0, 10, 10);
866
867        let nrows = m.nrows();
868        let ncols = m.ncols();
869        let ptr = m.as_ptr();
870        let len = m.as_slice().len();
871
872        let s = MutStridedView::try_shrink_from_mut(m.as_mut_slice(), nrows, ncols, ncols).unwrap();
873        assert_eq!(s.as_ptr(), ptr);
874        assert_eq!(s.as_slice().len(), len);
875
876        // Giving a slice that is too large is okay.
877        let s = MutStridedView::try_shrink_from_mut(m.as_mut_slice(), nrows, 5, ncols).unwrap();
878        assert_eq!(s.as_ptr(), ptr);
879
880        // Too small is a problem.
881        let s = MutStridedView::try_shrink_from_mut(m.as_mut_slice(), nrows, ncols, ncols + 1);
882        assert!(s.is_err());
883        let err = s.unwrap_err();
884        assert_eq!(
885            err.to_string(),
886            expected_error(len, nrows, ncols, ncols + 1)
887        );
888    }
889
890    #[test]
891    #[should_panic(expected = "cstride must be greater than or equal to ncols")]
892    fn test_try_shink_from_mut_panics() {
893        let mut m = views::Matrix::<usize>::new(0, 4, 4);
894        let _ = MutStridedView::try_shrink_from_mut(m.as_mut_slice(), 2, 2, 1);
895    }
896
897    #[test]
898    fn test_try_from() {
899        // Exact is okay.
900        let m = views::Matrix::<usize>::new(0, 10, 10);
901        let nrows = m.nrows();
902        let ncols = m.ncols();
903        let s = StridedView::try_from(m.as_slice(), nrows, ncols, ncols).unwrap();
904        assert_eq!(s.as_slice(), m.as_slice());
905
906        // Giving a slice that is too large is a problem.
907        let s = StridedView::try_from(m.as_slice(), nrows, 5, ncols);
908        assert!(s.is_err());
909        let err = s.unwrap_err();
910        assert_eq!(
911            err.to_string(),
912            expected_error(m.as_slice().len(), nrows, 5, ncols)
913        );
914
915        // Too small is a problem.
916        let s = StridedView::try_from(m.as_slice(), nrows, ncols, ncols + 1);
917        assert!(s.is_err());
918        let err = s.unwrap_err();
919        assert_eq!(
920            err.to_string(),
921            expected_error(m.as_slice().len(), nrows, ncols, ncols + 1)
922        );
923        assert_eq!(err.into_inner(), m.as_slice());
924    }
925
926    #[test]
927    #[should_panic(expected = "cstride must be greater than or equal to ncols")]
928    fn test_try_frompanics() {
929        let mut m = views::Matrix::<usize>::new(0, 4, 4);
930        let _ = MutStridedView::try_from(m.as_mut_slice(), 2, 2, 1);
931    }
932
933    #[test]
934    #[should_panic(expected = "tried to access row 3 of a matrix with 3 rows")]
935    fn test_get_row_panics() {
936        let m = views::Matrix::<usize>::new(0, 3, 7);
937        let v: StridedView<_> = m.as_view().into();
938        v.row(3);
939    }
940
941    #[test]
942    #[should_panic(expected = "tried to access row 3 of a matrix with 3 rows")]
943    fn test_get_row_mut_panics() {
944        let mut m = views::Matrix::<usize>::new(0, 3, 7);
945        let mut v: MutStridedView<_> = m.as_mut_view().into();
946        v.row_mut(3);
947    }
948
949    #[test]
950    #[should_panic(expected = "row 3 is out of bounds (max: 3)")]
951    fn test_index_panics_row() {
952        let m = views::Matrix::<usize>::new(0, 3, 7);
953        let v: StridedView<_> = m.as_view().into();
954        let _ = v[(3, 2)];
955    }
956
957    #[test]
958    #[should_panic(expected = "col 7 is out of bounds (max: 7)")]
959    fn test_index_panics_col() {
960        let m = views::Matrix::<usize>::new(0, 3, 7);
961        let v: StridedView<_> = m.as_view().into();
962        let _ = v[(2, 7)];
963    }
964
965    #[test]
966    #[should_panic(expected = "row 3 is out of bounds (max: 3)")]
967    fn test_index_mut_panics_row() {
968        let mut m = views::Matrix::<usize>::new(0, 3, 7);
969        let mut v: MutStridedView<_> = m.as_mut_view().into();
970        v[(3, 2)] = 1;
971    }
972
973    #[test]
974    #[should_panic(expected = "col 7 is out of bounds (max: 7)")]
975    fn test_index_mut_panics_col() {
976        let mut m = views::Matrix::<usize>::new(0, 3, 7);
977        let mut v: MutStridedView<_> = m.as_mut_view().into();
978        v[(2, 7)] = 1;
979    }
980}