p3_matrix/
dense.rs

1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::iter;
6use core::marker::PhantomData;
7use core::ops::Deref;
8
9use p3_field::{
10    ExtensionField, Field, PackedValue, par_scale_slice_in_place, scale_slice_in_place_single_core,
11};
12use p3_maybe_rayon::prelude::*;
13use rand::Rng;
14use rand::distr::{Distribution, StandardUniform};
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::Matrix;
19
20/// A dense matrix in row-major format, with customizable backing storage.
21///
22/// The data is stored as a flat buffer, where rows are laid out consecutively.
23#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct DenseMatrix<T, V = Vec<T>> {
25    /// Flat buffer of matrix values in row-major order.
26    pub values: V,
27    /// Number of columns in the matrix.
28    ///
29    /// The number of rows is implicitly determined as `values.len() / width`.
30    pub width: usize,
31    /// Marker for the element type `T`, unused directly.
32    ///
33    /// Required to retain type information when `V` does not own or contain `T`.
34    _phantom: PhantomData<T>,
35}
36
37pub type RowMajorMatrix<T> = DenseMatrix<T>;
38pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
39pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
40pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
41
42pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
43    fn to_vec(self) -> Vec<T>;
44}
45
46// Cow doesn't impl IntoOwned so we can't blanket it
47impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
48    fn to_vec(self) -> Self {
49        self
50    }
51}
52
53impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
54    fn to_vec(self) -> Vec<T> {
55        <[T]>::to_vec(self)
56    }
57}
58
59impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
60    fn to_vec(self) -> Vec<T> {
61        <[T]>::to_vec(self)
62    }
63}
64
65impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
66    fn to_vec(self) -> Vec<T> {
67        self.into_owned()
68    }
69}
70
71impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
72    /// Create a new dense matrix of the given dimensions, backed by a `Vec`, and filled with
73    /// default values.
74    #[must_use]
75    pub fn default(width: usize, height: usize) -> Self {
76        Self::new(vec![T::default(); width * height], width)
77    }
78}
79
80impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
81    /// Create a new dense matrix of the given dimensions, backed by the given storage.
82    ///
83    /// Note that it is undefined behavior to create a matrix such that
84    /// `values.len() % width != 0`.
85    #[must_use]
86    pub fn new(values: S, width: usize) -> Self {
87        debug_assert!(width == 0 || values.borrow().len() % width == 0);
88        Self {
89            values,
90            width,
91            _phantom: PhantomData,
92        }
93    }
94
95    /// Create a new RowMajorMatrix containing a single row.
96    #[must_use]
97    pub fn new_row(values: S) -> Self {
98        let width = values.borrow().len();
99        Self::new(values, width)
100    }
101
102    /// Create a new RowMajorMatrix containing a single column.
103    #[must_use]
104    pub fn new_col(values: S) -> Self {
105        Self::new(values, 1)
106    }
107
108    /// Get a view of the matrix, i.e. a reference to the underlying data.
109    pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
110        RowMajorMatrixView::new(self.values.borrow(), self.width)
111    }
112
113    /// Get a mutable view of the matrix, i.e. a mutable reference to the underlying data.
114    pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
115    where
116        S: BorrowMut<[T]>,
117    {
118        RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
119    }
120
121    /// Copy the values from the given matrix into this matrix.
122    pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
123    where
124        T: Copy,
125        S: BorrowMut<[T]>,
126        S2: DenseStorage<T>,
127    {
128        assert_eq!(self.dimensions(), source.dimensions());
129        // Equivalent to:
130        // self.values.borrow_mut().copy_from_slice(source.values.borrow());
131        self.par_rows_mut()
132            .zip(source.par_row_slices())
133            .for_each(|(dst, src)| {
134                dst.copy_from_slice(src);
135            });
136    }
137
138    /// Flatten an extension field matrix to a base field matrix.
139    pub fn flatten_to_base<F: Field>(self) -> RowMajorMatrix<F>
140    where
141        T: ExtensionField<F>,
142    {
143        let width = self.width * T::DIMENSION;
144        let values = T::flatten_to_base(self.values.to_vec());
145        RowMajorMatrix::new(values, width)
146    }
147
148    /// Get an iterator over the rows of the matrix.
149    pub fn row_slices(&self) -> impl Iterator<Item = &[T]> {
150        self.values.borrow().chunks_exact(self.width)
151    }
152
153    /// Get a parallel iterator over the rows of the matrix.
154    pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
155    where
156        T: Sync,
157    {
158        self.values.borrow().par_chunks_exact(self.width)
159    }
160
161    /// Returns a slice of the given row.
162    ///
163    /// # Panics
164    /// Panics if `r` larger than self.height().
165    pub fn row_mut(&mut self, r: usize) -> &mut [T]
166    where
167        S: BorrowMut<[T]>,
168    {
169        &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
170    }
171
172    /// Get a mutable iterator over the rows of the matrix.
173    pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
174    where
175        S: BorrowMut<[T]>,
176    {
177        self.values.borrow_mut().chunks_exact_mut(self.width)
178    }
179
180    /// Get a mutable parallel iterator over the rows of the matrix.
181    pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
182    where
183        T: 'a + Send,
184        S: BorrowMut<[T]>,
185    {
186        self.values.borrow_mut().par_chunks_exact_mut(self.width)
187    }
188
189    /// Get a mutable iterator over the rows of the matrix which packs the rows into packed values.
190    ///
191    /// If `P::WIDTH` does not divide `self.width`, the remainder of the row will be returned as a
192    /// base slice.
193    pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
194    where
195        P: PackedValue<Value = T>,
196        S: BorrowMut<[T]>,
197    {
198        P::pack_slice_with_suffix_mut(self.row_mut(r))
199    }
200
201    /// Scale the given row by the given value.
202    ///
203    /// # Panics
204    /// Panics if `r` larger than `self.height()`.
205    pub fn scale_row(&mut self, r: usize, scale: T)
206    where
207        T: Field,
208        S: BorrowMut<[T]>,
209    {
210        scale_slice_in_place_single_core(self.row_mut(r), scale);
211    }
212
213    /// Scale the given row by the given value.
214    ///
215    /// # Performance
216    /// This function is parallelized, which may introduce some overhead compared to
217    /// [`Self::scale_row`] when the width is small.
218    ///
219    /// # Panics
220    /// Panics if `r` larger than `self.height()`.
221    pub fn par_scale_row(&mut self, r: usize, scale: T)
222    where
223        T: Field,
224        S: BorrowMut<[T]>,
225    {
226        par_scale_slice_in_place(self.row_mut(r), scale);
227    }
228
229    /// Scale the entire matrix by the given value.
230    pub fn scale(&mut self, scale: T)
231    where
232        T: Field,
233        S: BorrowMut<[T]>,
234    {
235        par_scale_slice_in_place(self.values.borrow_mut(), scale);
236    }
237
238    /// Split the matrix into two matrix views, one with the first `r` rows and one with the remaining rows.
239    ///
240    /// # Panics
241    /// Panics if `r` larger than `self.height()`.
242    pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
243        let (lo, hi) = self.values.borrow().split_at(r * self.width);
244        (
245            DenseMatrix::new(lo, self.width),
246            DenseMatrix::new(hi, self.width),
247        )
248    }
249
250    /// Split the matrix into two mutable matrix views, one with the first `r` rows and one with the remaining rows.
251    ///
252    /// # Panics
253    /// Panics if `r` larger than `self.height()`.
254    pub fn split_rows_mut(
255        &mut self,
256        r: usize,
257    ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
258    where
259        S: BorrowMut<[T]>,
260    {
261        let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
262        (
263            DenseMatrix::new(lo, self.width),
264            DenseMatrix::new(hi, self.width),
265        )
266    }
267
268    /// Get an iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
269    ///
270    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
271    pub fn par_row_chunks(
272        &self,
273        chunk_rows: usize,
274    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
275    where
276        T: Send,
277    {
278        self.values
279            .borrow()
280            .par_chunks(self.width * chunk_rows)
281            .map(|slice| RowMajorMatrixView::new(slice, self.width))
282    }
283
284    /// Get a parallel iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
285    ///
286    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
287    pub fn par_row_chunks_exact(
288        &self,
289        chunk_rows: usize,
290    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
291    where
292        T: Send,
293    {
294        self.values
295            .borrow()
296            .par_chunks_exact(self.width * chunk_rows)
297            .map(|slice| RowMajorMatrixView::new(slice, self.width))
298    }
299
300    /// Get a mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
301    ///
302    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
303    pub fn par_row_chunks_mut(
304        &mut self,
305        chunk_rows: usize,
306    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
307    where
308        T: Send,
309        S: BorrowMut<[T]>,
310    {
311        self.values
312            .borrow_mut()
313            .par_chunks_mut(self.width * chunk_rows)
314            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
315    }
316
317    /// Get a mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
318    ///
319    /// If `chunk_rows` does not divide the height of the matrix, the last up to `chunk_rows - 1` rows
320    /// of the matrix will be omitted.
321    pub fn row_chunks_exact_mut(
322        &mut self,
323        chunk_rows: usize,
324    ) -> impl Iterator<Item = RowMajorMatrixViewMut<T>>
325    where
326        T: Send,
327        S: BorrowMut<[T]>,
328    {
329        self.values
330            .borrow_mut()
331            .chunks_exact_mut(self.width * chunk_rows)
332            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
333    }
334
335    /// Get a parallel mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
336    ///
337    /// If `chunk_rows` does not divide the height of the matrix, the last up to `chunk_rows - 1` rows
338    /// of the matrix will be omitted.
339    pub fn par_row_chunks_exact_mut(
340        &mut self,
341        chunk_rows: usize,
342    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
343    where
344        T: Send,
345        S: BorrowMut<[T]>,
346    {
347        self.values
348            .borrow_mut()
349            .par_chunks_exact_mut(self.width * chunk_rows)
350            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
351    }
352
353    /// Get a pair of mutable slices of the given rows.
354    ///
355    /// # Panics
356    /// Panics if `row_1` or `row_2` are out of bounds or if `row_1 >= row_2`.
357    pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
358    where
359        S: BorrowMut<[T]>,
360    {
361        debug_assert_ne!(row_1, row_2);
362        let start_1 = row_1 * self.width;
363        let start_2 = row_2 * self.width;
364        let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
365        (&mut lo[start_1..][..self.width], &mut hi[..self.width])
366    }
367
368    /// Get a pair of mutable slices of the given rows, both packed into packed field elements.
369    ///
370    /// If `P:WIDTH` does not divide `self.width`, the remainder of the row will be returned as a base slice.
371    ///
372    /// # Panics
373    /// Panics if `row_1` or `row_2` are out of bounds or if `row_1 >= row_2`.
374    #[allow(clippy::type_complexity)]
375    pub fn packed_row_pair_mut<P>(
376        &mut self,
377        row_1: usize,
378        row_2: usize,
379    ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
380    where
381        S: BorrowMut<[T]>,
382        P: PackedValue<Value = T>,
383    {
384        let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
385        (
386            P::pack_slice_with_suffix_mut(slice_1),
387            P::pack_slice_with_suffix_mut(slice_2),
388        )
389    }
390
391    /// Append zeros to the "end" of the given matrix, except that the matrix is in bit-reversed order,
392    /// so in actuality we're interleaving zero rows.
393    #[instrument(level = "debug", skip_all)]
394    pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
395    where
396        T: Field,
397    {
398        if added_bits == 0 {
399            return self.to_row_major_matrix();
400        }
401
402        // This is equivalent to:
403        //     reverse_matrix_index_bits(mat);
404        //     mat
405        //         .values
406        //         .resize(mat.values.len() << added_bits, F::ZERO);
407        //     reverse_matrix_index_bits(mat);
408        // But rather than implement it with bit reversals, we directly construct the resulting matrix,
409        // whose rows are zero except for rows whose low `added_bits` bits are zero.
410
411        let w = self.width;
412        let mut padded =
413            RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
414        padded
415            .par_row_chunks_exact_mut(1 << added_bits)
416            .zip(self.par_row_slices())
417            .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
418
419        padded
420    }
421}
422
423impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
424    #[inline]
425    fn width(&self) -> usize {
426        self.width
427    }
428
429    #[inline]
430    fn height(&self) -> usize {
431        if self.width == 0 {
432            0
433        } else {
434            self.values.borrow().len() / self.width
435        }
436    }
437
438    #[inline]
439    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
440        unsafe {
441            // Safety: The caller must ensure that r < self.height() and c < self.width().
442            self.values
443                .borrow()
444                .get_unchecked(r * self.width + c)
445                .clone()
446        }
447    }
448
449    #[inline]
450    unsafe fn row_subseq_unchecked(
451        &self,
452        r: usize,
453        start: usize,
454        end: usize,
455    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
456        unsafe {
457            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
458            self.values
459                .borrow()
460                .get_unchecked(r * self.width + start..r * self.width + end)
461                .iter()
462                .cloned()
463        }
464    }
465
466    #[inline]
467    unsafe fn row_subslice_unchecked(
468        &self,
469        r: usize,
470        start: usize,
471        end: usize,
472    ) -> impl Deref<Target = [T]> {
473        unsafe {
474            // Safety: The caller must ensure that r < self.height()
475            self.values
476                .borrow()
477                .get_unchecked(r * self.width + start..r * self.width + end)
478        }
479    }
480
481    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
482    where
483        Self: Sized,
484        T: Clone,
485    {
486        RowMajorMatrix::new(self.values.to_vec(), self.width)
487    }
488
489    #[inline]
490    fn horizontally_packed_row<'a, P>(
491        &'a self,
492        r: usize,
493    ) -> (
494        impl Iterator<Item = P> + Send + Sync,
495        impl Iterator<Item = T> + Send + Sync,
496    )
497    where
498        P: PackedValue<Value = T>,
499        T: Clone + 'a,
500    {
501        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
502        let (packed, sfx) = P::pack_slice_with_suffix(buf);
503        (packed.iter().copied(), sfx.iter().cloned())
504    }
505
506    #[inline]
507    fn padded_horizontally_packed_row<'a, P>(
508        &'a self,
509        r: usize,
510    ) -> impl Iterator<Item = P> + Send + Sync
511    where
512        P: PackedValue<Value = T>,
513        T: Clone + Default + 'a,
514    {
515        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
516        let (packed, sfx) = P::pack_slice_with_suffix(buf);
517        packed.iter().copied().chain(iter::once(P::from_fn(|i| {
518            sfx.get(i).cloned().unwrap_or_default()
519        })))
520    }
521}
522
523impl<T: Clone + Default + Send + Sync> DenseMatrix<T> {
524    pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
525        RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
526    }
527
528    pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
529    where
530        StandardUniform: Distribution<T>,
531    {
532        let values = rng.sample_iter(StandardUniform).take(rows * cols).collect();
533        Self::new(values, cols)
534    }
535
536    pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
537    where
538        T: Field,
539        StandardUniform: Distribution<T>,
540    {
541        let values = rng
542            .sample_iter(StandardUniform)
543            .filter(|x| !x.is_zero())
544            .take(rows * cols)
545            .collect();
546        Self::new(values, cols)
547    }
548
549    pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
550        assert!(new_height >= self.height());
551        self.values.resize(self.width * new_height, fill);
552    }
553}
554
555impl<T: Copy + Default + Send + Sync, V: DenseStorage<T>> DenseMatrix<T, V> {
556    /// Return the transpose of this matrix.
557    pub fn transpose(&self) -> RowMajorMatrix<T> {
558        let nelts = self.height() * self.width();
559        let mut values = vec![T::default(); nelts];
560        transpose::transpose(
561            self.values.borrow(),
562            &mut values,
563            self.width(),
564            self.height(),
565        );
566        RowMajorMatrix::new(values, self.height())
567    }
568
569    /// Transpose the matrix returning the result in `other` without intermediate allocation.
570    pub fn transpose_into<W: DenseStorage<T> + BorrowMut<[T]>>(
571        &self,
572        other: &mut DenseMatrix<T, W>,
573    ) {
574        assert_eq!(self.height(), other.width());
575        assert_eq!(other.height(), self.width());
576        transpose::transpose(
577            self.values.borrow(),
578            other.values.borrow_mut(),
579            self.width(),
580            self.height(),
581        );
582    }
583}
584
585impl<'a, T: Clone + Default + Send + Sync> RowMajorMatrixView<'a, T> {
586    pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
587        RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
588    }
589}
590
591#[cfg(test)]
592mod tests {
593    use p3_baby_bear::BabyBear;
594    use p3_field::FieldArray;
595
596    use super::*;
597
598    #[test]
599    fn test_new() {
600        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
601        assert_eq!(matrix.width, 2);
602        assert_eq!(matrix.height(), 3);
603        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6]);
604    }
605
606    #[test]
607    fn test_new_row() {
608        let matrix = RowMajorMatrix::new_row(vec![1, 2, 3]);
609        assert_eq!(matrix.width, 3);
610        assert_eq!(matrix.height(), 1);
611    }
612
613    #[test]
614    fn test_new_col() {
615        let matrix = RowMajorMatrix::new_col(vec![1, 2, 3]);
616        assert_eq!(matrix.width, 1);
617        assert_eq!(matrix.height(), 3);
618    }
619
620    #[test]
621    fn test_height_with_zero_width() {
622        let matrix: DenseMatrix<i32> = RowMajorMatrix::new(vec![], 0);
623        assert_eq!(matrix.height(), 0);
624    }
625
626    #[test]
627    fn test_get_methods() {
628        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2); // Height = 3, Width = 2
629        assert_eq!(matrix.get(0, 0), Some(1));
630        assert_eq!(matrix.get(1, 1), Some(4));
631        assert_eq!(matrix.get(2, 0), Some(5));
632        unsafe {
633            assert_eq!(matrix.get_unchecked(0, 1), 2);
634            assert_eq!(matrix.get_unchecked(1, 0), 3);
635            assert_eq!(matrix.get_unchecked(2, 1), 6);
636        }
637        assert_eq!(matrix.get(3, 0), None); // Height out of bounds
638        assert_eq!(matrix.get(0, 2), None); // Width out of bounds
639    }
640
641    #[test]
642    fn test_row_methods() {
643        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4); // Height = 2, Width = 4
644        let row: Vec<_> = matrix.row(1).unwrap().into_iter().collect();
645        assert_eq!(row, vec![5, 6, 7, 8]);
646        unsafe {
647            let row: Vec<_> = matrix.row_unchecked(0).into_iter().collect();
648            assert_eq!(row, vec![1, 2, 3, 4]);
649            let row: Vec<_> = matrix.row_subseq_unchecked(0, 0, 3).into_iter().collect();
650            assert_eq!(row, vec![1, 2, 3]);
651            let row: Vec<_> = matrix.row_subseq_unchecked(0, 1, 3).into_iter().collect();
652            assert_eq!(row, vec![2, 3]);
653            let row: Vec<_> = matrix.row_subseq_unchecked(0, 2, 4).into_iter().collect();
654            assert_eq!(row, vec![3, 4]);
655        }
656        assert!(matrix.row(2).is_none()); // Height out of bounds
657    }
658
659    #[test]
660    fn test_row_slice_methods() {
661        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], 3); // Height = 3, Width = 3
662        let slice0 = matrix.row_slice(0);
663        let slice2 = matrix.row_slice(2);
664        assert_eq!(slice0.unwrap().deref(), &[1, 2, 3]);
665        assert_eq!(slice2.unwrap().deref(), &[7, 8, 9]);
666        unsafe {
667            assert_eq!(&[1, 2, 3], matrix.row_slice_unchecked(0).deref());
668            assert_eq!(&[7, 8, 9], matrix.row_slice_unchecked(2).deref());
669
670            assert_eq!(&[1, 2, 3], matrix.row_subslice_unchecked(0, 0, 3).deref());
671            assert_eq!(&[8], matrix.row_subslice_unchecked(2, 1, 2).deref());
672        }
673        assert!(matrix.row_slice(3).is_none()); // Height out of bounds
674    }
675
676    #[test]
677    fn test_as_view() {
678        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
679        let view = matrix.as_view();
680        assert_eq!(view.values, &[1, 2, 3, 4]);
681        assert_eq!(view.width, 2);
682    }
683
684    #[test]
685    fn test_as_view_mut() {
686        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
687        let view = matrix.as_view_mut();
688        view.values[0] = 10;
689        assert_eq!(matrix.values, vec![10, 2, 3, 4]);
690    }
691
692    #[test]
693    fn test_copy_from() {
694        let mut matrix1 = RowMajorMatrix::new(vec![0, 0, 0, 0], 2);
695        let matrix2 = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
696        matrix1.copy_from(&matrix2);
697        assert_eq!(matrix1.values, vec![1, 2, 3, 4]);
698    }
699
700    #[test]
701    fn test_split_rows() {
702        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
703        let (top, bottom) = matrix.split_rows(1);
704        assert_eq!(top.values, vec![1, 2]);
705        assert_eq!(bottom.values, vec![3, 4, 5, 6]);
706    }
707
708    #[test]
709    fn test_split_rows_mut() {
710        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
711        let (top, bottom) = matrix.split_rows_mut(1);
712        assert_eq!(top.values, vec![1, 2]);
713        assert_eq!(bottom.values, vec![3, 4, 5, 6]);
714    }
715
716    #[test]
717    fn test_row_mut() {
718        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
719        matrix.row_mut(1)[0] = 10;
720        assert_eq!(matrix.values, vec![1, 2, 10, 4, 5, 6]);
721    }
722
723    #[test]
724    fn test_bit_reversed_zero_pad() {
725        let matrix = RowMajorMatrix::new(
726            vec![
727                BabyBear::new(1),
728                BabyBear::new(2),
729                BabyBear::new(3),
730                BabyBear::new(4),
731            ],
732            2,
733        );
734        let padded = matrix.bit_reversed_zero_pad(1);
735        assert_eq!(padded.width, 2);
736        assert_eq!(
737            padded.values,
738            vec![
739                BabyBear::new(1),
740                BabyBear::new(2),
741                BabyBear::new(0),
742                BabyBear::new(0),
743                BabyBear::new(3),
744                BabyBear::new(4),
745                BabyBear::new(0),
746                BabyBear::new(0)
747            ]
748        );
749    }
750
751    #[test]
752    fn test_bit_reversed_zero_pad_no_change() {
753        let matrix = RowMajorMatrix::new(
754            vec![
755                BabyBear::new(1),
756                BabyBear::new(2),
757                BabyBear::new(3),
758                BabyBear::new(4),
759            ],
760            2,
761        );
762        let padded = matrix.bit_reversed_zero_pad(0);
763
764        assert_eq!(padded.width, 2);
765        assert_eq!(
766            padded.values,
767            vec![
768                BabyBear::new(1),
769                BabyBear::new(2),
770                BabyBear::new(3),
771                BabyBear::new(4),
772            ]
773        );
774    }
775
776    #[test]
777    fn test_scale() {
778        let mut matrix = RowMajorMatrix::new(
779            vec![
780                BabyBear::new(1),
781                BabyBear::new(2),
782                BabyBear::new(3),
783                BabyBear::new(4),
784                BabyBear::new(5),
785                BabyBear::new(6),
786            ],
787            2,
788        );
789        matrix.scale(BabyBear::new(2));
790        assert_eq!(
791            matrix.values,
792            vec![
793                BabyBear::new(2),
794                BabyBear::new(4),
795                BabyBear::new(6),
796                BabyBear::new(8),
797                BabyBear::new(10),
798                BabyBear::new(12)
799            ]
800        );
801    }
802
803    #[test]
804    fn test_scale_row() {
805        let mut matrix = RowMajorMatrix::new(
806            vec![
807                BabyBear::new(1),
808                BabyBear::new(2),
809                BabyBear::new(3),
810                BabyBear::new(4),
811                BabyBear::new(5),
812                BabyBear::new(6),
813            ],
814            2,
815        );
816        matrix.scale_row(1, BabyBear::new(3));
817        assert_eq!(
818            matrix.values,
819            vec![
820                BabyBear::new(1),
821                BabyBear::new(2),
822                BabyBear::new(9),
823                BabyBear::new(12),
824                BabyBear::new(5),
825                BabyBear::new(6),
826            ]
827        );
828    }
829
830    #[test]
831    fn test_to_row_major_matrix() {
832        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
833        let converted = matrix.to_row_major_matrix();
834
835        // The converted matrix should have the same values and width
836        assert_eq!(converted.width, 2);
837        assert_eq!(converted.height(), 3);
838        assert_eq!(converted.values, vec![1, 2, 3, 4, 5, 6]);
839    }
840
841    #[test]
842    fn test_horizontally_packed_row() {
843        type Packed = FieldArray<BabyBear, 2>;
844
845        let matrix = RowMajorMatrix::new(
846            vec![
847                BabyBear::new(1),
848                BabyBear::new(2),
849                BabyBear::new(3),
850                BabyBear::new(4),
851                BabyBear::new(5),
852                BabyBear::new(6),
853            ],
854            3,
855        );
856
857        let (packed_iter, suffix_iter) = matrix.horizontally_packed_row::<Packed>(1);
858
859        let packed: Vec<_> = packed_iter.collect();
860        let suffix: Vec<_> = suffix_iter.collect();
861
862        assert_eq!(
863            packed,
864            vec![Packed::from([BabyBear::new(4), BabyBear::new(5)])]
865        );
866        assert_eq!(suffix, vec![BabyBear::new(6)]);
867    }
868
869    #[test]
870    fn test_padded_horizontally_packed_row() {
871        use p3_baby_bear::BabyBear;
872
873        type Packed = FieldArray<BabyBear, 2>;
874
875        let matrix = RowMajorMatrix::new(
876            vec![
877                BabyBear::new(1),
878                BabyBear::new(2),
879                BabyBear::new(3),
880                BabyBear::new(4),
881                BabyBear::new(5),
882                BabyBear::new(6),
883            ],
884            3,
885        );
886
887        let packed_iter = matrix.padded_horizontally_packed_row::<Packed>(1);
888        let packed: Vec<_> = packed_iter.collect();
889
890        assert_eq!(
891            packed,
892            vec![
893                Packed::from([BabyBear::new(4), BabyBear::new(5)]),
894                Packed::from([BabyBear::new(6), BabyBear::new(0)])
895            ]
896        );
897    }
898
899    #[test]
900    fn test_pad_to_height() {
901        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
902
903        // Original matrix:
904        // [ 1  2  3 ]
905        // [ 4  5  6 ] (height = 2)
906
907        matrix.pad_to_height(4, 9);
908
909        // Expected matrix after padding:
910        // [ 1  2  3 ]
911        // [ 4  5  6 ]
912        // [ 9  9  9 ]  <-- Newly added row
913        // [ 9  9  9 ]  <-- Newly added row
914
915        assert_eq!(matrix.height(), 4);
916        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 9, 9, 9, 9, 9, 9]);
917    }
918
919    #[test]
920    fn test_transpose_into() {
921        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
922
923        // Original matrix:
924        // [ 1  2  3 ]
925        // [ 4  5  6 ]
926
927        let mut transposed = RowMajorMatrix::new(vec![0; 6], 2);
928
929        matrix.transpose_into(&mut transposed);
930
931        // Expected transposed matrix:
932        // [ 1  4 ]
933        // [ 2  5 ]
934        // [ 3  6 ]
935
936        assert_eq!(transposed.width, 2);
937        assert_eq!(transposed.height(), 3);
938        assert_eq!(transposed.values, vec![1, 4, 2, 5, 3, 6]);
939    }
940
941    #[test]
942    fn test_flatten_to_base() {
943        let matrix = RowMajorMatrix::new(
944            vec![
945                BabyBear::new(2),
946                BabyBear::new(3),
947                BabyBear::new(4),
948                BabyBear::new(5),
949            ],
950            2,
951        );
952
953        let flattened: RowMajorMatrix<BabyBear> = matrix.flatten_to_base();
954
955        assert_eq!(flattened.width, 2);
956        assert_eq!(
957            flattened.values,
958            vec![
959                BabyBear::new(2),
960                BabyBear::new(3),
961                BabyBear::new(4),
962                BabyBear::new(5),
963            ]
964        );
965    }
966
967    #[test]
968    fn test_horizontally_packed_row_mut() {
969        type Packed = FieldArray<BabyBear, 2>;
970
971        let mut matrix = RowMajorMatrix::new(
972            vec![
973                BabyBear::new(1),
974                BabyBear::new(2),
975                BabyBear::new(3),
976                BabyBear::new(4),
977                BabyBear::new(5),
978                BabyBear::new(6),
979            ],
980            3,
981        );
982
983        let (packed, suffix) = matrix.horizontally_packed_row_mut::<Packed>(1);
984        packed[0] = Packed::from([BabyBear::new(9), BabyBear::new(10)]);
985        suffix[0] = BabyBear::new(11);
986
987        assert_eq!(
988            matrix.values,
989            vec![
990                BabyBear::new(1),
991                BabyBear::new(2),
992                BabyBear::new(3),
993                BabyBear::new(9),
994                BabyBear::new(10),
995                BabyBear::new(11),
996            ]
997        );
998    }
999
1000    #[test]
1001    fn test_par_row_chunks() {
1002        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1003
1004        let chunks: Vec<_> = matrix.par_row_chunks(2).collect();
1005
1006        assert_eq!(chunks.len(), 2);
1007        assert_eq!(chunks[0].values, vec![1, 2, 3, 4]);
1008        assert_eq!(chunks[1].values, vec![5, 6, 7, 8]);
1009    }
1010
1011    #[test]
1012    fn test_par_row_chunks_exact() {
1013        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1014
1015        let chunks: Vec<_> = matrix.par_row_chunks_exact(1).collect();
1016
1017        assert_eq!(chunks.len(), 3);
1018        assert_eq!(chunks[0].values, vec![1, 2]);
1019        assert_eq!(chunks[1].values, vec![3, 4]);
1020        assert_eq!(chunks[2].values, vec![5, 6]);
1021    }
1022
1023    #[test]
1024    fn test_par_row_chunks_mut() {
1025        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1026
1027        matrix
1028            .par_row_chunks_mut(2)
1029            .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 10));
1030
1031        assert_eq!(matrix.values, vec![11, 12, 13, 14, 15, 16, 17, 18]);
1032    }
1033
1034    #[test]
1035    fn test_row_chunks_exact_mut() {
1036        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1037
1038        for chunk in matrix.row_chunks_exact_mut(1) {
1039            chunk.values.iter_mut().for_each(|x| *x *= 2);
1040        }
1041
1042        assert_eq!(matrix.values, vec![2, 4, 6, 8, 10, 12]);
1043    }
1044
1045    #[test]
1046    fn test_par_row_chunks_exact_mut() {
1047        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1048
1049        matrix
1050            .par_row_chunks_exact_mut(1)
1051            .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 5));
1052
1053        assert_eq!(matrix.values, vec![6, 7, 8, 9, 10, 11]);
1054    }
1055
1056    #[test]
1057    fn test_row_pair_mut() {
1058        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1059
1060        let (row1, row2) = matrix.row_pair_mut(0, 2);
1061        row1[0] = 9;
1062        row2[1] = 10;
1063
1064        assert_eq!(matrix.values, vec![9, 2, 3, 4, 5, 10]);
1065    }
1066
1067    #[test]
1068    fn test_packed_row_pair_mut() {
1069        type Packed = FieldArray<BabyBear, 2>;
1070
1071        let mut matrix = RowMajorMatrix::new(
1072            vec![
1073                BabyBear::new(1),
1074                BabyBear::new(2),
1075                BabyBear::new(3),
1076                BabyBear::new(4),
1077                BabyBear::new(5),
1078                BabyBear::new(6),
1079            ],
1080            3,
1081        );
1082
1083        let ((packed1, sfx1), (packed2, sfx2)) = matrix.packed_row_pair_mut::<Packed>(0, 1);
1084        packed1[0] = Packed::from([BabyBear::new(7), BabyBear::new(8)]);
1085        packed2[0] = Packed::from([BabyBear::new(33), BabyBear::new(44)]);
1086        sfx1[0] = BabyBear::new(99);
1087        sfx2[0] = BabyBear::new(9);
1088
1089        assert_eq!(
1090            matrix.values,
1091            vec![
1092                BabyBear::new(7),
1093                BabyBear::new(8),
1094                BabyBear::new(99),
1095                BabyBear::new(33),
1096                BabyBear::new(44),
1097                BabyBear::new(9),
1098            ]
1099        );
1100    }
1101
1102    #[test]
1103    fn test_transpose_square_matrix() {
1104        const START_INDEX: usize = 1;
1105        const VALUE_LEN: usize = 9;
1106        const WIDTH: usize = 3;
1107        const HEIGHT: usize = 3;
1108
1109        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1110        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1111        let transposed = matrix.transpose();
1112        let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
1113        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1114        assert_eq!(transposed, should_be_transposed);
1115    }
1116
1117    #[test]
1118    fn test_transpose_row_matrix() {
1119        const START_INDEX: usize = 1;
1120        const VALUE_LEN: usize = 30;
1121        const WIDTH: usize = 1;
1122        const HEIGHT: usize = 30;
1123
1124        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1125        let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
1126        let transposed = matrix.transpose();
1127        let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
1128        assert_eq!(transposed, should_be_transposed);
1129    }
1130
1131    #[test]
1132    fn test_transpose_rectangular_matrix() {
1133        const START_INDEX: usize = 1;
1134        const VALUE_LEN: usize = 30;
1135        const WIDTH: usize = 5;
1136        const HEIGHT: usize = 6;
1137
1138        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1139        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1140        let transposed = matrix.transpose();
1141        let should_be_transposed_values = vec![
1142            1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
1143            5, 10, 15, 20, 25, 30,
1144        ];
1145        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1146        assert_eq!(transposed, should_be_transposed);
1147    }
1148
1149    #[test]
1150    fn test_transpose_larger_rectangular_matrix() {
1151        const START_INDEX: usize = 1;
1152        const VALUE_LEN: usize = 131072; // 512 * 256
1153        const WIDTH: usize = 256;
1154        const HEIGHT: usize = 512;
1155
1156        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1157        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1158        let transposed = matrix.transpose();
1159
1160        assert_eq!(transposed.width(), HEIGHT);
1161        assert_eq!(transposed.height(), WIDTH);
1162
1163        for col_index in 0..WIDTH {
1164            for row_index in 0..HEIGHT {
1165                assert_eq!(
1166                    matrix.values[row_index * WIDTH + col_index],
1167                    transposed.values[col_index * HEIGHT + row_index]
1168                );
1169            }
1170        }
1171    }
1172
1173    #[test]
1174    fn test_transpose_very_large_rectangular_matrix() {
1175        const START_INDEX: usize = 1;
1176        const VALUE_LEN: usize = 1048576; // 512 * 256
1177        const WIDTH: usize = 1024;
1178        const HEIGHT: usize = 1024;
1179
1180        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1181        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1182        let transposed = matrix.transpose();
1183
1184        assert_eq!(transposed.width(), HEIGHT);
1185        assert_eq!(transposed.height(), WIDTH);
1186
1187        for col_index in 0..WIDTH {
1188            for row_index in 0..HEIGHT {
1189                assert_eq!(
1190                    matrix.values[row_index * WIDTH + col_index],
1191                    transposed.values[col_index * HEIGHT + row_index]
1192                );
1193            }
1194        }
1195    }
1196
1197    #[test]
1198    fn test_vertically_packed_row_pair() {
1199        type Packed = FieldArray<BabyBear, 2>;
1200
1201        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1202
1203        // Calling the function with r = 0 and step = 2
1204        let packed = matrix.vertically_packed_row_pair::<Packed>(0, 2);
1205
1206        // Matrix visualization:
1207        //
1208        // [  1   2   3   4  ]  <-- Row 0
1209        // [  5   6   7   8  ]  <-- Row 1
1210        // [  9  10  11  12  ]  <-- Row 2
1211        // [ 13  14  15  16  ]  <-- Row 3
1212        //
1213        // Packing rows 0-1 together, then rows 2-3 together:
1214        //
1215        // Packed result:
1216        // [
1217        //   (1, 5), (2, 6), (3, 7), (4, 8),   // First packed row (Row 0 & Row 1)
1218        //   (9, 13), (10, 14), (11, 15), (12, 16),   // Second packed row (Row 2 & Row 3)
1219        // ]
1220
1221        assert_eq!(
1222            packed,
1223            (1..5)
1224                .chain(9..13)
1225                .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1226                .collect::<Vec<_>>(),
1227        );
1228    }
1229
1230    #[test]
1231    fn test_vertically_packed_row_pair_overlap() {
1232        type Packed = FieldArray<BabyBear, 2>;
1233
1234        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1235
1236        // Original matrix visualization:
1237        //
1238        // [  1   2   3   4  ]  <-- Row 0
1239        // [  5   6   7   8  ]  <-- Row 1
1240        // [  9  10  11  12  ]  <-- Row 2
1241        // [ 13  14  15  16  ]  <-- Row 3
1242        //
1243        // Packing rows 0-1 together, then rows 1-2 together:
1244        //
1245        // Expected packed result:
1246        // [
1247        //   (1, 5), (2, 6), (3, 7), (4, 8),   // First packed row (Row 0 & Row 1)
1248        //   (5, 9), (6, 10), (7, 11), (8, 12) // Second packed row (Row 1 & Row 2)
1249        // ]
1250
1251        // Calling the function with overlapping rows (r = 0 and step = 1)
1252        let packed = matrix.vertically_packed_row_pair::<Packed>(0, 1);
1253
1254        assert_eq!(
1255            packed,
1256            (1..5)
1257                .chain(5..9)
1258                .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1259                .collect::<Vec<_>>(),
1260        );
1261    }
1262
1263    #[test]
1264    fn test_vertically_packed_row_pair_wraparound_start_1() {
1265        use p3_baby_bear::BabyBear;
1266        use p3_field::FieldArray;
1267
1268        type Packed = FieldArray<BabyBear, 2>;
1269
1270        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1271
1272        // Original matrix visualization:
1273        //
1274        // [  1   2   3   4  ]  <-- Row 0
1275        // [  5   6   7   8  ]  <-- Row 1
1276        // [  9  10  11  12  ]  <-- Row 2
1277        // [ 13  14  15  16  ]  <-- Row 3
1278        //
1279        // Packing starts from row 1, skipping 2 rows (step = 2):
1280        // - The first packed row should contain row 1 & row 2.
1281        // - The second packed row should contain row 3 & row 1 (wraparound case).
1282        //
1283        // Expected packed result:
1284        // [
1285        //   (5, 9), (6, 10), (7, 11), (8, 12),   // Packed row (Row 1 & Row 2)
1286        //   (13, 1), (14, 2), (15, 3), (16, 4)    // Packed row (Row 3 & Row 1)
1287        // ]
1288
1289        // Calling the function with wraparound scenario (starting at r = 1 with step = 2)
1290        let packed = matrix.vertically_packed_row_pair::<Packed>(1, 2);
1291
1292        assert_eq!(
1293            packed,
1294            vec![
1295                Packed::from([BabyBear::new(5), BabyBear::new(9)]),
1296                Packed::from([BabyBear::new(6), BabyBear::new(10)]),
1297                Packed::from([BabyBear::new(7), BabyBear::new(11)]),
1298                Packed::from([BabyBear::new(8), BabyBear::new(12)]),
1299                Packed::from([BabyBear::new(13), BabyBear::new(1)]),
1300                Packed::from([BabyBear::new(14), BabyBear::new(2)]),
1301                Packed::from([BabyBear::new(15), BabyBear::new(3)]),
1302                Packed::from([BabyBear::new(16), BabyBear::new(4)]),
1303            ]
1304        );
1305    }
1306}