1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::marker::PhantomData;
6use core::ops::Deref;
7
8use p3_field::{
9 ExtensionField, Field, PackedValue, par_scale_slice_in_place, scale_slice_in_place_single_core,
10};
11use p3_maybe_rayon::prelude::*;
12use rand::distr::{Distribution, StandardUniform};
13use rand::{Rng, RngExt};
14use serde::{Deserialize, Serialize};
15use tracing::instrument;
16
17use crate::Matrix;
18
19#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
23pub struct DenseMatrix<T, V = Vec<T>> {
24 pub values: V,
26 pub width: usize,
30 _phantom: PhantomData<T>,
34}
35
36pub type RowMajorMatrix<T> = DenseMatrix<T>;
37pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
38pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
39pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
40
41pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
42 fn to_vec(self) -> Vec<T>;
43}
44
45impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
47 fn to_vec(self) -> Self {
48 self
49 }
50}
51
52impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
53 fn to_vec(self) -> Vec<T> {
54 <[T]>::to_vec(self)
55 }
56}
57
58impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
59 fn to_vec(self) -> Vec<T> {
60 <[T]>::to_vec(self)
61 }
62}
63
64impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
65 fn to_vec(self) -> Vec<T> {
66 self.into_owned()
67 }
68}
69
70impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
71 #[must_use]
74 pub fn default(width: usize, height: usize) -> Self {
75 Self::new(vec![T::default(); width * height], width)
76 }
77}
78
79impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
80 #[must_use]
85 pub fn new(values: S, width: usize) -> Self {
86 debug_assert!(values.borrow().len().is_multiple_of(width));
87 Self {
88 values,
89 width,
90 _phantom: PhantomData,
91 }
92 }
93
94 #[must_use]
96 pub fn new_row(values: S) -> Self {
97 let width = values.borrow().len();
98 Self::new(values, width)
99 }
100
101 #[must_use]
103 pub fn new_col(values: S) -> Self {
104 Self::new(values, 1)
105 }
106
107 pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
109 RowMajorMatrixView::new(self.values.borrow(), self.width)
110 }
111
112 pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
114 where
115 S: BorrowMut<[T]>,
116 {
117 RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
118 }
119
120 pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
122 where
123 T: Copy,
124 S: BorrowMut<[T]>,
125 S2: DenseStorage<T>,
126 {
127 assert_eq!(self.dimensions(), source.dimensions());
128 self.par_rows_mut()
131 .zip(source.par_row_slices())
132 .for_each(|(dst, src)| {
133 dst.copy_from_slice(src);
134 });
135 }
136
137 pub fn flatten_to_base<F: Field>(self) -> RowMajorMatrix<F>
139 where
140 T: ExtensionField<F>,
141 {
142 let width = self.width * T::DIMENSION;
143 let values = T::flatten_to_base(self.values.to_vec());
144 RowMajorMatrix::new(values, width)
145 }
146
147 pub fn row_slices(&self) -> impl DoubleEndedIterator<Item = &[T]> {
149 self.values.borrow().chunks_exact(self.width)
150 }
151
152 pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
154 where
155 T: Sync,
156 {
157 self.values.borrow().par_chunks_exact(self.width)
158 }
159
160 pub fn row_mut(&mut self, r: usize) -> &mut [T]
165 where
166 S: BorrowMut<[T]>,
167 {
168 &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
169 }
170
171 pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
173 where
174 S: BorrowMut<[T]>,
175 {
176 self.values.borrow_mut().chunks_exact_mut(self.width)
177 }
178
179 pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
181 where
182 T: 'a + Send,
183 S: BorrowMut<[T]>,
184 {
185 self.values.borrow_mut().par_chunks_exact_mut(self.width)
186 }
187
188 pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
193 where
194 P: PackedValue<Value = T>,
195 S: BorrowMut<[T]>,
196 {
197 P::pack_slice_with_suffix_mut(self.row_mut(r))
198 }
199
200 pub fn scale_row(&mut self, r: usize, scale: T)
205 where
206 T: Field,
207 S: BorrowMut<[T]>,
208 {
209 scale_slice_in_place_single_core(self.row_mut(r), scale);
210 }
211
212 pub fn par_scale_row(&mut self, r: usize, scale: T)
221 where
222 T: Field,
223 S: BorrowMut<[T]>,
224 {
225 par_scale_slice_in_place(self.row_mut(r), scale);
226 }
227
228 pub fn scale(&mut self, scale: T)
230 where
231 T: Field,
232 S: BorrowMut<[T]>,
233 {
234 par_scale_slice_in_place(self.values.borrow_mut(), scale);
235 }
236
237 pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<'_, T>, RowMajorMatrixView<'_, T>) {
242 let (lo, hi) = self.values.borrow().split_at(r * self.width);
243 (
244 DenseMatrix::new(lo, self.width),
245 DenseMatrix::new(hi, self.width),
246 )
247 }
248
249 pub fn split_rows_mut(
254 &mut self,
255 r: usize,
256 ) -> (RowMajorMatrixViewMut<'_, T>, RowMajorMatrixViewMut<'_, T>)
257 where
258 S: BorrowMut<[T]>,
259 {
260 let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
261 (
262 DenseMatrix::new(lo, self.width),
263 DenseMatrix::new(hi, self.width),
264 )
265 }
266
267 pub fn par_row_chunks(
271 &self,
272 chunk_rows: usize,
273 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
274 where
275 T: Send,
276 {
277 self.values
278 .borrow()
279 .par_chunks(self.width * chunk_rows)
280 .map(|slice| RowMajorMatrixView::new(slice, self.width))
281 }
282
283 pub fn par_row_chunks_exact(
287 &self,
288 chunk_rows: usize,
289 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
290 where
291 T: Send,
292 {
293 self.values
294 .borrow()
295 .par_chunks_exact(self.width * chunk_rows)
296 .map(|slice| RowMajorMatrixView::new(slice, self.width))
297 }
298
299 pub fn par_row_chunks_mut(
303 &mut self,
304 chunk_rows: usize,
305 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
306 where
307 T: Send,
308 S: BorrowMut<[T]>,
309 {
310 self.values
311 .borrow_mut()
312 .par_chunks_mut(self.width * chunk_rows)
313 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
314 }
315
316 pub fn row_chunks_exact_mut(
321 &mut self,
322 chunk_rows: usize,
323 ) -> impl Iterator<Item = RowMajorMatrixViewMut<'_, T>>
324 where
325 T: Send,
326 S: BorrowMut<[T]>,
327 {
328 self.values
329 .borrow_mut()
330 .chunks_exact_mut(self.width * chunk_rows)
331 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
332 }
333
334 pub fn par_row_chunks_exact_mut(
339 &mut self,
340 chunk_rows: usize,
341 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
342 where
343 T: Send,
344 S: BorrowMut<[T]>,
345 {
346 self.values
347 .borrow_mut()
348 .par_chunks_exact_mut(self.width * chunk_rows)
349 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
350 }
351
352 pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
357 where
358 S: BorrowMut<[T]>,
359 {
360 debug_assert_ne!(row_1, row_2);
361 let start_1 = row_1 * self.width;
362 let start_2 = row_2 * self.width;
363 let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
364 (&mut lo[start_1..][..self.width], &mut hi[..self.width])
365 }
366
367 #[allow(clippy::type_complexity)]
374 pub fn packed_row_pair_mut<P>(
375 &mut self,
376 row_1: usize,
377 row_2: usize,
378 ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
379 where
380 S: BorrowMut<[T]>,
381 P: PackedValue<Value = T>,
382 {
383 let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
384 (
385 P::pack_slice_with_suffix_mut(slice_1),
386 P::pack_slice_with_suffix_mut(slice_2),
387 )
388 }
389
390 #[instrument(level = "debug", skip_all)]
393 pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
394 where
395 T: Field,
396 {
397 if added_bits == 0 {
398 return self.to_row_major_matrix();
399 }
400
401 let w = self.width;
411 let mut padded =
412 RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
413 padded
414 .par_row_chunks_exact_mut(1 << added_bits)
415 .zip(self.par_row_slices())
416 .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
417
418 padded
419 }
420}
421
422impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
423 #[inline]
424 fn width(&self) -> usize {
425 self.width
426 }
427
428 #[inline]
429 fn height(&self) -> usize {
430 self.values
431 .borrow()
432 .len()
433 .checked_div(self.width)
434 .unwrap_or(0)
435 }
436
437 #[inline]
438 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
439 unsafe {
440 self.values
442 .borrow()
443 .get_unchecked(r * self.width + c)
444 .clone()
445 }
446 }
447
448 #[inline]
449 unsafe fn row_subseq_unchecked(
450 &self,
451 r: usize,
452 start: usize,
453 end: usize,
454 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
455 unsafe {
456 self.values
458 .borrow()
459 .get_unchecked(r * self.width + start..r * self.width + end)
460 .iter()
461 .cloned()
462 }
463 }
464
465 #[inline]
466 unsafe fn row_subslice_unchecked(
467 &self,
468 r: usize,
469 start: usize,
470 end: usize,
471 ) -> impl Deref<Target = [T]> {
472 unsafe {
473 self.values
475 .borrow()
476 .get_unchecked(r * self.width + start..r * self.width + end)
477 }
478 }
479
480 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
481 where
482 Self: Sized,
483 T: Clone,
484 {
485 RowMajorMatrix::new(self.values.to_vec(), self.width)
486 }
487
488 #[inline]
489 fn horizontally_packed_row<'a, P>(
490 &'a self,
491 r: usize,
492 ) -> (
493 impl Iterator<Item = P> + Send + Sync,
494 impl Iterator<Item = T> + Send + Sync,
495 )
496 where
497 P: PackedValue<Value = T>,
498 T: Clone + 'a,
499 {
500 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
501 let (packed, sfx) = P::pack_slice_with_suffix(buf);
502 (packed.iter().copied(), sfx.iter().cloned())
503 }
504
505 #[inline]
506 fn padded_horizontally_packed_row<'a, P>(
507 &'a self,
508 r: usize,
509 ) -> impl Iterator<Item = P> + Send + Sync
510 where
511 P: PackedValue<Value = T>,
512 T: Clone + Default + 'a,
513 {
514 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
515 let (packed, sfx) = P::pack_slice_with_suffix(buf);
516 packed.iter().copied().chain(
517 (!sfx.is_empty()).then(|| P::from_fn(|i| sfx.get(i).cloned().unwrap_or_default())),
518 )
519 }
520}
521
522impl<T: Clone + Default + Send + Sync> DenseMatrix<T> {
523 pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
524 RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
525 }
526
527 pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
528 where
529 StandardUniform: Distribution<T>,
530 {
531 let values = rng.sample_iter(StandardUniform).take(rows * cols).collect();
532 Self::new(values, cols)
533 }
534
535 pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
536 where
537 T: Field,
538 StandardUniform: Distribution<T>,
539 {
540 let values = rng
541 .sample_iter(StandardUniform)
542 .filter(|x| !x.is_zero())
543 .take(rows * cols)
544 .collect();
545 Self::new(values, cols)
546 }
547
548 pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
549 assert!(new_height >= self.height());
550 self.values.resize(self.width * new_height, fill);
551 }
552
553 pub fn pad_to_power_of_two_height(&mut self, fill: T) {
563 let target_height = self.height().next_power_of_two();
565
566 self.values.resize(self.width * target_height, fill);
569 }
570}
571
572impl<T: Copy + Default + Send + Sync, V: DenseStorage<T>> DenseMatrix<T, V> {
573 pub fn transpose(&self) -> RowMajorMatrix<T> {
575 let nelts = self.height() * self.width();
576 let mut values = vec![T::default(); nelts];
577 p3_util::transpose::transpose(
578 self.values.borrow(),
579 &mut values,
580 self.width(),
581 self.height(),
582 );
583 RowMajorMatrix::new(values, self.height())
584 }
585
586 pub fn transpose_into<W: DenseStorage<T> + BorrowMut<[T]>>(
588 &self,
589 other: &mut DenseMatrix<T, W>,
590 ) {
591 assert_eq!(self.height(), other.width());
592 assert_eq!(other.height(), self.width());
593 p3_util::transpose::transpose(
594 self.values.borrow(),
595 other.values.borrow_mut(),
596 self.width(),
597 self.height(),
598 );
599 }
600}
601
602impl<'a, T: Clone + Default + Send + Sync> RowMajorMatrixView<'a, T> {
603 pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
604 RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use p3_baby_bear::BabyBear;
611 use p3_field::FieldArray;
612
613 use super::*;
614
615 #[test]
616 fn test_new() {
617 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
618 assert_eq!(matrix.width, 2);
619 assert_eq!(matrix.height(), 3);
620 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6]);
621 }
622
623 #[test]
624 fn test_new_row() {
625 let matrix = RowMajorMatrix::new_row(vec![1, 2, 3]);
626 assert_eq!(matrix.width, 3);
627 assert_eq!(matrix.height(), 1);
628 }
629
630 #[test]
631 fn test_new_col() {
632 let matrix = RowMajorMatrix::new_col(vec![1, 2, 3]);
633 assert_eq!(matrix.width, 1);
634 assert_eq!(matrix.height(), 3);
635 }
636
637 #[test]
638 fn test_height_with_zero_width() {
639 let matrix: DenseMatrix<i32> = RowMajorMatrix::new(vec![], 0);
640 assert_eq!(matrix.height(), 0);
641 }
642
643 #[test]
644 fn test_get_methods() {
645 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2); assert_eq!(matrix.get(0, 0), Some(1));
647 assert_eq!(matrix.get(1, 1), Some(4));
648 assert_eq!(matrix.get(2, 0), Some(5));
649 unsafe {
650 assert_eq!(matrix.get_unchecked(0, 1), 2);
651 assert_eq!(matrix.get_unchecked(1, 0), 3);
652 assert_eq!(matrix.get_unchecked(2, 1), 6);
653 }
654 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 2), None); }
657
658 #[test]
659 fn test_row_methods() {
660 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4); let row: Vec<_> = matrix.row(1).unwrap().into_iter().collect();
662 assert_eq!(row, vec![5, 6, 7, 8]);
663 unsafe {
664 let row: Vec<_> = matrix.row_unchecked(0).into_iter().collect();
665 assert_eq!(row, vec![1, 2, 3, 4]);
666 let row: Vec<_> = matrix.row_subseq_unchecked(0, 0, 3).into_iter().collect();
667 assert_eq!(row, vec![1, 2, 3]);
668 let row: Vec<_> = matrix.row_subseq_unchecked(0, 1, 3).into_iter().collect();
669 assert_eq!(row, vec![2, 3]);
670 let row: Vec<_> = matrix.row_subseq_unchecked(0, 2, 4).into_iter().collect();
671 assert_eq!(row, vec![3, 4]);
672 }
673 assert!(matrix.row(2).is_none()); }
675
676 #[test]
677 fn test_row_slice_methods() {
678 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], 3); let slice0 = matrix.row_slice(0);
680 let slice2 = matrix.row_slice(2);
681 assert_eq!(slice0.unwrap().deref(), &[1, 2, 3]);
682 assert_eq!(slice2.unwrap().deref(), &[7, 8, 9]);
683 unsafe {
684 assert_eq!(&[1, 2, 3], matrix.row_slice_unchecked(0).deref());
685 assert_eq!(&[7, 8, 9], matrix.row_slice_unchecked(2).deref());
686
687 assert_eq!(&[1, 2, 3], matrix.row_subslice_unchecked(0, 0, 3).deref());
688 assert_eq!(&[8], matrix.row_subslice_unchecked(2, 1, 2).deref());
689 }
690 assert!(matrix.row_slice(3).is_none()); }
692
693 #[test]
694 fn test_as_view() {
695 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
696 let view = matrix.as_view();
697 assert_eq!(view.values, &[1, 2, 3, 4]);
698 assert_eq!(view.width, 2);
699 }
700
701 #[test]
702 fn test_as_view_mut() {
703 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
704 let view = matrix.as_view_mut();
705 view.values[0] = 10;
706 assert_eq!(matrix.values, vec![10, 2, 3, 4]);
707 }
708
709 #[test]
710 fn test_copy_from() {
711 let mut matrix1 = RowMajorMatrix::new(vec![0, 0, 0, 0], 2);
712 let matrix2 = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
713 matrix1.copy_from(&matrix2);
714 assert_eq!(matrix1.values, vec![1, 2, 3, 4]);
715 }
716
717 #[test]
718 fn test_split_rows() {
719 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
720 let (top, bottom) = matrix.split_rows(1);
721 assert_eq!(top.values, vec![1, 2]);
722 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
723 }
724
725 #[test]
726 fn test_split_rows_mut() {
727 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
728 let (top, bottom) = matrix.split_rows_mut(1);
729 assert_eq!(top.values, vec![1, 2]);
730 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
731 }
732
733 #[test]
734 fn test_row_mut() {
735 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
736 matrix.row_mut(1)[0] = 10;
737 assert_eq!(matrix.values, vec![1, 2, 10, 4, 5, 6]);
738 }
739
740 #[test]
741 fn test_bit_reversed_zero_pad() {
742 let matrix = RowMajorMatrix::new(
743 vec![
744 BabyBear::new(1),
745 BabyBear::new(2),
746 BabyBear::new(3),
747 BabyBear::new(4),
748 ],
749 2,
750 );
751 let padded = matrix.bit_reversed_zero_pad(1);
752 assert_eq!(padded.width, 2);
753 assert_eq!(
754 padded.values,
755 vec![
756 BabyBear::new(1),
757 BabyBear::new(2),
758 BabyBear::new(0),
759 BabyBear::new(0),
760 BabyBear::new(3),
761 BabyBear::new(4),
762 BabyBear::new(0),
763 BabyBear::new(0)
764 ]
765 );
766 }
767
768 #[test]
769 fn test_bit_reversed_zero_pad_no_change() {
770 let matrix = RowMajorMatrix::new(
771 vec![
772 BabyBear::new(1),
773 BabyBear::new(2),
774 BabyBear::new(3),
775 BabyBear::new(4),
776 ],
777 2,
778 );
779 let padded = matrix.bit_reversed_zero_pad(0);
780
781 assert_eq!(padded.width, 2);
782 assert_eq!(
783 padded.values,
784 vec![
785 BabyBear::new(1),
786 BabyBear::new(2),
787 BabyBear::new(3),
788 BabyBear::new(4),
789 ]
790 );
791 }
792
793 #[test]
794 fn test_scale() {
795 let mut matrix = RowMajorMatrix::new(
796 vec![
797 BabyBear::new(1),
798 BabyBear::new(2),
799 BabyBear::new(3),
800 BabyBear::new(4),
801 BabyBear::new(5),
802 BabyBear::new(6),
803 ],
804 2,
805 );
806 matrix.scale(BabyBear::new(2));
807 assert_eq!(
808 matrix.values,
809 vec![
810 BabyBear::new(2),
811 BabyBear::new(4),
812 BabyBear::new(6),
813 BabyBear::new(8),
814 BabyBear::new(10),
815 BabyBear::new(12)
816 ]
817 );
818 }
819
820 #[test]
821 fn test_scale_row() {
822 let mut matrix = RowMajorMatrix::new(
823 vec![
824 BabyBear::new(1),
825 BabyBear::new(2),
826 BabyBear::new(3),
827 BabyBear::new(4),
828 BabyBear::new(5),
829 BabyBear::new(6),
830 ],
831 2,
832 );
833 matrix.scale_row(1, BabyBear::new(3));
834 assert_eq!(
835 matrix.values,
836 vec![
837 BabyBear::new(1),
838 BabyBear::new(2),
839 BabyBear::new(9),
840 BabyBear::new(12),
841 BabyBear::new(5),
842 BabyBear::new(6),
843 ]
844 );
845 }
846
847 #[test]
848 fn test_to_row_major_matrix() {
849 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
850 let converted = matrix.to_row_major_matrix();
851
852 assert_eq!(converted.width, 2);
854 assert_eq!(converted.height(), 3);
855 assert_eq!(converted.values, vec![1, 2, 3, 4, 5, 6]);
856 }
857
858 #[test]
859 fn test_horizontally_packed_row() {
860 type Packed = FieldArray<BabyBear, 2>;
861
862 let matrix = RowMajorMatrix::new(
863 vec![
864 BabyBear::new(1),
865 BabyBear::new(2),
866 BabyBear::new(3),
867 BabyBear::new(4),
868 BabyBear::new(5),
869 BabyBear::new(6),
870 ],
871 3,
872 );
873
874 let (packed_iter, suffix_iter) = matrix.horizontally_packed_row::<Packed>(1);
875
876 let packed: Vec<_> = packed_iter.collect();
877 let suffix: Vec<_> = suffix_iter.collect();
878
879 assert_eq!(
880 packed,
881 vec![Packed::from([BabyBear::new(4), BabyBear::new(5)])]
882 );
883 assert_eq!(suffix, vec![BabyBear::new(6)]);
884 }
885
886 #[test]
887 fn test_padded_horizontally_packed_row() {
888 use p3_baby_bear::BabyBear;
889
890 type Packed = FieldArray<BabyBear, 2>;
891
892 let matrix = RowMajorMatrix::new(
893 vec![
894 BabyBear::new(1),
895 BabyBear::new(2),
896 BabyBear::new(3),
897 BabyBear::new(4),
898 BabyBear::new(5),
899 BabyBear::new(6),
900 ],
901 3,
902 );
903
904 let packed_iter = matrix.padded_horizontally_packed_row::<Packed>(1);
905 let packed: Vec<_> = packed_iter.collect();
906
907 assert_eq!(
908 packed,
909 vec![
910 Packed::from([BabyBear::new(4), BabyBear::new(5)]),
911 Packed::from([BabyBear::new(6), BabyBear::new(0)])
912 ]
913 );
914 }
915
916 #[test]
917 fn test_padded_horizontally_packed_row_exact_width() {
918 type Packed = FieldArray<BabyBear, 2>;
919
920 let matrix = RowMajorMatrix::new(
924 vec![
925 BabyBear::new(1),
926 BabyBear::new(2),
927 BabyBear::new(3),
928 BabyBear::new(4),
929 BabyBear::new(5),
930 BabyBear::new(6),
931 BabyBear::new(7),
932 BabyBear::new(8),
933 ],
934 4,
935 );
936
937 let packed: Vec<_> = matrix.padded_horizontally_packed_row::<Packed>(1).collect();
938
939 assert_eq!(packed.len(), 2);
940 assert_eq!(
941 packed,
942 vec![
943 Packed::from([BabyBear::new(5), BabyBear::new(6)]),
944 Packed::from([BabyBear::new(7), BabyBear::new(8)]),
945 ]
946 );
947 }
948
949 #[test]
950 fn test_pad_to_height() {
951 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
952
953 matrix.pad_to_height(4, 9);
958
959 assert_eq!(matrix.height(), 4);
966 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 9, 9, 9, 9, 9, 9]);
967 }
968
969 #[test]
970 fn test_pad_to_power_of_two_height() {
971 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
976 assert_eq!(matrix.height(), 3);
977 matrix.pad_to_power_of_two_height(0);
978 assert_eq!(matrix.height(), 4);
979 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 0, 0]);
981
982 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
987 assert_eq!(matrix.height(), 4);
988 matrix.pad_to_power_of_two_height(99);
989 assert_eq!(matrix.height(), 4);
990 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
992
993 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3], 3);
997 assert_eq!(matrix.height(), 1);
998 matrix.pad_to_power_of_two_height(42);
999 assert_eq!(matrix.height(), 1);
1000 assert_eq!(matrix.values, vec![1, 2, 3]);
1001
1002 let mut matrix = RowMajorMatrix::new(vec![1; 10], 2);
1006 assert_eq!(matrix.height(), 5);
1007 matrix.pad_to_power_of_two_height(-1);
1008 assert_eq!(matrix.height(), 8);
1009 assert_eq!(matrix.values.len(), 16);
1011 assert!(matrix.values[..10].iter().all(|&v| v == 1));
1012 assert!(matrix.values[10..].iter().all(|&v| v == -1));
1013 }
1014
1015 #[test]
1016 fn test_pad_to_power_of_two_height_empty_matrix() {
1017 let mut matrix: RowMajorMatrix<i32> = RowMajorMatrix::new(vec![], 3);
1020 assert_eq!(matrix.height(), 0);
1021 assert_eq!(matrix.width, 3);
1022 matrix.pad_to_power_of_two_height(7);
1023 assert_eq!(matrix.height(), 1);
1025 assert_eq!(matrix.values, vec![7, 7, 7]);
1026 }
1027
1028 #[test]
1029 fn test_transpose_into() {
1030 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
1031
1032 let mut transposed = RowMajorMatrix::new(vec![0; 6], 2);
1037
1038 matrix.transpose_into(&mut transposed);
1039
1040 assert_eq!(transposed.width, 2);
1046 assert_eq!(transposed.height(), 3);
1047 assert_eq!(transposed.values, vec![1, 4, 2, 5, 3, 6]);
1048 }
1049
1050 #[test]
1051 fn test_flatten_to_base() {
1052 let matrix = RowMajorMatrix::new(
1053 vec![
1054 BabyBear::new(2),
1055 BabyBear::new(3),
1056 BabyBear::new(4),
1057 BabyBear::new(5),
1058 ],
1059 2,
1060 );
1061
1062 let flattened: RowMajorMatrix<BabyBear> = matrix.flatten_to_base();
1063
1064 assert_eq!(flattened.width, 2);
1065 assert_eq!(
1066 flattened.values,
1067 vec![
1068 BabyBear::new(2),
1069 BabyBear::new(3),
1070 BabyBear::new(4),
1071 BabyBear::new(5),
1072 ]
1073 );
1074 }
1075
1076 #[test]
1077 fn test_horizontally_packed_row_mut() {
1078 type Packed = FieldArray<BabyBear, 2>;
1079
1080 let mut matrix = RowMajorMatrix::new(
1081 vec![
1082 BabyBear::new(1),
1083 BabyBear::new(2),
1084 BabyBear::new(3),
1085 BabyBear::new(4),
1086 BabyBear::new(5),
1087 BabyBear::new(6),
1088 ],
1089 3,
1090 );
1091
1092 let (packed, suffix) = matrix.horizontally_packed_row_mut::<Packed>(1);
1093 packed[0] = Packed::from([BabyBear::new(9), BabyBear::new(10)]);
1094 suffix[0] = BabyBear::new(11);
1095
1096 assert_eq!(
1097 matrix.values,
1098 vec![
1099 BabyBear::new(1),
1100 BabyBear::new(2),
1101 BabyBear::new(3),
1102 BabyBear::new(9),
1103 BabyBear::new(10),
1104 BabyBear::new(11),
1105 ]
1106 );
1107 }
1108
1109 #[test]
1110 fn test_par_row_chunks() {
1111 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1112
1113 let chunks: Vec<_> = matrix.par_row_chunks(2).collect();
1114
1115 assert_eq!(chunks.len(), 2);
1116 assert_eq!(chunks[0].values, vec![1, 2, 3, 4]);
1117 assert_eq!(chunks[1].values, vec![5, 6, 7, 8]);
1118 }
1119
1120 #[test]
1121 fn test_par_row_chunks_exact() {
1122 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1123
1124 let chunks: Vec<_> = matrix.par_row_chunks_exact(1).collect();
1125
1126 assert_eq!(chunks.len(), 3);
1127 assert_eq!(chunks[0].values, vec![1, 2]);
1128 assert_eq!(chunks[1].values, vec![3, 4]);
1129 assert_eq!(chunks[2].values, vec![5, 6]);
1130 }
1131
1132 #[test]
1133 fn test_par_row_chunks_mut() {
1134 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1135
1136 matrix
1137 .par_row_chunks_mut(2)
1138 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 10));
1139
1140 assert_eq!(matrix.values, vec![11, 12, 13, 14, 15, 16, 17, 18]);
1141 }
1142
1143 #[test]
1144 fn test_row_chunks_exact_mut() {
1145 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1146
1147 for chunk in matrix.row_chunks_exact_mut(1) {
1148 chunk.values.iter_mut().for_each(|x| *x *= 2);
1149 }
1150
1151 assert_eq!(matrix.values, vec![2, 4, 6, 8, 10, 12]);
1152 }
1153
1154 #[test]
1155 fn test_par_row_chunks_exact_mut() {
1156 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1157
1158 matrix
1159 .par_row_chunks_exact_mut(1)
1160 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 5));
1161
1162 assert_eq!(matrix.values, vec![6, 7, 8, 9, 10, 11]);
1163 }
1164
1165 #[test]
1166 fn test_row_pair_mut() {
1167 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1168
1169 let (row1, row2) = matrix.row_pair_mut(0, 2);
1170 row1[0] = 9;
1171 row2[1] = 10;
1172
1173 assert_eq!(matrix.values, vec![9, 2, 3, 4, 5, 10]);
1174 }
1175
1176 #[test]
1177 fn test_packed_row_pair_mut() {
1178 type Packed = FieldArray<BabyBear, 2>;
1179
1180 let mut matrix = RowMajorMatrix::new(
1181 vec![
1182 BabyBear::new(1),
1183 BabyBear::new(2),
1184 BabyBear::new(3),
1185 BabyBear::new(4),
1186 BabyBear::new(5),
1187 BabyBear::new(6),
1188 ],
1189 3,
1190 );
1191
1192 let ((packed1, sfx1), (packed2, sfx2)) = matrix.packed_row_pair_mut::<Packed>(0, 1);
1193 packed1[0] = Packed::from([BabyBear::new(7), BabyBear::new(8)]);
1194 packed2[0] = Packed::from([BabyBear::new(33), BabyBear::new(44)]);
1195 sfx1[0] = BabyBear::new(99);
1196 sfx2[0] = BabyBear::new(9);
1197
1198 assert_eq!(
1199 matrix.values,
1200 vec![
1201 BabyBear::new(7),
1202 BabyBear::new(8),
1203 BabyBear::new(99),
1204 BabyBear::new(33),
1205 BabyBear::new(44),
1206 BabyBear::new(9),
1207 ]
1208 );
1209 }
1210
1211 #[test]
1212 fn test_transpose_square_matrix() {
1213 const START_INDEX: usize = 1;
1214 const VALUE_LEN: usize = 9;
1215 const WIDTH: usize = 3;
1216 const HEIGHT: usize = 3;
1217
1218 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1219 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1220 let transposed = matrix.transpose();
1221 let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
1222 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1223 assert_eq!(transposed, should_be_transposed);
1224 }
1225
1226 #[test]
1227 fn test_transpose_row_matrix() {
1228 const START_INDEX: usize = 1;
1229 const VALUE_LEN: usize = 30;
1230 const WIDTH: usize = 1;
1231 const HEIGHT: usize = 30;
1232
1233 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1234 let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
1235 let transposed = matrix.transpose();
1236 let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
1237 assert_eq!(transposed, should_be_transposed);
1238 }
1239
1240 #[test]
1241 fn test_transpose_rectangular_matrix() {
1242 const START_INDEX: usize = 1;
1243 const VALUE_LEN: usize = 30;
1244 const WIDTH: usize = 5;
1245 const HEIGHT: usize = 6;
1246
1247 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1248 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1249 let transposed = matrix.transpose();
1250 let should_be_transposed_values = vec![
1251 1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
1252 5, 10, 15, 20, 25, 30,
1253 ];
1254 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1255 assert_eq!(transposed, should_be_transposed);
1256 }
1257
1258 #[test]
1259 fn test_transpose_larger_rectangular_matrix() {
1260 const START_INDEX: usize = 1;
1261 const VALUE_LEN: usize = 131072; const WIDTH: usize = 256;
1263 const HEIGHT: usize = 512;
1264
1265 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1266 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1267 let transposed = matrix.transpose();
1268
1269 assert_eq!(transposed.width(), HEIGHT);
1270 assert_eq!(transposed.height(), WIDTH);
1271
1272 for col_index in 0..WIDTH {
1273 for row_index in 0..HEIGHT {
1274 assert_eq!(
1275 matrix.values[row_index * WIDTH + col_index],
1276 transposed.values[col_index * HEIGHT + row_index]
1277 );
1278 }
1279 }
1280 }
1281
1282 #[test]
1283 fn test_transpose_very_large_rectangular_matrix() {
1284 const START_INDEX: usize = 1;
1285 const VALUE_LEN: usize = 1048576; const WIDTH: usize = 1024;
1287 const HEIGHT: usize = 1024;
1288
1289 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1290 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1291 let transposed = matrix.transpose();
1292
1293 assert_eq!(transposed.width(), HEIGHT);
1294 assert_eq!(transposed.height(), WIDTH);
1295
1296 for col_index in 0..WIDTH {
1297 for row_index in 0..HEIGHT {
1298 assert_eq!(
1299 matrix.values[row_index * WIDTH + col_index],
1300 transposed.values[col_index * HEIGHT + row_index]
1301 );
1302 }
1303 }
1304 }
1305
1306 #[test]
1307 fn test_vertically_packed_row_pair() {
1308 type Packed = FieldArray<BabyBear, 2>;
1309
1310 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1311
1312 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 2);
1314
1315 assert_eq!(
1331 packed,
1332 (1..5)
1333 .chain(9..13)
1334 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1335 .collect::<Vec<_>>(),
1336 );
1337 }
1338
1339 #[test]
1340 fn test_vertically_packed_row_pair_overlap() {
1341 type Packed = FieldArray<BabyBear, 2>;
1342
1343 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1344
1345 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 1);
1362
1363 assert_eq!(
1364 packed,
1365 (1..5)
1366 .chain(5..9)
1367 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1368 .collect::<Vec<_>>(),
1369 );
1370 }
1371
1372 #[test]
1373 fn test_vertically_packed_row_pair_wraparound_start_1() {
1374 use p3_baby_bear::BabyBear;
1375 use p3_field::FieldArray;
1376
1377 type Packed = FieldArray<BabyBear, 2>;
1378
1379 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1380
1381 let packed = matrix.vertically_packed_row_pair::<Packed>(1, 2);
1400
1401 assert_eq!(
1402 packed,
1403 vec![
1404 Packed::from([BabyBear::new(5), BabyBear::new(9)]),
1405 Packed::from([BabyBear::new(6), BabyBear::new(10)]),
1406 Packed::from([BabyBear::new(7), BabyBear::new(11)]),
1407 Packed::from([BabyBear::new(8), BabyBear::new(12)]),
1408 Packed::from([BabyBear::new(13), BabyBear::new(1)]),
1409 Packed::from([BabyBear::new(14), BabyBear::new(2)]),
1410 Packed::from([BabyBear::new(15), BabyBear::new(3)]),
1411 Packed::from([BabyBear::new(16), BabyBear::new(4)]),
1412 ]
1413 );
1414 }
1415}