p3_matrix/
dense.rs

1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::iter;
6use core::marker::PhantomData;
7use core::ops::Deref;
8
9use p3_field::{
10    ExtensionField, Field, PackedValue, par_scale_slice_in_place, scale_slice_in_place_single_core,
11};
12use p3_maybe_rayon::prelude::*;
13use rand::Rng;
14use rand::distr::{Distribution, StandardUniform};
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::Matrix;
19
20/// A dense matrix in row-major format, with customizable backing storage.
21///
22/// The data is stored as a flat buffer, where rows are laid out consecutively.
23#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct DenseMatrix<T, V = Vec<T>> {
25    /// Flat buffer of matrix values in row-major order.
26    pub values: V,
27    /// Number of columns in the matrix.
28    ///
29    /// The number of rows is implicitly determined as `values.len() / width`.
30    pub width: usize,
31    /// Marker for the element type `T`, unused directly.
32    ///
33    /// Required to retain type information when `V` does not own or contain `T`.
34    _phantom: PhantomData<T>,
35}
36
37pub type RowMajorMatrix<T> = DenseMatrix<T>;
38pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
39pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
40pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
41
42pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
43    fn to_vec(self) -> Vec<T>;
44}
45
46// Cow doesn't impl IntoOwned so we can't blanket it
47impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
48    fn to_vec(self) -> Self {
49        self
50    }
51}
52
53impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
54    fn to_vec(self) -> Vec<T> {
55        <[T]>::to_vec(self)
56    }
57}
58
59impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
60    fn to_vec(self) -> Vec<T> {
61        <[T]>::to_vec(self)
62    }
63}
64
65impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
66    fn to_vec(self) -> Vec<T> {
67        self.into_owned()
68    }
69}
70
71impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
72    /// Create a new dense matrix of the given dimensions, backed by a `Vec`, and filled with
73    /// default values.
74    #[must_use]
75    pub fn default(width: usize, height: usize) -> Self {
76        Self::new(vec![T::default(); width * height], width)
77    }
78}
79
80impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
81    /// Create a new dense matrix of the given dimensions, backed by the given storage.
82    ///
83    /// Note that it is undefined behavior to create a matrix such that
84    /// `values.len() % width != 0`.
85    #[must_use]
86    pub fn new(values: S, width: usize) -> Self {
87        debug_assert!(values.borrow().len().is_multiple_of(width));
88        Self {
89            values,
90            width,
91            _phantom: PhantomData,
92        }
93    }
94
95    /// Create a new RowMajorMatrix containing a single row.
96    #[must_use]
97    pub fn new_row(values: S) -> Self {
98        let width = values.borrow().len();
99        Self::new(values, width)
100    }
101
102    /// Create a new RowMajorMatrix containing a single column.
103    #[must_use]
104    pub fn new_col(values: S) -> Self {
105        Self::new(values, 1)
106    }
107
108    /// Get a view of the matrix, i.e. a reference to the underlying data.
109    pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
110        RowMajorMatrixView::new(self.values.borrow(), self.width)
111    }
112
113    /// Get a mutable view of the matrix, i.e. a mutable reference to the underlying data.
114    pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
115    where
116        S: BorrowMut<[T]>,
117    {
118        RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
119    }
120
121    /// Copy the values from the given matrix into this matrix.
122    pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
123    where
124        T: Copy,
125        S: BorrowMut<[T]>,
126        S2: DenseStorage<T>,
127    {
128        assert_eq!(self.dimensions(), source.dimensions());
129        // Equivalent to:
130        // self.values.borrow_mut().copy_from_slice(source.values.borrow());
131        self.par_rows_mut()
132            .zip(source.par_row_slices())
133            .for_each(|(dst, src)| {
134                dst.copy_from_slice(src);
135            });
136    }
137
138    /// Flatten an extension field matrix to a base field matrix.
139    pub fn flatten_to_base<F: Field>(self) -> RowMajorMatrix<F>
140    where
141        T: ExtensionField<F>,
142    {
143        let width = self.width * T::DIMENSION;
144        let values = T::flatten_to_base(self.values.to_vec());
145        RowMajorMatrix::new(values, width)
146    }
147
148    /// Get an iterator over the rows of the matrix.
149    pub fn row_slices(&self) -> impl DoubleEndedIterator<Item = &[T]> {
150        self.values.borrow().chunks_exact(self.width)
151    }
152
153    /// Get a parallel iterator over the rows of the matrix.
154    pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
155    where
156        T: Sync,
157    {
158        self.values.borrow().par_chunks_exact(self.width)
159    }
160
161    /// Returns a slice of the given row.
162    ///
163    /// # Panics
164    /// Panics if `r` larger than self.height().
165    pub fn row_mut(&mut self, r: usize) -> &mut [T]
166    where
167        S: BorrowMut<[T]>,
168    {
169        &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
170    }
171
172    /// Get a mutable iterator over the rows of the matrix.
173    pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
174    where
175        S: BorrowMut<[T]>,
176    {
177        self.values.borrow_mut().chunks_exact_mut(self.width)
178    }
179
180    /// Get a mutable parallel iterator over the rows of the matrix.
181    pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
182    where
183        T: 'a + Send,
184        S: BorrowMut<[T]>,
185    {
186        self.values.borrow_mut().par_chunks_exact_mut(self.width)
187    }
188
189    /// Get a mutable iterator over the rows of the matrix which packs the rows into packed values.
190    ///
191    /// If `P::WIDTH` does not divide `self.width`, the remainder of the row will be returned as a
192    /// base slice.
193    pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
194    where
195        P: PackedValue<Value = T>,
196        S: BorrowMut<[T]>,
197    {
198        P::pack_slice_with_suffix_mut(self.row_mut(r))
199    }
200
201    /// Scale the given row by the given value.
202    ///
203    /// # Panics
204    /// Panics if `r` larger than `self.height()`.
205    pub fn scale_row(&mut self, r: usize, scale: T)
206    where
207        T: Field,
208        S: BorrowMut<[T]>,
209    {
210        scale_slice_in_place_single_core(self.row_mut(r), scale);
211    }
212
213    /// Scale the given row by the given value.
214    ///
215    /// # Performance
216    /// This function is parallelized, which may introduce some overhead compared to
217    /// [`Self::scale_row`] when the width is small.
218    ///
219    /// # Panics
220    /// Panics if `r` larger than `self.height()`.
221    pub fn par_scale_row(&mut self, r: usize, scale: T)
222    where
223        T: Field,
224        S: BorrowMut<[T]>,
225    {
226        par_scale_slice_in_place(self.row_mut(r), scale);
227    }
228
229    /// Scale the entire matrix by the given value.
230    pub fn scale(&mut self, scale: T)
231    where
232        T: Field,
233        S: BorrowMut<[T]>,
234    {
235        par_scale_slice_in_place(self.values.borrow_mut(), scale);
236    }
237
238    /// Split the matrix into two matrix views, one with the first `r` rows and one with the remaining rows.
239    ///
240    /// # Panics
241    /// Panics if `r` larger than `self.height()`.
242    pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<'_, T>, RowMajorMatrixView<'_, T>) {
243        let (lo, hi) = self.values.borrow().split_at(r * self.width);
244        (
245            DenseMatrix::new(lo, self.width),
246            DenseMatrix::new(hi, self.width),
247        )
248    }
249
250    /// Split the matrix into two mutable matrix views, one with the first `r` rows and one with the remaining rows.
251    ///
252    /// # Panics
253    /// Panics if `r` larger than `self.height()`.
254    pub fn split_rows_mut(
255        &mut self,
256        r: usize,
257    ) -> (RowMajorMatrixViewMut<'_, T>, RowMajorMatrixViewMut<'_, T>)
258    where
259        S: BorrowMut<[T]>,
260    {
261        let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
262        (
263            DenseMatrix::new(lo, self.width),
264            DenseMatrix::new(hi, self.width),
265        )
266    }
267
268    /// Get an iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
269    ///
270    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
271    pub fn par_row_chunks(
272        &self,
273        chunk_rows: usize,
274    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
275    where
276        T: Send,
277    {
278        self.values
279            .borrow()
280            .par_chunks(self.width * chunk_rows)
281            .map(|slice| RowMajorMatrixView::new(slice, self.width))
282    }
283
284    /// Get a parallel iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
285    ///
286    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
287    pub fn par_row_chunks_exact(
288        &self,
289        chunk_rows: usize,
290    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
291    where
292        T: Send,
293    {
294        self.values
295            .borrow()
296            .par_chunks_exact(self.width * chunk_rows)
297            .map(|slice| RowMajorMatrixView::new(slice, self.width))
298    }
299
300    /// Get a mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
301    ///
302    /// If `chunk_rows` does not divide the height of the matrix, the last chunk will be smaller.
303    pub fn par_row_chunks_mut(
304        &mut self,
305        chunk_rows: usize,
306    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
307    where
308        T: Send,
309        S: BorrowMut<[T]>,
310    {
311        self.values
312            .borrow_mut()
313            .par_chunks_mut(self.width * chunk_rows)
314            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
315    }
316
317    /// Get a mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
318    ///
319    /// If `chunk_rows` does not divide the height of the matrix, the last up to `chunk_rows - 1` rows
320    /// of the matrix will be omitted.
321    pub fn row_chunks_exact_mut(
322        &mut self,
323        chunk_rows: usize,
324    ) -> impl Iterator<Item = RowMajorMatrixViewMut<'_, T>>
325    where
326        T: Send,
327        S: BorrowMut<[T]>,
328    {
329        self.values
330            .borrow_mut()
331            .chunks_exact_mut(self.width * chunk_rows)
332            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
333    }
334
335    /// Get a parallel mutable iterator over the rows of the matrix which takes `chunk_rows` rows at a time.
336    ///
337    /// If `chunk_rows` does not divide the height of the matrix, the last up to `chunk_rows - 1` rows
338    /// of the matrix will be omitted.
339    pub fn par_row_chunks_exact_mut(
340        &mut self,
341        chunk_rows: usize,
342    ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
343    where
344        T: Send,
345        S: BorrowMut<[T]>,
346    {
347        self.values
348            .borrow_mut()
349            .par_chunks_exact_mut(self.width * chunk_rows)
350            .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
351    }
352
353    /// Get a pair of mutable slices of the given rows.
354    ///
355    /// # Panics
356    /// Panics if `row_1` or `row_2` are out of bounds or if `row_1 >= row_2`.
357    pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
358    where
359        S: BorrowMut<[T]>,
360    {
361        debug_assert_ne!(row_1, row_2);
362        let start_1 = row_1 * self.width;
363        let start_2 = row_2 * self.width;
364        let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
365        (&mut lo[start_1..][..self.width], &mut hi[..self.width])
366    }
367
368    /// Get a pair of mutable slices of the given rows, both packed into packed field elements.
369    ///
370    /// If `P:WIDTH` does not divide `self.width`, the remainder of the row will be returned as a base slice.
371    ///
372    /// # Panics
373    /// Panics if `row_1` or `row_2` are out of bounds or if `row_1 >= row_2`.
374    #[allow(clippy::type_complexity)]
375    pub fn packed_row_pair_mut<P>(
376        &mut self,
377        row_1: usize,
378        row_2: usize,
379    ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
380    where
381        S: BorrowMut<[T]>,
382        P: PackedValue<Value = T>,
383    {
384        let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
385        (
386            P::pack_slice_with_suffix_mut(slice_1),
387            P::pack_slice_with_suffix_mut(slice_2),
388        )
389    }
390
391    /// Append zeros to the "end" of the given matrix, except that the matrix is in bit-reversed order,
392    /// so in actuality we're interleaving zero rows.
393    #[instrument(level = "debug", skip_all)]
394    pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
395    where
396        T: Field,
397    {
398        if added_bits == 0 {
399            return self.to_row_major_matrix();
400        }
401
402        // This is equivalent to:
403        //     reverse_matrix_index_bits(mat);
404        //     mat
405        //         .values
406        //         .resize(mat.values.len() << added_bits, F::ZERO);
407        //     reverse_matrix_index_bits(mat);
408        // But rather than implement it with bit reversals, we directly construct the resulting matrix,
409        // whose rows are zero except for rows whose low `added_bits` bits are zero.
410
411        let w = self.width;
412        let mut padded =
413            RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
414        padded
415            .par_row_chunks_exact_mut(1 << added_bits)
416            .zip(self.par_row_slices())
417            .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
418
419        padded
420    }
421}
422
423impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
424    #[inline]
425    fn width(&self) -> usize {
426        self.width
427    }
428
429    #[inline]
430    fn height(&self) -> usize {
431        if self.width == 0 {
432            0
433        } else {
434            self.values.borrow().len() / self.width
435        }
436    }
437
438    #[inline]
439    unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
440        unsafe {
441            // Safety: The caller must ensure that r < self.height() and c < self.width().
442            self.values
443                .borrow()
444                .get_unchecked(r * self.width + c)
445                .clone()
446        }
447    }
448
449    #[inline]
450    unsafe fn row_subseq_unchecked(
451        &self,
452        r: usize,
453        start: usize,
454        end: usize,
455    ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
456        unsafe {
457            // Safety: The caller must ensure that r < self.height() and start <= end <= self.width().
458            self.values
459                .borrow()
460                .get_unchecked(r * self.width + start..r * self.width + end)
461                .iter()
462                .cloned()
463        }
464    }
465
466    #[inline]
467    unsafe fn row_subslice_unchecked(
468        &self,
469        r: usize,
470        start: usize,
471        end: usize,
472    ) -> impl Deref<Target = [T]> {
473        unsafe {
474            // Safety: The caller must ensure that r < self.height()
475            self.values
476                .borrow()
477                .get_unchecked(r * self.width + start..r * self.width + end)
478        }
479    }
480
481    fn to_row_major_matrix(self) -> RowMajorMatrix<T>
482    where
483        Self: Sized,
484        T: Clone,
485    {
486        RowMajorMatrix::new(self.values.to_vec(), self.width)
487    }
488
489    #[inline]
490    fn horizontally_packed_row<'a, P>(
491        &'a self,
492        r: usize,
493    ) -> (
494        impl Iterator<Item = P> + Send + Sync,
495        impl Iterator<Item = T> + Send + Sync,
496    )
497    where
498        P: PackedValue<Value = T>,
499        T: Clone + 'a,
500    {
501        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
502        let (packed, sfx) = P::pack_slice_with_suffix(buf);
503        (packed.iter().copied(), sfx.iter().cloned())
504    }
505
506    #[inline]
507    fn padded_horizontally_packed_row<'a, P>(
508        &'a self,
509        r: usize,
510    ) -> impl Iterator<Item = P> + Send + Sync
511    where
512        P: PackedValue<Value = T>,
513        T: Clone + Default + 'a,
514    {
515        let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
516        let (packed, sfx) = P::pack_slice_with_suffix(buf);
517        packed.iter().copied().chain(iter::once(P::from_fn(|i| {
518            sfx.get(i).cloned().unwrap_or_default()
519        })))
520    }
521}
522
523impl<T: Clone + Default + Send + Sync> DenseMatrix<T> {
524    pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
525        RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
526    }
527
528    pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
529    where
530        StandardUniform: Distribution<T>,
531    {
532        let values = rng.sample_iter(StandardUniform).take(rows * cols).collect();
533        Self::new(values, cols)
534    }
535
536    pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
537    where
538        T: Field,
539        StandardUniform: Distribution<T>,
540    {
541        let values = rng
542            .sample_iter(StandardUniform)
543            .filter(|x| !x.is_zero())
544            .take(rows * cols)
545            .collect();
546        Self::new(values, cols)
547    }
548
549    pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
550        assert!(new_height >= self.height());
551        self.values.resize(self.width * new_height, fill);
552    }
553
554    /// Pad the matrix height to the next power of two by appending rows filled with `fill`.
555    ///
556    /// This is commonly used in proof systems where trace matrices must have power-of-two heights.
557    ///
558    /// # Behavior
559    ///
560    /// - If the matrix is empty (height = 0), it is padded to have exactly one row of `fill` values.
561    /// - If the height is already a power of two, the matrix is unchanged.
562    /// - Otherwise, the matrix is padded to the next power of two height.
563    pub fn pad_to_power_of_two_height(&mut self, fill: T) {
564        // Compute the target height as the next power of two.
565        let target_height = self.height().next_power_of_two();
566
567        // If target_height == height, resize will have no effect.
568        // Otherwise we pad the matrix to a power of two height by filling with the supplied value.
569        self.values.resize(self.width * target_height, fill);
570    }
571}
572
573impl<T: Copy + Default + Send + Sync, V: DenseStorage<T>> DenseMatrix<T, V> {
574    /// Return the transpose of this matrix.
575    pub fn transpose(&self) -> RowMajorMatrix<T> {
576        let nelts = self.height() * self.width();
577        let mut values = vec![T::default(); nelts];
578        transpose::transpose(
579            self.values.borrow(),
580            &mut values,
581            self.width(),
582            self.height(),
583        );
584        RowMajorMatrix::new(values, self.height())
585    }
586
587    /// Transpose the matrix returning the result in `other` without intermediate allocation.
588    pub fn transpose_into<W: DenseStorage<T> + BorrowMut<[T]>>(
589        &self,
590        other: &mut DenseMatrix<T, W>,
591    ) {
592        assert_eq!(self.height(), other.width());
593        assert_eq!(other.height(), self.width());
594        transpose::transpose(
595            self.values.borrow(),
596            other.values.borrow_mut(),
597            self.width(),
598            self.height(),
599        );
600    }
601}
602
603impl<'a, T: Clone + Default + Send + Sync> RowMajorMatrixView<'a, T> {
604    pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
605        RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
606    }
607}
608
609#[cfg(test)]
610mod tests {
611    use p3_baby_bear::BabyBear;
612    use p3_field::FieldArray;
613
614    use super::*;
615
616    #[test]
617    fn test_new() {
618        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
619        assert_eq!(matrix.width, 2);
620        assert_eq!(matrix.height(), 3);
621        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6]);
622    }
623
624    #[test]
625    fn test_new_row() {
626        let matrix = RowMajorMatrix::new_row(vec![1, 2, 3]);
627        assert_eq!(matrix.width, 3);
628        assert_eq!(matrix.height(), 1);
629    }
630
631    #[test]
632    fn test_new_col() {
633        let matrix = RowMajorMatrix::new_col(vec![1, 2, 3]);
634        assert_eq!(matrix.width, 1);
635        assert_eq!(matrix.height(), 3);
636    }
637
638    #[test]
639    fn test_height_with_zero_width() {
640        let matrix: DenseMatrix<i32> = RowMajorMatrix::new(vec![], 0);
641        assert_eq!(matrix.height(), 0);
642    }
643
644    #[test]
645    fn test_get_methods() {
646        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2); // Height = 3, Width = 2
647        assert_eq!(matrix.get(0, 0), Some(1));
648        assert_eq!(matrix.get(1, 1), Some(4));
649        assert_eq!(matrix.get(2, 0), Some(5));
650        unsafe {
651            assert_eq!(matrix.get_unchecked(0, 1), 2);
652            assert_eq!(matrix.get_unchecked(1, 0), 3);
653            assert_eq!(matrix.get_unchecked(2, 1), 6);
654        }
655        assert_eq!(matrix.get(3, 0), None); // Height out of bounds
656        assert_eq!(matrix.get(0, 2), None); // Width out of bounds
657    }
658
659    #[test]
660    fn test_row_methods() {
661        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4); // Height = 2, Width = 4
662        let row: Vec<_> = matrix.row(1).unwrap().into_iter().collect();
663        assert_eq!(row, vec![5, 6, 7, 8]);
664        unsafe {
665            let row: Vec<_> = matrix.row_unchecked(0).into_iter().collect();
666            assert_eq!(row, vec![1, 2, 3, 4]);
667            let row: Vec<_> = matrix.row_subseq_unchecked(0, 0, 3).into_iter().collect();
668            assert_eq!(row, vec![1, 2, 3]);
669            let row: Vec<_> = matrix.row_subseq_unchecked(0, 1, 3).into_iter().collect();
670            assert_eq!(row, vec![2, 3]);
671            let row: Vec<_> = matrix.row_subseq_unchecked(0, 2, 4).into_iter().collect();
672            assert_eq!(row, vec![3, 4]);
673        }
674        assert!(matrix.row(2).is_none()); // Height out of bounds
675    }
676
677    #[test]
678    fn test_row_slice_methods() {
679        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], 3); // Height = 3, Width = 3
680        let slice0 = matrix.row_slice(0);
681        let slice2 = matrix.row_slice(2);
682        assert_eq!(slice0.unwrap().deref(), &[1, 2, 3]);
683        assert_eq!(slice2.unwrap().deref(), &[7, 8, 9]);
684        unsafe {
685            assert_eq!(&[1, 2, 3], matrix.row_slice_unchecked(0).deref());
686            assert_eq!(&[7, 8, 9], matrix.row_slice_unchecked(2).deref());
687
688            assert_eq!(&[1, 2, 3], matrix.row_subslice_unchecked(0, 0, 3).deref());
689            assert_eq!(&[8], matrix.row_subslice_unchecked(2, 1, 2).deref());
690        }
691        assert!(matrix.row_slice(3).is_none()); // Height out of bounds
692    }
693
694    #[test]
695    fn test_as_view() {
696        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
697        let view = matrix.as_view();
698        assert_eq!(view.values, &[1, 2, 3, 4]);
699        assert_eq!(view.width, 2);
700    }
701
702    #[test]
703    fn test_as_view_mut() {
704        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
705        let view = matrix.as_view_mut();
706        view.values[0] = 10;
707        assert_eq!(matrix.values, vec![10, 2, 3, 4]);
708    }
709
710    #[test]
711    fn test_copy_from() {
712        let mut matrix1 = RowMajorMatrix::new(vec![0, 0, 0, 0], 2);
713        let matrix2 = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
714        matrix1.copy_from(&matrix2);
715        assert_eq!(matrix1.values, vec![1, 2, 3, 4]);
716    }
717
718    #[test]
719    fn test_split_rows() {
720        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
721        let (top, bottom) = matrix.split_rows(1);
722        assert_eq!(top.values, vec![1, 2]);
723        assert_eq!(bottom.values, vec![3, 4, 5, 6]);
724    }
725
726    #[test]
727    fn test_split_rows_mut() {
728        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
729        let (top, bottom) = matrix.split_rows_mut(1);
730        assert_eq!(top.values, vec![1, 2]);
731        assert_eq!(bottom.values, vec![3, 4, 5, 6]);
732    }
733
734    #[test]
735    fn test_row_mut() {
736        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
737        matrix.row_mut(1)[0] = 10;
738        assert_eq!(matrix.values, vec![1, 2, 10, 4, 5, 6]);
739    }
740
741    #[test]
742    fn test_bit_reversed_zero_pad() {
743        let matrix = RowMajorMatrix::new(
744            vec![
745                BabyBear::new(1),
746                BabyBear::new(2),
747                BabyBear::new(3),
748                BabyBear::new(4),
749            ],
750            2,
751        );
752        let padded = matrix.bit_reversed_zero_pad(1);
753        assert_eq!(padded.width, 2);
754        assert_eq!(
755            padded.values,
756            vec![
757                BabyBear::new(1),
758                BabyBear::new(2),
759                BabyBear::new(0),
760                BabyBear::new(0),
761                BabyBear::new(3),
762                BabyBear::new(4),
763                BabyBear::new(0),
764                BabyBear::new(0)
765            ]
766        );
767    }
768
769    #[test]
770    fn test_bit_reversed_zero_pad_no_change() {
771        let matrix = RowMajorMatrix::new(
772            vec![
773                BabyBear::new(1),
774                BabyBear::new(2),
775                BabyBear::new(3),
776                BabyBear::new(4),
777            ],
778            2,
779        );
780        let padded = matrix.bit_reversed_zero_pad(0);
781
782        assert_eq!(padded.width, 2);
783        assert_eq!(
784            padded.values,
785            vec![
786                BabyBear::new(1),
787                BabyBear::new(2),
788                BabyBear::new(3),
789                BabyBear::new(4),
790            ]
791        );
792    }
793
794    #[test]
795    fn test_scale() {
796        let mut matrix = RowMajorMatrix::new(
797            vec![
798                BabyBear::new(1),
799                BabyBear::new(2),
800                BabyBear::new(3),
801                BabyBear::new(4),
802                BabyBear::new(5),
803                BabyBear::new(6),
804            ],
805            2,
806        );
807        matrix.scale(BabyBear::new(2));
808        assert_eq!(
809            matrix.values,
810            vec![
811                BabyBear::new(2),
812                BabyBear::new(4),
813                BabyBear::new(6),
814                BabyBear::new(8),
815                BabyBear::new(10),
816                BabyBear::new(12)
817            ]
818        );
819    }
820
821    #[test]
822    fn test_scale_row() {
823        let mut matrix = RowMajorMatrix::new(
824            vec![
825                BabyBear::new(1),
826                BabyBear::new(2),
827                BabyBear::new(3),
828                BabyBear::new(4),
829                BabyBear::new(5),
830                BabyBear::new(6),
831            ],
832            2,
833        );
834        matrix.scale_row(1, BabyBear::new(3));
835        assert_eq!(
836            matrix.values,
837            vec![
838                BabyBear::new(1),
839                BabyBear::new(2),
840                BabyBear::new(9),
841                BabyBear::new(12),
842                BabyBear::new(5),
843                BabyBear::new(6),
844            ]
845        );
846    }
847
848    #[test]
849    fn test_to_row_major_matrix() {
850        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
851        let converted = matrix.to_row_major_matrix();
852
853        // The converted matrix should have the same values and width
854        assert_eq!(converted.width, 2);
855        assert_eq!(converted.height(), 3);
856        assert_eq!(converted.values, vec![1, 2, 3, 4, 5, 6]);
857    }
858
859    #[test]
860    fn test_horizontally_packed_row() {
861        type Packed = FieldArray<BabyBear, 2>;
862
863        let matrix = RowMajorMatrix::new(
864            vec![
865                BabyBear::new(1),
866                BabyBear::new(2),
867                BabyBear::new(3),
868                BabyBear::new(4),
869                BabyBear::new(5),
870                BabyBear::new(6),
871            ],
872            3,
873        );
874
875        let (packed_iter, suffix_iter) = matrix.horizontally_packed_row::<Packed>(1);
876
877        let packed: Vec<_> = packed_iter.collect();
878        let suffix: Vec<_> = suffix_iter.collect();
879
880        assert_eq!(
881            packed,
882            vec![Packed::from([BabyBear::new(4), BabyBear::new(5)])]
883        );
884        assert_eq!(suffix, vec![BabyBear::new(6)]);
885    }
886
887    #[test]
888    fn test_padded_horizontally_packed_row() {
889        use p3_baby_bear::BabyBear;
890
891        type Packed = FieldArray<BabyBear, 2>;
892
893        let matrix = RowMajorMatrix::new(
894            vec![
895                BabyBear::new(1),
896                BabyBear::new(2),
897                BabyBear::new(3),
898                BabyBear::new(4),
899                BabyBear::new(5),
900                BabyBear::new(6),
901            ],
902            3,
903        );
904
905        let packed_iter = matrix.padded_horizontally_packed_row::<Packed>(1);
906        let packed: Vec<_> = packed_iter.collect();
907
908        assert_eq!(
909            packed,
910            vec![
911                Packed::from([BabyBear::new(4), BabyBear::new(5)]),
912                Packed::from([BabyBear::new(6), BabyBear::new(0)])
913            ]
914        );
915    }
916
917    #[test]
918    fn test_pad_to_height() {
919        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
920
921        // Original matrix:
922        // [ 1  2  3 ]
923        // [ 4  5  6 ] (height = 2)
924
925        matrix.pad_to_height(4, 9);
926
927        // Expected matrix after padding:
928        // [ 1  2  3 ]
929        // [ 4  5  6 ]
930        // [ 9  9  9 ]  <-- Newly added row
931        // [ 9  9  9 ]  <-- Newly added row
932
933        assert_eq!(matrix.height(), 4);
934        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 9, 9, 9, 9, 9, 9]);
935    }
936
937    #[test]
938    fn test_pad_to_power_of_two_height() {
939        // Test 1: Non-power-of-two height (3 rows -> 4 rows) with fill value 0.
940        //
941        // - Original matrix has 3 rows, which is not a power of two.
942        // - After padding, it should have 4 rows (next power of two).
943        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
944        assert_eq!(matrix.height(), 3);
945        matrix.pad_to_power_of_two_height(0);
946        assert_eq!(matrix.height(), 4);
947        // Original values preserved, new row filled with 0.
948        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 0, 0]);
949
950        // Test 2: Already power-of-two height (4 rows -> 4 rows, unchanged).
951        //
952        // Matrix height is already a power of two, so no padding occurs.
953        // Fill value is ignored when no padding is needed.
954        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
955        assert_eq!(matrix.height(), 4);
956        matrix.pad_to_power_of_two_height(99);
957        assert_eq!(matrix.height(), 4);
958        // Values unchanged (fill value not used).
959        assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
960
961        // Test 3: Single row matrix (1 row -> 1 row, unchanged).
962        //
963        // Height of 1 is a power of two (2^0 = 1).
964        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3], 3);
965        assert_eq!(matrix.height(), 1);
966        matrix.pad_to_power_of_two_height(42);
967        assert_eq!(matrix.height(), 1);
968        assert_eq!(matrix.values, vec![1, 2, 3]);
969
970        // Test 4: 5 rows -> 8 rows with custom fill value (-1).
971        //
972        // Demonstrates padding across a larger gap with a non-zero fill value.
973        let mut matrix = RowMajorMatrix::new(vec![1; 10], 2);
974        assert_eq!(matrix.height(), 5);
975        matrix.pad_to_power_of_two_height(-1);
976        assert_eq!(matrix.height(), 8);
977        // Original 10 values plus 6 fill values (3 new rows * 2 width).
978        assert_eq!(matrix.values.len(), 16);
979        assert!(matrix.values[..10].iter().all(|&v| v == 1));
980        assert!(matrix.values[10..].iter().all(|&v| v == -1));
981    }
982
983    #[test]
984    fn test_pad_to_power_of_two_height_empty_matrix() {
985        // Empty matrix (0 rows) should be padded to 1 row of fill values.
986        // This ensures the matrix is valid for downstream operations.
987        let mut matrix: RowMajorMatrix<i32> = RowMajorMatrix::new(vec![], 3);
988        assert_eq!(matrix.height(), 0);
989        assert_eq!(matrix.width, 3);
990        matrix.pad_to_power_of_two_height(7);
991        // After padding: 1 row with 3 columns, all filled with 7.
992        assert_eq!(matrix.height(), 1);
993        assert_eq!(matrix.values, vec![7, 7, 7]);
994    }
995
996    #[test]
997    fn test_transpose_into() {
998        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
999
1000        // Original matrix:
1001        // [ 1  2  3 ]
1002        // [ 4  5  6 ]
1003
1004        let mut transposed = RowMajorMatrix::new(vec![0; 6], 2);
1005
1006        matrix.transpose_into(&mut transposed);
1007
1008        // Expected transposed matrix:
1009        // [ 1  4 ]
1010        // [ 2  5 ]
1011        // [ 3  6 ]
1012
1013        assert_eq!(transposed.width, 2);
1014        assert_eq!(transposed.height(), 3);
1015        assert_eq!(transposed.values, vec![1, 4, 2, 5, 3, 6]);
1016    }
1017
1018    #[test]
1019    fn test_flatten_to_base() {
1020        let matrix = RowMajorMatrix::new(
1021            vec![
1022                BabyBear::new(2),
1023                BabyBear::new(3),
1024                BabyBear::new(4),
1025                BabyBear::new(5),
1026            ],
1027            2,
1028        );
1029
1030        let flattened: RowMajorMatrix<BabyBear> = matrix.flatten_to_base();
1031
1032        assert_eq!(flattened.width, 2);
1033        assert_eq!(
1034            flattened.values,
1035            vec![
1036                BabyBear::new(2),
1037                BabyBear::new(3),
1038                BabyBear::new(4),
1039                BabyBear::new(5),
1040            ]
1041        );
1042    }
1043
1044    #[test]
1045    fn test_horizontally_packed_row_mut() {
1046        type Packed = FieldArray<BabyBear, 2>;
1047
1048        let mut matrix = RowMajorMatrix::new(
1049            vec![
1050                BabyBear::new(1),
1051                BabyBear::new(2),
1052                BabyBear::new(3),
1053                BabyBear::new(4),
1054                BabyBear::new(5),
1055                BabyBear::new(6),
1056            ],
1057            3,
1058        );
1059
1060        let (packed, suffix) = matrix.horizontally_packed_row_mut::<Packed>(1);
1061        packed[0] = Packed::from([BabyBear::new(9), BabyBear::new(10)]);
1062        suffix[0] = BabyBear::new(11);
1063
1064        assert_eq!(
1065            matrix.values,
1066            vec![
1067                BabyBear::new(1),
1068                BabyBear::new(2),
1069                BabyBear::new(3),
1070                BabyBear::new(9),
1071                BabyBear::new(10),
1072                BabyBear::new(11),
1073            ]
1074        );
1075    }
1076
1077    #[test]
1078    fn test_par_row_chunks() {
1079        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1080
1081        let chunks: Vec<_> = matrix.par_row_chunks(2).collect();
1082
1083        assert_eq!(chunks.len(), 2);
1084        assert_eq!(chunks[0].values, vec![1, 2, 3, 4]);
1085        assert_eq!(chunks[1].values, vec![5, 6, 7, 8]);
1086    }
1087
1088    #[test]
1089    fn test_par_row_chunks_exact() {
1090        let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1091
1092        let chunks: Vec<_> = matrix.par_row_chunks_exact(1).collect();
1093
1094        assert_eq!(chunks.len(), 3);
1095        assert_eq!(chunks[0].values, vec![1, 2]);
1096        assert_eq!(chunks[1].values, vec![3, 4]);
1097        assert_eq!(chunks[2].values, vec![5, 6]);
1098    }
1099
1100    #[test]
1101    fn test_par_row_chunks_mut() {
1102        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1103
1104        matrix
1105            .par_row_chunks_mut(2)
1106            .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 10));
1107
1108        assert_eq!(matrix.values, vec![11, 12, 13, 14, 15, 16, 17, 18]);
1109    }
1110
1111    #[test]
1112    fn test_row_chunks_exact_mut() {
1113        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1114
1115        for chunk in matrix.row_chunks_exact_mut(1) {
1116            chunk.values.iter_mut().for_each(|x| *x *= 2);
1117        }
1118
1119        assert_eq!(matrix.values, vec![2, 4, 6, 8, 10, 12]);
1120    }
1121
1122    #[test]
1123    fn test_par_row_chunks_exact_mut() {
1124        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1125
1126        matrix
1127            .par_row_chunks_exact_mut(1)
1128            .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 5));
1129
1130        assert_eq!(matrix.values, vec![6, 7, 8, 9, 10, 11]);
1131    }
1132
1133    #[test]
1134    fn test_row_pair_mut() {
1135        let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1136
1137        let (row1, row2) = matrix.row_pair_mut(0, 2);
1138        row1[0] = 9;
1139        row2[1] = 10;
1140
1141        assert_eq!(matrix.values, vec![9, 2, 3, 4, 5, 10]);
1142    }
1143
1144    #[test]
1145    fn test_packed_row_pair_mut() {
1146        type Packed = FieldArray<BabyBear, 2>;
1147
1148        let mut matrix = RowMajorMatrix::new(
1149            vec![
1150                BabyBear::new(1),
1151                BabyBear::new(2),
1152                BabyBear::new(3),
1153                BabyBear::new(4),
1154                BabyBear::new(5),
1155                BabyBear::new(6),
1156            ],
1157            3,
1158        );
1159
1160        let ((packed1, sfx1), (packed2, sfx2)) = matrix.packed_row_pair_mut::<Packed>(0, 1);
1161        packed1[0] = Packed::from([BabyBear::new(7), BabyBear::new(8)]);
1162        packed2[0] = Packed::from([BabyBear::new(33), BabyBear::new(44)]);
1163        sfx1[0] = BabyBear::new(99);
1164        sfx2[0] = BabyBear::new(9);
1165
1166        assert_eq!(
1167            matrix.values,
1168            vec![
1169                BabyBear::new(7),
1170                BabyBear::new(8),
1171                BabyBear::new(99),
1172                BabyBear::new(33),
1173                BabyBear::new(44),
1174                BabyBear::new(9),
1175            ]
1176        );
1177    }
1178
1179    #[test]
1180    fn test_transpose_square_matrix() {
1181        const START_INDEX: usize = 1;
1182        const VALUE_LEN: usize = 9;
1183        const WIDTH: usize = 3;
1184        const HEIGHT: usize = 3;
1185
1186        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1187        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1188        let transposed = matrix.transpose();
1189        let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
1190        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1191        assert_eq!(transposed, should_be_transposed);
1192    }
1193
1194    #[test]
1195    fn test_transpose_row_matrix() {
1196        const START_INDEX: usize = 1;
1197        const VALUE_LEN: usize = 30;
1198        const WIDTH: usize = 1;
1199        const HEIGHT: usize = 30;
1200
1201        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1202        let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
1203        let transposed = matrix.transpose();
1204        let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
1205        assert_eq!(transposed, should_be_transposed);
1206    }
1207
1208    #[test]
1209    fn test_transpose_rectangular_matrix() {
1210        const START_INDEX: usize = 1;
1211        const VALUE_LEN: usize = 30;
1212        const WIDTH: usize = 5;
1213        const HEIGHT: usize = 6;
1214
1215        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1216        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1217        let transposed = matrix.transpose();
1218        let should_be_transposed_values = vec![
1219            1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
1220            5, 10, 15, 20, 25, 30,
1221        ];
1222        let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1223        assert_eq!(transposed, should_be_transposed);
1224    }
1225
1226    #[test]
1227    fn test_transpose_larger_rectangular_matrix() {
1228        const START_INDEX: usize = 1;
1229        const VALUE_LEN: usize = 131072; // 512 * 256
1230        const WIDTH: usize = 256;
1231        const HEIGHT: usize = 512;
1232
1233        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1234        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1235        let transposed = matrix.transpose();
1236
1237        assert_eq!(transposed.width(), HEIGHT);
1238        assert_eq!(transposed.height(), WIDTH);
1239
1240        for col_index in 0..WIDTH {
1241            for row_index in 0..HEIGHT {
1242                assert_eq!(
1243                    matrix.values[row_index * WIDTH + col_index],
1244                    transposed.values[col_index * HEIGHT + row_index]
1245                );
1246            }
1247        }
1248    }
1249
1250    #[test]
1251    fn test_transpose_very_large_rectangular_matrix() {
1252        const START_INDEX: usize = 1;
1253        const VALUE_LEN: usize = 1048576; // 512 * 256
1254        const WIDTH: usize = 1024;
1255        const HEIGHT: usize = 1024;
1256
1257        let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1258        let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1259        let transposed = matrix.transpose();
1260
1261        assert_eq!(transposed.width(), HEIGHT);
1262        assert_eq!(transposed.height(), WIDTH);
1263
1264        for col_index in 0..WIDTH {
1265            for row_index in 0..HEIGHT {
1266                assert_eq!(
1267                    matrix.values[row_index * WIDTH + col_index],
1268                    transposed.values[col_index * HEIGHT + row_index]
1269                );
1270            }
1271        }
1272    }
1273
1274    #[test]
1275    fn test_vertically_packed_row_pair() {
1276        type Packed = FieldArray<BabyBear, 2>;
1277
1278        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1279
1280        // Calling the function with r = 0 and step = 2
1281        let packed = matrix.vertically_packed_row_pair::<Packed>(0, 2);
1282
1283        // Matrix visualization:
1284        //
1285        // [  1   2   3   4  ]  <-- Row 0
1286        // [  5   6   7   8  ]  <-- Row 1
1287        // [  9  10  11  12  ]  <-- Row 2
1288        // [ 13  14  15  16  ]  <-- Row 3
1289        //
1290        // Packing rows 0-1 together, then rows 2-3 together:
1291        //
1292        // Packed result:
1293        // [
1294        //   (1, 5), (2, 6), (3, 7), (4, 8),   // First packed row (Row 0 & Row 1)
1295        //   (9, 13), (10, 14), (11, 15), (12, 16),   // Second packed row (Row 2 & Row 3)
1296        // ]
1297
1298        assert_eq!(
1299            packed,
1300            (1..5)
1301                .chain(9..13)
1302                .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1303                .collect::<Vec<_>>(),
1304        );
1305    }
1306
1307    #[test]
1308    fn test_vertically_packed_row_pair_overlap() {
1309        type Packed = FieldArray<BabyBear, 2>;
1310
1311        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1312
1313        // Original matrix visualization:
1314        //
1315        // [  1   2   3   4  ]  <-- Row 0
1316        // [  5   6   7   8  ]  <-- Row 1
1317        // [  9  10  11  12  ]  <-- Row 2
1318        // [ 13  14  15  16  ]  <-- Row 3
1319        //
1320        // Packing rows 0-1 together, then rows 1-2 together:
1321        //
1322        // Expected packed result:
1323        // [
1324        //   (1, 5), (2, 6), (3, 7), (4, 8),   // First packed row (Row 0 & Row 1)
1325        //   (5, 9), (6, 10), (7, 11), (8, 12) // Second packed row (Row 1 & Row 2)
1326        // ]
1327
1328        // Calling the function with overlapping rows (r = 0 and step = 1)
1329        let packed = matrix.vertically_packed_row_pair::<Packed>(0, 1);
1330
1331        assert_eq!(
1332            packed,
1333            (1..5)
1334                .chain(5..9)
1335                .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1336                .collect::<Vec<_>>(),
1337        );
1338    }
1339
1340    #[test]
1341    fn test_vertically_packed_row_pair_wraparound_start_1() {
1342        use p3_baby_bear::BabyBear;
1343        use p3_field::FieldArray;
1344
1345        type Packed = FieldArray<BabyBear, 2>;
1346
1347        let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1348
1349        // Original matrix visualization:
1350        //
1351        // [  1   2   3   4  ]  <-- Row 0
1352        // [  5   6   7   8  ]  <-- Row 1
1353        // [  9  10  11  12  ]  <-- Row 2
1354        // [ 13  14  15  16  ]  <-- Row 3
1355        //
1356        // Packing starts from row 1, skipping 2 rows (step = 2):
1357        // - The first packed row should contain row 1 & row 2.
1358        // - The second packed row should contain row 3 & row 1 (wraparound case).
1359        //
1360        // Expected packed result:
1361        // [
1362        //   (5, 9), (6, 10), (7, 11), (8, 12),   // Packed row (Row 1 & Row 2)
1363        //   (13, 1), (14, 2), (15, 3), (16, 4)    // Packed row (Row 3 & Row 1)
1364        // ]
1365
1366        // Calling the function with wraparound scenario (starting at r = 1 with step = 2)
1367        let packed = matrix.vertically_packed_row_pair::<Packed>(1, 2);
1368
1369        assert_eq!(
1370            packed,
1371            vec![
1372                Packed::from([BabyBear::new(5), BabyBear::new(9)]),
1373                Packed::from([BabyBear::new(6), BabyBear::new(10)]),
1374                Packed::from([BabyBear::new(7), BabyBear::new(11)]),
1375                Packed::from([BabyBear::new(8), BabyBear::new(12)]),
1376                Packed::from([BabyBear::new(13), BabyBear::new(1)]),
1377                Packed::from([BabyBear::new(14), BabyBear::new(2)]),
1378                Packed::from([BabyBear::new(15), BabyBear::new(3)]),
1379                Packed::from([BabyBear::new(16), BabyBear::new(4)]),
1380            ]
1381        );
1382    }
1383}