Skip to main content

p3_matrix/
lib.rs

1//! Matrix library.
2
3#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::Itertools;
12use p3_field::{
13    BasedVectorSpace, ExtensionField, Field, FieldArray, PackedFieldExtension, PackedValue,
14    PrimeCharacteristicRing,
15};
16use p3_maybe_rayon::prelude::*;
17use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
18use tracing::instrument;
19
20use crate::dense::RowMajorMatrix;
21
22pub mod bitrev;
23pub mod dense;
24pub mod extension;
25pub mod horizontally_truncated;
26pub mod row_index_mapped;
27pub mod stack;
28pub mod strided;
29pub mod util;
30
31/// A simple struct representing the shape of a matrix.
32///
33/// The `Dimensions` type stores the number of columns (`width`) and rows (`height`)
34/// of a matrix. It is commonly used for querying and displaying matrix shapes.
35#[derive(Copy, Clone, PartialEq, Eq)]
36pub struct Dimensions {
37    /// Number of columns in the matrix.
38    pub width: usize,
39    /// Number of rows in the matrix.
40    pub height: usize,
41}
42
43impl Debug for Dimensions {
44    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
45        write!(f, "{}x{}", self.width, self.height)
46    }
47}
48
49impl Display for Dimensions {
50    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
51        write!(f, "{}x{}", self.width, self.height)
52    }
53}
54
55/// A generic trait for two-dimensional matrix-like data structures.
56///
57/// The `Matrix` trait provides a uniform interface for accessing rows, elements,
58/// and computing with matrices in both sequential and parallel contexts. It supports
59/// packing strategies for SIMD optimizations and interaction with extension fields.
60pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
61    /// Returns the number of columns in the matrix.
62    fn width(&self) -> usize;
63
64    /// Returns the number of rows in the matrix.
65    fn height(&self) -> usize;
66
67    /// Returns the dimensions (width, height) of the matrix.
68    fn dimensions(&self) -> Dimensions {
69        Dimensions {
70            width: self.width(),
71            height: self.height(),
72        }
73    }
74
75    // The methods:
76    // get, get_unchecked, row, row_unchecked, row_subseq_unchecked, row_slice, row_slice_unchecked, row_subslice_unchecked
77    // are all defined in a circular manner so you only need to implement a subset of them.
78    // In particular is is enough to implement just one of: row_unchecked, row_subseq_unchecked
79    //
80    // That being said, most implementations will want to implement several methods for performance reasons.
81
82    /// Returns the element at the given row and column.
83    ///
84    /// Returns `None` if either `r >= height()` or `c >= width()`.
85    #[inline]
86    fn get(&self, r: usize, c: usize) -> Option<T> {
87        (r < self.height() && c < self.width()).then(|| unsafe {
88            // Safety: Clearly `r < self.height()` and `c < self.width()`.
89            self.get_unchecked(r, c)
90        })
91    }
92
93    /// Returns the element at the given row and column.
94    ///
95    /// For a safe alternative, see [`get`].
96    ///
97    /// # Safety
98    /// The caller must ensure that `r < self.height()` and `c < self.width()`.
99    /// Breaking any of these assumptions is considered undefined behaviour.
100    #[inline]
101    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
102        unsafe { self.row_slice_unchecked(r)[c].clone() }
103    }
104
105    /// Returns an iterator over the elements of the `r`-th row.
106    ///
107    /// The iterator will have `self.width()` elements.
108    ///
109    /// Returns `None` if `r >= height()`.
110    #[inline]
111    fn row(
112        &self,
113        r: usize,
114    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
115        (r < self.height()).then(|| unsafe {
116            // Safety: Clearly `r < self.height()`.
117            self.row_unchecked(r)
118        })
119    }
120
121    /// Returns an iterator over the elements of the `r`-th row.
122    ///
123    /// The iterator will have `self.width()` elements.
124    ///
125    /// For a safe alternative, see [`row`].
126    ///
127    /// # Safety
128    /// The caller must ensure that `r < self.height()`.
129    /// Breaking this assumption is considered undefined behaviour.
130    #[inline]
131    unsafe fn row_unchecked(
132        &self,
133        r: usize,
134    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
135        unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
136    }
137
138    /// Returns an iterator over the elements of the `r`-th row from position `start` to `end`.
139    ///
140    /// When `start = 0` and `end = width()`, this is equivalent to [`row_unchecked`].
141    ///
142    /// For a safe alternative, use [`row`], along with the `skip` and `take` iterator methods.
143    ///
144    /// # Safety
145    /// The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
146    /// Breaking any of these assumptions is considered undefined behaviour.
147    #[inline]
148    unsafe fn row_subseq_unchecked(
149        &self,
150        r: usize,
151        start: usize,
152        end: usize,
153    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
154        unsafe {
155            self.row_unchecked(r)
156                .into_iter()
157                .skip(start)
158                .take(end - start)
159        }
160    }
161
162    /// Returns the elements of the `r`-th row as something which can be coerced to a slice.
163    ///
164    /// Returns `None` if `r >= height()`.
165    #[inline]
166    fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
167        (r < self.height()).then(|| unsafe {
168            // Safety: Clearly `r < self.height()`.
169            self.row_slice_unchecked(r)
170        })
171    }
172
173    /// Returns the elements of the `r`-th row as something which can be coerced to a slice.
174    ///
175    /// For a safe alternative, see [`row_slice`].
176    ///
177    /// # Safety
178    /// The caller must ensure that `r < self.height()`.
179    /// Breaking this assumption is considered undefined behaviour.
180    #[inline]
181    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
182        unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
183    }
184
185    /// Returns a subset of elements of the `r`-th row as something which can be coerced to a slice.
186    ///
187    /// When `start = 0` and `end = width()`, this is equivalent to [`row_slice_unchecked`].
188    ///
189    /// For a safe alternative, see [`row_slice`].
190    ///
191    /// # Safety
192    /// The caller must ensure that `r < self.height()` and `start <= end <= self.width()`.
193    /// Breaking any of these assumptions is considered undefined behaviour.
194    #[inline]
195    unsafe fn row_subslice_unchecked(
196        &self,
197        r: usize,
198        start: usize,
199        end: usize,
200    ) -> impl Deref<Target = [T]> {
201        unsafe {
202            self.row_subseq_unchecked(r, start, end)
203                .into_iter()
204                .collect_vec()
205        }
206    }
207
208    /// Returns an iterator over all rows in the matrix.
209    #[inline]
210    fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
211        unsafe {
212            // Safety: `r` always satisfies `r < self.height()`.
213            (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
214        }
215    }
216
217    /// Returns a parallel iterator over all rows in the matrix.
218    #[inline]
219    fn par_rows(
220        &self,
221    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
222        unsafe {
223            // Safety: `r` always satisfies `r < self.height()`.
224            (0..self.height())
225                .into_par_iter()
226                .map(move |r| self.row_unchecked(r).into_iter())
227        }
228    }
229
230    /// Collect the elements of the rows `r` through `r + c`. If anything is larger than `self.height()`
231    /// simply wrap around to the beginning of the matrix.
232    fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
233        unsafe {
234            // Safety: Thank to the `%`, the rows index is always less than `self.height()`.
235            (0..c)
236                .map(|i| self.row_slice_unchecked((r + i) % self.height()))
237                .collect_vec()
238        }
239    }
240
241    /// Returns an iterator over the first row of the matrix.
242    ///
243    /// Returns None if `height() == 0`.
244    #[inline]
245    fn first_row(
246        &self,
247    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
248        self.row(0)
249    }
250
251    /// Returns an iterator over the last row of the matrix.
252    ///
253    /// Returns None if `height() == 0`.
254    #[inline]
255    fn last_row(
256        &self,
257    ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
258        if self.height() == 0 {
259            None
260        } else {
261            // Safety: Clearly `self.height() - 1 < self.height()`.
262            unsafe { Some(self.row_unchecked(self.height() - 1)) }
263        }
264    }
265
266    /// Converts the matrix into a `RowMajorMatrix` by collecting all rows into a single vector.
267    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
268    where
269        Self: Sized,
270        T: Clone,
271    {
272        RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
273    }
274
275    /// Get a packed iterator over the `r`-th row.
276    ///
277    /// If the row length is not divisible by the packing width, the final elements
278    /// are returned as a base iterator with length `<= P::WIDTH - 1`.
279    ///
280    /// # Panics
281    /// Panics if `r >= height()`.
282    fn horizontally_packed_row<'a, P>(
283        &'a self,
284        r: usize,
285    ) -> (
286        impl Iterator<Item = P> + Send + Sync,
287        impl Iterator<Item = T> + Send + Sync,
288    )
289    where
290        P: PackedValue<Value = T>,
291        T: Clone + 'a,
292    {
293        assert!(r < self.height(), "Row index out of bounds.");
294        let num_packed = self.width() / P::WIDTH;
295        unsafe {
296            // Safety: We have already checked that `r < height()`.
297            let mut iter = self
298                .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
299                .into_iter();
300
301            // array::from_fn is guaranteed to always call in order.
302            let packed =
303                (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
304
305            let sfx = self
306                .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
307                .into_iter();
308            (packed, sfx)
309        }
310    }
311
312    /// Get a packed iterator over the `r`-th row.
313    ///
314    /// If the row length is not divisible by the packing width, the final entry will be zero-padded.
315    ///
316    /// # Panics
317    /// Panics if `r >= height()`.
318    fn padded_horizontally_packed_row<'a, P>(
319        &'a self,
320        r: usize,
321    ) -> impl Iterator<Item = P> + Send + Sync
322    where
323        P: PackedValue<Value = T>,
324        T: Clone + Default + 'a,
325    {
326        let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
327        let num_elems = self.width().div_ceil(P::WIDTH);
328        // array::from_fn is guaranteed to always call in order.
329        (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
330    }
331
332    /// Get a parallel iterator over all packed rows of the matrix.
333    ///
334    /// If the matrix width is not divisible by the packing width, the final elements
335    /// of each row are returned as a base iterator with length `<= P::WIDTH - 1`.
336    fn par_horizontally_packed_rows<'a, P>(
337        &'a self,
338    ) -> impl IndexedParallelIterator<
339        Item = (
340            impl Iterator<Item = P> + Send + Sync,
341            impl Iterator<Item = T> + Send + Sync,
342        ),
343    >
344    where
345        P: PackedValue<Value = T>,
346        T: Clone + 'a,
347    {
348        (0..self.height())
349            .into_par_iter()
350            .map(|r| self.horizontally_packed_row(r))
351    }
352
353    /// Get a parallel iterator over all packed rows of the matrix.
354    ///
355    /// If the matrix width is not divisible by the packing width, the final entry of each row will be zero-padded.
356    fn par_padded_horizontally_packed_rows<'a, P>(
357        &'a self,
358    ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
359    where
360        P: PackedValue<Value = T>,
361        T: Clone + Default + 'a,
362    {
363        (0..self.height())
364            .into_par_iter()
365            .map(|r| self.padded_horizontally_packed_row(r))
366    }
367
368    /// Pack together a collection of adjacent rows from the matrix.
369    ///
370    /// Returns an iterator whose i'th element is packing of the i'th element of the
371    /// rows r through r + P::WIDTH - 1. If we exceed the height of the matrix,
372    /// wrap around and include initial rows.
373    #[inline]
374    fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
375    where
376        T: Copy,
377        P: PackedValue<Value = T>,
378    {
379        // Precompute row slices once to minimize redundant calls and improve performance.
380        let rows = self.wrapping_row_slices(r, P::WIDTH);
381
382        // Using precomputed rows avoids repeatedly calling `row_slice`, which is costly.
383        (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
384    }
385
386    /// Pack together a collection of rows and "next" rows from the matrix.
387    ///
388    /// Returns a vector corresponding to 2 packed rows. The i'th element of the first
389    /// row contains the packing of the i'th element of the rows r through r + P::WIDTH - 1.
390    /// The i'th element of the second row contains the packing of the i'th element of the
391    /// rows r + step through r + step + P::WIDTH - 1. If at some point we exceed the
392    /// height of the matrix, wrap around and include initial rows.
393    #[inline]
394    fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
395    where
396        T: Copy,
397        P: PackedValue<Value = T>,
398    {
399        // Whilst it would appear that this can be replaced by two calls to vertically_packed_row
400        // tests seem to indicate that combining them in the same function is slightly faster.
401        // It's probably allowing the compiler to make some optimizations on the fly.
402
403        let rows = self.wrapping_row_slices(r, P::WIDTH);
404        let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
405
406        (0..self.width())
407            .map(|c| P::from_fn(|i| rows[i][c]))
408            .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
409            .collect_vec()
410    }
411
412    /// Returns a view over a vertically strided submatrix.
413    ///
414    /// The view selects rows using `r = offset + i * stride` for each `i`.
415    fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
416    where
417        Self: Sized,
418    {
419        VerticallyStridedRowIndexMap::new_view(self, stride, offset)
420    }
421
422    /// Compute Mᵀv, aka premultiply this matrix by the given vector,
423    /// aka scale each row by the corresponding entry in `v` and take the sum across rows.
424    /// `v` can be a vector of extension elements.
425    #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
426    fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
427    where
428        T: Field,
429        EF: ExtensionField<T>,
430    {
431        let packed_width = self.width().div_ceil(T::Packing::WIDTH);
432
433        let packed_result = self
434            .par_padded_horizontally_packed_rows::<T::Packing>()
435            .zip(v)
436            .par_fold_reduce(
437                || EF::ExtensionPacking::zero_vec(packed_width),
438                |mut acc, (row, &scale)| {
439                    let scale: EF::ExtensionPacking = scale.into();
440                    acc.iter_mut().zip(row).for_each(|(l, r)| *l += scale * r);
441                    acc
442                },
443                |mut acc_l, acc_r| {
444                    acc_l.iter_mut().zip(&acc_r).for_each(|(l, r)| *l += *r);
445                    acc_l
446                },
447            );
448
449        EF::ExtensionPacking::to_ext_iter(packed_result)
450            .take(self.width())
451            .collect()
452    }
453
454    /// Compute Mᵀ · [v₀, v₁, ..., vₙ₋₁] for N weight vectors simultaneously.
455    ///
456    /// Computes `result[col][j] = Σᵣ M[r, col] · vⱼ[r]` for all columns and all j ∈ [0, N).
457    ///
458    /// Batching N weight vectors reduces memory bandwidth: each matrix row is loaded once
459    /// instead of N times. Uses SIMD packing (width W) to process W columns in parallel.
460    #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
461    fn columnwise_dot_product_batched<EF, const N: usize>(
462        &self,
463        vs: &[FieldArray<EF, N>],
464    ) -> Vec<FieldArray<EF, N>>
465    where
466        T: Field,
467        EF: ExtensionField<T>,
468    {
469        let packed_width = self.width().div_ceil(T::Packing::WIDTH);
470
471        let packed_results: Vec<EF::ExtensionPacking> = self
472            .par_padded_horizontally_packed_rows::<T::Packing>()
473            .zip(vs)
474            .par_fold_reduce(
475                || EF::ExtensionPacking::zero_vec(packed_width * N),
476                |mut acc, (packed_row, scales)| {
477                    // Broadcast each scalar scale to all SIMD lanes
478                    let packed_scales: [EF::ExtensionPacking; N] =
479                        scales.map_into_array(EF::ExtensionPacking::from);
480
481                    // acc[c][j] += scales[j] · row[c] for column batch c, point j
482                    for (acc_c, row_c) in acc.chunks_exact_mut(N).zip(packed_row) {
483                        for j in 0..N {
484                            acc_c[j] += packed_scales[j] * row_c;
485                        }
486                    }
487                    acc
488                },
489                |mut acc_l, acc_r| {
490                    acc_l.iter_mut().zip(&acc_r).for_each(|(lj, rj)| *lj += *rj);
491                    acc_l
492                },
493            );
494
495        // Unpack: chunk[j].lane(i) → result[c·W + i][j] for column batch c
496        packed_results
497            .chunks(N)
498            .flat_map(|chunk| {
499                (0..T::Packing::WIDTH)
500                    .map(move |lane| FieldArray::from_fn(|j| chunk[j].extract(lane)))
501            })
502            .take(self.width())
503            .collect()
504    }
505
506    /// Compute the matrix vector product `M . vec`, aka take the dot product of each
507    /// row of `M` by `vec`. If the length of `vec` is longer than the width of `M`,
508    /// `vec` is truncated to the first `width()` elements.
509    ///
510    /// We make use of `PackedFieldExtension` to speed up computations. Thus `vec` is passed in as
511    /// a slice of `PackedFieldExtension` elements.
512    ///
513    /// # Panics
514    /// This function panics if the length of `vec` is less than `self.width().div_ceil(T::Packing::WIDTH)`.
515    fn rowwise_packed_dot_product<EF>(
516        &self,
517        vec: &[EF::ExtensionPacking],
518    ) -> impl IndexedParallelIterator<Item = EF>
519    where
520        T: Field,
521        EF: ExtensionField<T>,
522    {
523        // The length of a `padded_horizontally_packed_row` is `self.width().div_ceil(T::Packing::WIDTH)`.
524        assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
525
526        // Instead of creating N intermediate ExtPacking products and summing them,
527        // we track D separate BasePacking accumulators (one per extension coefficient).
528        self.par_padded_horizontally_packed_rows::<T::Packing>()
529            .map(move |row_packed| {
530                // Get the extension dimension from the first vec element's coefficients
531                let d = <EF::ExtensionPacking as BasedVectorSpace<T::Packing>>::DIMENSION;
532
533                // Initialize D accumulators for each coefficient of the extension
534                // In practice, we set D to 8, which is the maximum degree of the extension field supported.
535                let mut coeff_accs: [T::Packing; 8] = [T::Packing::ZERO; 8];
536                debug_assert!(d <= 8, "Extension degree > 8 not supported");
537
538                // Accumulate coefficient-wise: for each (v, r) pair, acc[i] += v.coefficient(i) * r
539                for (v, r) in vec.iter().zip(row_packed) {
540                    let v_coeffs = v.as_basis_coefficients_slice();
541                    for (acc, &v_coeff) in coeff_accs[..d].iter_mut().zip(v_coeffs) {
542                        *acc += v_coeff * r;
543                    }
544                }
545
546                // Construct the result ExtPacking from the accumulators and sum the coefficients.
547                let packed_result =
548                    EF::ExtensionPacking::from_basis_coefficients_fn(|i| coeff_accs[i]);
549                EF::ExtensionPacking::to_ext_iter([packed_result]).sum()
550            })
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use alloc::vec::Vec;
557    use alloc::{format, vec};
558
559    use itertools::izip;
560    use p3_baby_bear::BabyBear;
561    use p3_field::PrimeCharacteristicRing;
562    use p3_field::extension::BinomialExtensionField;
563    use rand::SeedableRng;
564    use rand::rngs::SmallRng;
565
566    use super::*;
567
568    #[test]
569    fn test_columnwise_dot_product() {
570        type F = BabyBear;
571        type EF = BinomialExtensionField<BabyBear, 4>;
572
573        let mut rng = SmallRng::seed_from_u64(1);
574        let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
575        let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
576
577        let mut expected = vec![EF::ZERO; m.width()];
578        for (row, &scale) in izip!(m.rows(), &v) {
579            for (l, r) in izip!(&mut expected, row) {
580                *l += scale * r;
581            }
582        }
583
584        assert_eq!(m.columnwise_dot_product(&v), expected);
585    }
586
587    #[test]
588    fn test_columnwise_dot_product_batched() {
589        type F = BabyBear;
590        type EF = BinomialExtensionField<BabyBear, 4>;
591
592        let mut rng = SmallRng::seed_from_u64(1);
593        let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
594        let v1 = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
595        let v2 = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
596
597        // Compute expected via two separate calls
598        let expected1 = m.columnwise_dot_product(&v1);
599        let expected2 = m.columnwise_dot_product(&v2);
600
601        // Compute via batched call - returns Vec<[EF; 2]> where result[col] = [dot1, dot2]
602        let vs: Vec<FieldArray<EF, 2>> = v1
603            .into_iter()
604            .zip(v2)
605            .map(|(a, b)| FieldArray([a, b]))
606            .collect();
607        let results = m.columnwise_dot_product_batched::<EF, 2>(&vs);
608
609        // Extract each point's results
610        let result1: Vec<EF> = results.iter().map(|r| r[0]).collect();
611        let result2: Vec<EF> = results.iter().map(|r| r[1]).collect();
612
613        assert_eq!(result1, expected1);
614        assert_eq!(result2, expected2);
615    }
616
617    // Mock implementation for testing purposes
618    struct MockMatrix {
619        data: Vec<Vec<u32>>,
620        width: usize,
621        height: usize,
622    }
623
624    impl Matrix<u32> for MockMatrix {
625        fn width(&self) -> usize {
626            self.width
627        }
628
629        fn height(&self) -> usize {
630            self.height
631        }
632
633        unsafe fn row_unchecked(
634            &self,
635            r: usize,
636        ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
637        {
638            // Just a mock implementation so we just do the easy safe thing.
639            self.data[r].clone()
640        }
641    }
642
643    #[test]
644    fn test_dimensions() {
645        let dims = Dimensions {
646            width: 3,
647            height: 5,
648        };
649        assert_eq!(dims.width, 3);
650        assert_eq!(dims.height, 5);
651        assert_eq!(format!("{dims:?}"), "3x5");
652        assert_eq!(format!("{dims}"), "3x5");
653    }
654
655    #[test]
656    fn test_mock_matrix_dimensions() {
657        let matrix = MockMatrix {
658            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
659            width: 3,
660            height: 3,
661        };
662        assert_eq!(matrix.width(), 3);
663        assert_eq!(matrix.height(), 3);
664        assert_eq!(
665            matrix.dimensions(),
666            Dimensions {
667                width: 3,
668                height: 3
669            }
670        );
671    }
672
673    #[test]
674    fn test_first_row() {
675        let matrix = MockMatrix {
676            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
677            width: 3,
678            height: 3,
679        };
680        let mut first_row = matrix.first_row().unwrap().into_iter();
681        assert_eq!(first_row.next(), Some(1));
682        assert_eq!(first_row.next(), Some(2));
683        assert_eq!(first_row.next(), Some(3));
684    }
685
686    #[test]
687    fn test_last_row() {
688        let matrix = MockMatrix {
689            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
690            width: 3,
691            height: 3,
692        };
693        let mut last_row = matrix.last_row().unwrap().into_iter();
694        assert_eq!(last_row.next(), Some(7));
695        assert_eq!(last_row.next(), Some(8));
696        assert_eq!(last_row.next(), Some(9));
697    }
698
699    #[test]
700    fn test_first_last_row_empty_matrix() {
701        let matrix = MockMatrix {
702            data: vec![],
703            width: 3,
704            height: 0,
705        };
706        let first_row = matrix.first_row();
707        let last_row = matrix.last_row();
708        assert!(first_row.is_none());
709        assert!(last_row.is_none());
710    }
711
712    #[test]
713    fn test_to_row_major_matrix() {
714        let matrix = MockMatrix {
715            data: vec![vec![1, 2], vec![3, 4]],
716            width: 2,
717            height: 2,
718        };
719        let row_major = matrix.to_row_major_matrix();
720        assert_eq!(row_major.values, vec![1, 2, 3, 4]);
721        assert_eq!(row_major.width, 2);
722    }
723
724    #[test]
725    fn test_matrix_get_methods() {
726        let matrix = MockMatrix {
727            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
728            width: 3,
729            height: 3,
730        };
731        assert_eq!(matrix.get(0, 0), Some(1));
732        assert_eq!(matrix.get(1, 2), Some(6));
733        assert_eq!(matrix.get(2, 1), Some(8));
734
735        unsafe {
736            assert_eq!(matrix.get_unchecked(0, 1), 2);
737            assert_eq!(matrix.get_unchecked(1, 0), 4);
738            assert_eq!(matrix.get_unchecked(2, 2), 9);
739        }
740
741        assert_eq!(matrix.get(3, 0), None); // Height out of bounds
742        assert_eq!(matrix.get(0, 3), None); // Width out of bounds
743    }
744
745    #[test]
746    fn test_matrix_row_methods_iteration() {
747        let matrix = MockMatrix {
748            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
749            width: 3,
750            height: 3,
751        };
752
753        let mut row_iter = matrix.row(1).unwrap().into_iter();
754        assert_eq!(row_iter.next(), Some(4));
755        assert_eq!(row_iter.next(), Some(5));
756        assert_eq!(row_iter.next(), Some(6));
757        assert_eq!(row_iter.next(), None);
758
759        unsafe {
760            let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
761            assert_eq!(row_iter_unchecked.next(), Some(7));
762            assert_eq!(row_iter_unchecked.next(), Some(8));
763            assert_eq!(row_iter_unchecked.next(), Some(9));
764            assert_eq!(row_iter_unchecked.next(), None);
765
766            let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
767            assert_eq!(row_iter_subset.next(), Some(2));
768            assert_eq!(row_iter_subset.next(), Some(3));
769            assert_eq!(row_iter_subset.next(), None);
770        }
771
772        assert!(matrix.row(3).is_none()); // Height out of bounds
773    }
774
775    #[test]
776    fn test_row_slice_methods() {
777        let matrix = MockMatrix {
778            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
779            width: 3,
780            height: 3,
781        };
782        let row_slice = matrix.row_slice(1).unwrap();
783        assert_eq!(*row_slice, [4, 5, 6]);
784        unsafe {
785            let row_slice_unchecked = matrix.row_slice_unchecked(2);
786            assert_eq!(*row_slice_unchecked, [7, 8, 9]);
787
788            let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
789            assert_eq!(*row_subslice, [2]);
790        }
791
792        assert!(matrix.row_slice(3).is_none()); // Height out of bounds
793    }
794
795    #[test]
796    fn test_matrix_rows() {
797        let matrix = MockMatrix {
798            data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
799            width: 3,
800            height: 3,
801        };
802
803        let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
804        assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
805    }
806
807    #[test]
808    fn test_rowwise_packed_dot_product() {
809        use p3_field::PackedFieldExtension;
810
811        type F = BabyBear;
812        type EF = BinomialExtensionField<BabyBear, 4>;
813        type PF = <F as p3_field::Field>::Packing;
814        type EFPacked = <EF as p3_field::ExtensionField<F>>::ExtensionPacking;
815
816        let mut rng = SmallRng::seed_from_u64(42);
817
818        // Test with various matrix dimensions to cover edge cases.
819        for (height, width) in [(32, 16), (64, 128), (128, 17), (256, 255)] {
820            let m = RowMajorMatrix::<F>::rand(&mut rng, height, width);
821            let v = RowMajorMatrix::<EF>::rand(&mut rng, width, 1).values;
822
823            // Compute expected result naively: for each row, compute dot product with v.
824            let expected: Vec<EF> = m
825                .rows()
826                .map(|row| {
827                    row.into_iter()
828                        .zip(v.iter())
829                        .map(|(r, &ve)| ve * r)
830                        .sum::<EF>()
831                })
832                .collect();
833
834            // Pack the vector for the optimized function.
835            let packed_v: Vec<EFPacked> = v
836                .chunks(<PF as PackedValue>::WIDTH)
837                .map(|chunk| {
838                    let mut padded = vec![EF::ZERO; <PF as PackedValue>::WIDTH];
839                    padded[..chunk.len()].copy_from_slice(chunk);
840                    EFPacked::from_ext_slice(&padded)
841                })
842                .collect();
843
844            // Compute using the optimized function.
845            let result: Vec<EF> = m.rowwise_packed_dot_product::<EF>(&packed_v).collect();
846
847            assert_eq!(result, expected, "Mismatch for matrix {}x{}", height, width);
848        }
849    }
850}