Skip to main content

lib_q_stark_matrix/
lib.rs

1//! Matrix library.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{
9    Debug,
10    Display,
11    Formatter,
12};
13use core::ops::Deref;
14
15use itertools::{
16    Itertools,
17    izip,
18};
19use lib_q_stark_field::{
20    BasedVectorSpace,
21    ExtensionField,
22    Field,
23    PackedValue,
24    PrimeCharacteristicRing,
25    dot_product,
26};
27use lib_q_stark_rayon::prelude::*;
28use strided::{
29    VerticallyStridedMatrixView,
30    VerticallyStridedRowIndexMap,
31};
32
33use crate::dense::RowMajorMatrix;
34
35pub mod bitrev;
36pub mod dense;
37pub mod extension;
38pub mod horizontally_truncated;
39pub mod row_index_mapped;
40pub mod stack;
41pub mod strided;
42pub mod util;
43
44/// A simple struct representing the shape of a matrix.
45///
46/// The `Dimensions` type stores the number of columns (`width`) and rows (`height`)
47/// of a matrix. It is commonly used for querying and displaying matrix shapes.
48#[derive(Copy, Clone, PartialEq, Eq)]
49pub struct Dimensions {
50    /// Number of columns in the matrix.
51    pub width: usize,
52    /// Number of rows in the matrix.
53    pub height: usize,
54}
55
56impl Debug for Dimensions {
57    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
58        write!(f, "{}x{}", self.width, self.height)
59    }
60}
61
62impl Display for Dimensions {
63    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
64        write!(f, "{}x{}", self.width, self.height)
65    }
66}
67
68/// A generic trait for two-dimensional matrix-like data structures.
69///
70/// The `Matrix` trait provides a uniform interface for accessing rows, elements,
71/// and computing with matrices in both sequential and parallel contexts. It supports
72/// packing strategies for SIMD optimizations and interaction with extension fields.
73pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
74    /// Returns the number of columns in the matrix.
75    fn width(&self) -> usize;
76
77    /// Returns the number of rows in the matrix.
78    fn height(&self) -> usize;
79
80    /// Returns the dimensions (width, height) of the matrix.
81    fn dimensions(&self) -> Dimensions {
82        Dimensions {
83            width: self.width(),
84            height: self.height(),
85        }
86    }
87
88    // The methods:
89    // get, get_unchecked, row, row_unchecked, row_subseq_unchecked, row_slice, row_slice_unchecked, row_subslice_unchecked
90    // are all defined in a circular manner so you only need to implement a subset of them.
91    // In particular is is enough to implement just one of: row_unchecked, row_subseq_unchecked
92    //
93    // That being said, most implementations will want to implement several methods for performance reasons.
94
95    /// Returns the element at the given row and column.
96    ///
97    /// Returns `None` if either `r >= height()` or `c >= width()`.
98    #[inline]
99    fn get(&self, r: usize, c: usize) -> Option<T> {
100        (r < self.height() && c < self.width()).then(|| unsafe {
101            // Safety: Clearly `r < self.height()` and `c < self.width()`.
102            self.get_unchecked(r, c)
103        })
104    }
105
106    /// Returns the element at the given row and column.
107    ///
108    /// For a safe alternative, see `get`.
109    ///
110    /// # Safety
111    /// The caller must ensure that `r < self.height()` and `c < self.width()`.
112    /// Breaking any of these assumptions is considered undefined behaviour.
113    #[inline]
114    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
115        unsafe { self.row_slice_unchecked(r)[c].clone() }
116    }
117
118    /// Returns an iterator over the elements of the `r`-th row.
119    ///
120    /// The iterator will have `self.width()` elements.
121    ///
122    /// Returns `None` if `r >= height()`.
123    #[inline]
124    fn row(
125        &self,
126        r: usize,
127    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
128        (r < self.height()).then(|| unsafe {
129            // Safety: Clearly `r < self.height()`.
130            self.row_unchecked(r)
131        })
132    }
133
134    /// Returns an iterator over the elements of the `r`-th row.
135    ///
136    /// The iterator will have `self.width()` elements.
137    ///
138    /// For a safe alternative, see `row`.
139    ///
140    /// # Safety
141    /// The caller must ensure that `r < self.height()`.
142    /// Breaking this assumption is considered undefined behaviour.
143    #[inline]
144    unsafe fn row_unchecked(
145        &self,
146        r: usize,
147    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
148        unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
149    }
150
151    /// Returns an iterator over the elements of the `r`-th row from position `start` to `end`.
152    ///
153    /// When `start = 0` and `end = width()`, this is equivalent to `row_unchecked`.
154    ///
155    /// For a safe alternative, use `row`, along with the `skip` and `take` iterator methods.
156    ///
157    /// # Safety
158    /// The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
159    /// Breaking any of these assumptions is considered undefined behaviour.
160    #[inline]
161    unsafe fn row_subseq_unchecked(
162        &self,
163        r: usize,
164        start: usize,
165        end: usize,
166    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
167        unsafe {
168            self.row_unchecked(r)
169                .into_iter()
170                .skip(start)
171                .take(end - start)
172        }
173    }
174
175    /// Returns the elements of the `r`-th row as something which can be coerced to a slice.
176    ///
177    /// Returns `None` if `r >= height()`.
178    #[inline]
179    fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
180        (r < self.height()).then(|| unsafe {
181            // Safety: Clearly `r < self.height()`.
182            self.row_slice_unchecked(r)
183        })
184    }
185
186    /// Returns the elements of the `r`-th row as something which can be coerced to a slice.
187    ///
188    /// For a safe alternative, see `row_slice`.
189    ///
190    /// # Safety
191    /// The caller must ensure that `r < self.height()`.
192    /// Breaking this assumption is considered undefined behaviour.
193    #[inline]
194    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
195        unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
196    }
197
198    /// Returns a subset of elements of the `r`-th row as something which can be coerced to a slice.
199    ///
200    /// When `start = 0` and `end = width()`, this is equivalent to `row_slice_unchecked`.
201    ///
202    /// For a safe alternative, see `row_slice`.
203    ///
204    /// # Safety
205    /// The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
206    /// Breaking any of these assumptions is considered undefined behaviour.
207    #[inline]
208    unsafe fn row_subslice_unchecked(
209        &self,
210        r: usize,
211        start: usize,
212        end: usize,
213    ) -> impl Deref<Target = [T]> {
214        unsafe {
215            self.row_subseq_unchecked(r, start, end)
216                .into_iter()
217                .collect_vec()
218        }
219    }
220
221    /// Returns an iterator over all rows in the matrix.
222    #[inline]
223    fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
224        unsafe {
225            // Safety: `r` always satisfies `r < self.height()`.
226            (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
227        }
228    }
229
230    /// Returns a parallel iterator over all rows in the matrix.
231    #[inline]
232    fn par_rows(
233        &self,
234    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
235        unsafe {
236            // Safety: `r` always satisfies `r < self.height()`.
237            (0..self.height())
238                .into_par_iter()
239                .map(move |r| self.row_unchecked(r).into_iter())
240        }
241    }
242
243    /// Collect the elements of the rows `r` through `r + c`. If anything is larger than `self.height()`
244    /// simply wrap around to the beginning of the matrix.
245    fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
246        unsafe {
247            // Safety: Thank to the `%`, the rows index is always less than `self.height()`.
248            (0..c)
249                .map(|i| self.row_slice_unchecked((r + i) % self.height()))
250                .collect_vec()
251        }
252    }
253
254    /// Returns an iterator over the first row of the matrix.
255    ///
256    /// Returns None if `height() == 0`.
257    #[inline]
258    fn first_row(
259        &self,
260    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
261        self.row(0)
262    }
263
264    /// Returns an iterator over the last row of the matrix.
265    ///
266    /// Returns None if `height() == 0`.
267    #[inline]
268    fn last_row(
269        &self,
270    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
271        if self.height() == 0 {
272            None
273        } else {
274            // Safety: Clearly `self.height() - 1 < self.height()`.
275            unsafe { Some(self.row_unchecked(self.height() - 1)) }
276        }
277    }
278
279    /// Converts the matrix into a `RowMajorMatrix` by collecting all rows into a single vector.
280    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
281    where
282        Self: Sized,
283        T: Clone,
284    {
285        RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
286    }
287
288    /// Get a packed iterator over the `r`-th row.
289    ///
290    /// If the row length is not divisible by the packing width, the final elements
291    /// are returned as a base iterator with length `<= P::WIDTH - 1`.
292    ///
293    /// # Panics
294    /// Panics if `r >= height()`.
295    fn horizontally_packed_row<'a, P>(
296        &'a self,
297        r: usize,
298    ) -> (
299        impl Iterator<Item = P> + Send + Sync,
300        impl Iterator<Item = T> + Send + Sync,
301    )
302    where
303        P: PackedValue<Value = T>,
304        T: Clone + 'a,
305    {
306        assert!(r < self.height(), "Row index out of bounds.");
307        let num_packed = self.width() / P::WIDTH;
308        unsafe {
309            // Safety: We have already checked that `r < height()`.
310            let mut iter = self
311                .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
312                .into_iter();
313
314            // array::from_fn is guaranteed to always call in order.
315            let packed =
316                (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
317
318            let sfx = self
319                .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
320                .into_iter();
321            (packed, sfx)
322        }
323    }
324
325    /// Get a packed iterator over the `r`-th row.
326    ///
327    /// If the row length is not divisible by the packing width, the final entry will be zero-padded.
328    ///
329    /// # Panics
330    /// Panics if `r >= height()`.
331    fn padded_horizontally_packed_row<'a, P>(
332        &'a self,
333        r: usize,
334    ) -> impl Iterator<Item = P> + Send + Sync
335    where
336        P: PackedValue<Value = T>,
337        T: Clone + Default + 'a,
338    {
339        let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
340        let num_elems = self.width().div_ceil(P::WIDTH);
341        // array::from_fn is guaranteed to always call in order.
342        (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
343    }
344
345    /// Get a parallel iterator over all packed rows of the matrix.
346    ///
347    /// If the matrix width is not divisible by the packing width, the final elements
348    /// of each row are returned as a base iterator with length `<= P::WIDTH - 1`.
349    fn par_horizontally_packed_rows<'a, P>(
350        &'a self,
351    ) -> impl IndexedParallelIterator<
352        Item = (
353            impl Iterator<Item = P> + Send + Sync,
354            impl Iterator<Item = T> + Send + Sync,
355        ),
356    >
357    where
358        P: PackedValue<Value = T>,
359        T: Clone + 'a,
360    {
361        (0..self.height())
362            .into_par_iter()
363            .map(|r| self.horizontally_packed_row(r))
364    }
365
366    /// Get a parallel iterator over all packed rows of the matrix.
367    ///
368    /// If the matrix width is not divisible by the packing width, the final entry of each row will be zero-padded.
369    fn par_padded_horizontally_packed_rows<'a, P>(
370        &'a self,
371    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
372    where
373        P: PackedValue<Value = T>,
374        T: Clone + Default + 'a,
375    {
376        (0..self.height())
377            .into_par_iter()
378            .map(|r| self.padded_horizontally_packed_row(r))
379    }
380
381    /// Pack together a collection of adjacent rows from the matrix.
382    ///
383    /// Returns an iterator whose i'th element is packing of the i'th element of the
384    /// rows r through r + P::WIDTH - 1. If we exceed the height of the matrix,
385    /// wrap around and include initial rows.
386    #[inline]
387    fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
388    where
389        T: Copy,
390        P: PackedValue<Value = T>,
391    {
392        // Precompute row slices once to minimize redundant calls and improve performance.
393        let rows = self.wrapping_row_slices(r, P::WIDTH);
394
395        // Using precomputed rows avoids repeatedly calling `row_slice`, which is costly.
396        (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
397    }
398
399    /// Pack together a collection of rows and "next" rows from the matrix.
400    ///
401    /// Returns a vector corresponding to 2 packed rows. The i'th element of the first
402    /// row contains the packing of the i'th element of the rows r through r + P::WIDTH - 1.
403    /// The i'th element of the second row contains the packing of the i'th element of the
404    /// rows r + step through r + step + P::WIDTH - 1. If at some point we exceed the
405    /// height of the matrix, wrap around and include initial rows.
406    #[inline]
407    fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
408    where
409        T: Copy,
410        P: PackedValue<Value = T>,
411    {
412        // Whilst it would appear that this can be replaced by two calls to vertically_packed_row
413        // tests seem to indicate that combining them in the same function is slightly faster.
414        // It's probably allowing the compiler to make some optimizations on the fly.
415
416        let rows = self.wrapping_row_slices(r, P::WIDTH);
417        let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
418
419        (0..self.width())
420            .map(|c| P::from_fn(|i| rows[i][c]))
421            .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
422            .collect_vec()
423    }
424
425    /// Returns a view over a vertically strided submatrix.
426    ///
427    /// The view selects rows using `r = offset + i * stride` for each `i`.
428    fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
429    where
430        Self: Sized,
431    {
432        VerticallyStridedRowIndexMap::new_view(self, stride, offset)
433    }
434
435    /// Compute Mᵀv, aka premultiply this matrix by the given vector,
436    /// aka scale each row by the corresponding entry in `v` and take the sum across rows.
437    /// `v` can be a vector of extension elements.
438    fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
439    where
440        T: Field,
441        EF: ExtensionField<T>,
442    {
443        let packed_width = self.width().div_ceil(T::Packing::WIDTH);
444
445        let packed_result = self
446            .par_padded_horizontally_packed_rows::<T::Packing>()
447            .zip(v)
448            .par_fold_reduce(
449                || EF::ExtensionPacking::zero_vec(packed_width),
450                |mut acc, (row, &scale)| {
451                    let scale = EF::ExtensionPacking::from_basis_coefficients_fn(|i| {
452                        T::Packing::from(scale.as_basis_coefficients_slice()[i])
453                    });
454                    izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
455                    acc
456                },
457                |mut acc_l, acc_r| {
458                    izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
459                    acc_l
460                },
461            );
462
463        packed_result
464            .into_iter()
465            .flat_map(|p| {
466                (0..T::Packing::WIDTH).map(move |i| {
467                    EF::from_basis_coefficients_fn(|j| {
468                        p.as_basis_coefficients_slice()[j].as_slice()[i]
469                    })
470                })
471            })
472            .take(self.width())
473            .collect()
474    }
475
476    /// Compute the matrix vector product `M . vec`, aka take the dot product of each
477    /// row of `M` by `vec`. If the length of `vec` is longer than the width of `M`,
478    /// `vec` is truncated to the first `width()` elements.
479    ///
480    /// We make use of `PackedFieldExtension` to speed up computations. Thus `vec` is passed in as
481    /// a slice of `PackedFieldExtension` elements.
482    ///
483    /// # Panics
484    /// This function panics if the length of `vec` is less than `self.width().div_ceil(T::Packing::WIDTH)`.
485    fn rowwise_packed_dot_product<EF>(
486        &self,
487        vec: &[EF::ExtensionPacking],
488    ) -> impl IndexedParallelIterator<Item = EF>
489    where
490        T: Field,
491        EF: ExtensionField<T>,
492    {
493        // The length of a `padded_horizontally_packed_row` is `self.width().div_ceil(T::Packing::WIDTH)`.
494        assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
495
496        // TODO: This is a base - extension dot product and so it should
497        // be possible to speed this up using ideas in `packed_linear_combination`.
498        // TODO: Perhaps we should be packing rows vertically not horizontally.
499        self.par_padded_horizontally_packed_rows::<T::Packing>()
500            .map(move |row_packed| {
501                let packed_sum_of_packed: EF::ExtensionPacking =
502                    dot_product(vec.iter().copied(), row_packed);
503                let sum_of_packed: EF = EF::from_basis_coefficients_fn(|i| {
504                    packed_sum_of_packed.as_basis_coefficients_slice()[i]
505                        .as_slice()
506                        .iter()
507                        .copied()
508                        .sum()
509                });
510                sum_of_packed
511            })
512    }
513}
514
515#[cfg(any())]
516mod tests {
517    use alloc::vec::Vec;
518    use alloc::{
519        format,
520        vec,
521    };
522
523    use itertools::izip;
524    use lib_q_stark_field::PrimeCharacteristicRing;
525    use lib_q_stark_field::extension::{
526        BinomialExtensionField,
527        Complex,
528    };
529    use lib_q_stark_mersenne31::Mersenne31;
530    use rand::SeedableRng;
531    use rand::rngs::SmallRng;
532
533    use super::*;
534
535    #[test]
536    fn test_columnwise_dot_product() {
537        // Use Complex<Mersenne31> as the base field since EF extends it, not Mersenne31 directly
538        type F = Complex<Mersenne31>;
539        // Mersenne31 doesn't support direct degree 4 extensions. Use Complex<Mersenne31> with degree 2.
540        type EF = BinomialExtensionField<Complex<Mersenne31>, 2>;
541
542        let mut rng = SmallRng::seed_from_u64(1);
543        let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
544        let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
545
546        let mut expected = vec![EF::ZERO; m.width()];
547        for (row, &scale) in izip!(m.rows(), &v) {
548            for (l, r) in izip!(&mut expected, row) {
549                // Convert base field element (Complex<Mersenne31>) to extension field (EF)
550                let r_ext = EF::from(r);
551                *l += scale * r_ext;
552            }
553        }
554
555        assert_eq!(m.columnwise_dot_product(&v), expected);
556    }
557
558    // Mock implementation for testing purposes
559    struct MockMatrix {
560        data: Vec<Vec<u32>>,
561        width: usize,
562        height: usize,
563    }
564
565    impl Matrix<u32> for MockMatrix {
566        fn width(&self) -> usize {
567            self.width
568        }
569
570        fn height(&self) -> usize {
571            self.height
572        }
573
574        unsafe fn row_unchecked(
575            &self,
576            r: usize,
577        ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
578        {
579            // Just a mock implementation so we just do the easy safe thing.
580            self.data[r].clone()
581        }
582    }
583
584    #[test]
585    fn test_dimensions() {
586        let dims = Dimensions {
587            width: 3,
588            height: 5,
589        };
590        assert_eq!(dims.width, 3);
591        assert_eq!(dims.height, 5);
592        assert_eq!(format!("{dims:?}"), "3x5");
593        assert_eq!(format!("{dims}"), "3x5");
594    }
595
596    #[test]
597    fn test_mock_matrix_dimensions() {
598        let matrix = MockMatrix {
599            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
600            width: 3,
601            height: 3,
602        };
603        assert_eq!(matrix.width(), 3);
604        assert_eq!(matrix.height(), 3);
605        assert_eq!(
606            matrix.dimensions(),
607            Dimensions {
608                width: 3,
609                height: 3
610            }
611        );
612    }
613
614    #[test]
615    fn test_first_row() {
616        let matrix = MockMatrix {
617            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
618            width: 3,
619            height: 3,
620        };
621        let mut first_row = matrix.first_row().unwrap().into_iter();
622        assert_eq!(first_row.next(), Some(1));
623        assert_eq!(first_row.next(), Some(2));
624        assert_eq!(first_row.next(), Some(3));
625    }
626
627    #[test]
628    fn test_last_row() {
629        let matrix = MockMatrix {
630            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
631            width: 3,
632            height: 3,
633        };
634        let mut last_row = matrix.last_row().unwrap().into_iter();
635        assert_eq!(last_row.next(), Some(7));
636        assert_eq!(last_row.next(), Some(8));
637        assert_eq!(last_row.next(), Some(9));
638    }
639
640    #[test]
641    fn test_first_last_row_empty_matrix() {
642        let matrix = MockMatrix {
643            data: vec![],
644            width: 3,
645            height: 0,
646        };
647        let first_row = matrix.first_row();
648        let last_row = matrix.last_row();
649        assert!(first_row.is_none());
650        assert!(last_row.is_none());
651    }
652
653    #[test]
654    fn test_to_row_major_matrix() {
655        let matrix = MockMatrix {
656            data: vec![vec![1, 2], vec![3, 4]],
657            width: 2,
658            height: 2,
659        };
660        let row_major = matrix.to_row_major_matrix();
661        assert_eq!(row_major.values, vec![1, 2, 3, 4]);
662        assert_eq!(row_major.width, 2);
663    }
664
665    #[test]
666    fn test_matrix_get_methods() {
667        let matrix = MockMatrix {
668            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
669            width: 3,
670            height: 3,
671        };
672        assert_eq!(matrix.get(0, 0), Some(1));
673        assert_eq!(matrix.get(1, 2), Some(6));
674        assert_eq!(matrix.get(2, 1), Some(8));
675
676        unsafe {
677            assert_eq!(matrix.get_unchecked(0, 1), 2);
678            assert_eq!(matrix.get_unchecked(1, 0), 4);
679            assert_eq!(matrix.get_unchecked(2, 2), 9);
680        }
681
682        assert_eq!(matrix.get(3, 0), None); // Height out of bounds
683        assert_eq!(matrix.get(0, 3), None); // Width out of bounds
684    }
685
686    #[test]
687    fn test_matrix_row_methods_iteration() {
688        let matrix = MockMatrix {
689            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
690            width: 3,
691            height: 3,
692        };
693
694        let mut row_iter = matrix.row(1).unwrap().into_iter();
695        assert_eq!(row_iter.next(), Some(4));
696        assert_eq!(row_iter.next(), Some(5));
697        assert_eq!(row_iter.next(), Some(6));
698        assert_eq!(row_iter.next(), None);
699
700        unsafe {
701            let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
702            assert_eq!(row_iter_unchecked.next(), Some(7));
703            assert_eq!(row_iter_unchecked.next(), Some(8));
704            assert_eq!(row_iter_unchecked.next(), Some(9));
705            assert_eq!(row_iter_unchecked.next(), None);
706
707            let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
708            assert_eq!(row_iter_subset.next(), Some(2));
709            assert_eq!(row_iter_subset.next(), Some(3));
710            assert_eq!(row_iter_subset.next(), None);
711        }
712
713        assert!(matrix.row(3).is_none()); // Height out of bounds
714    }
715
716    #[test]
717    fn test_row_slice_methods() {
718        let matrix = MockMatrix {
719            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
720            width: 3,
721            height: 3,
722        };
723        let row_slice = matrix.row_slice(1).unwrap();
724        assert_eq!(*row_slice, [4, 5, 6]);
725        unsafe {
726            let row_slice_unchecked = matrix.row_slice_unchecked(2);
727            assert_eq!(*row_slice_unchecked, [7, 8, 9]);
728
729            let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
730            assert_eq!(*row_subslice, [2]);
731        }
732
733        assert!(matrix.row_slice(3).is_none()); // Height out of bounds
734    }
735
736    #[test]
737    fn test_matrix_rows() {
738        let matrix = MockMatrix {
739            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
740            width: 3,
741            height: 3,
742        };
743
744        let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
745        assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
746    }
747}