1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::iter;
6use core::marker::PhantomData;
7use core::ops::Deref;
8
9use p3_field::{
10 ExtensionField, Field, PackedValue, par_scale_slice_in_place, scale_slice_in_place_single_core,
11};
12use p3_maybe_rayon::prelude::*;
13use rand::Rng;
14use rand::distr::{Distribution, StandardUniform};
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::Matrix;
19
20#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct DenseMatrix<T, V = Vec<T>> {
25 pub values: V,
27 pub width: usize,
31 _phantom: PhantomData<T>,
35}
36
37pub type RowMajorMatrix<T> = DenseMatrix<T>;
38pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
39pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
40pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
41
42pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
43 fn to_vec(self) -> Vec<T>;
44}
45
46impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
48 fn to_vec(self) -> Self {
49 self
50 }
51}
52
53impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
54 fn to_vec(self) -> Vec<T> {
55 <[T]>::to_vec(self)
56 }
57}
58
59impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
60 fn to_vec(self) -> Vec<T> {
61 <[T]>::to_vec(self)
62 }
63}
64
65impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
66 fn to_vec(self) -> Vec<T> {
67 self.into_owned()
68 }
69}
70
71impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
72 #[must_use]
75 pub fn default(width: usize, height: usize) -> Self {
76 Self::new(vec![T::default(); width * height], width)
77 }
78}
79
80impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
81 #[must_use]
86 pub fn new(values: S, width: usize) -> Self {
87 debug_assert!(values.borrow().len().is_multiple_of(width));
88 Self {
89 values,
90 width,
91 _phantom: PhantomData,
92 }
93 }
94
95 #[must_use]
97 pub fn new_row(values: S) -> Self {
98 let width = values.borrow().len();
99 Self::new(values, width)
100 }
101
102 #[must_use]
104 pub fn new_col(values: S) -> Self {
105 Self::new(values, 1)
106 }
107
108 pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
110 RowMajorMatrixView::new(self.values.borrow(), self.width)
111 }
112
113 pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
115 where
116 S: BorrowMut<[T]>,
117 {
118 RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
119 }
120
121 pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
123 where
124 T: Copy,
125 S: BorrowMut<[T]>,
126 S2: DenseStorage<T>,
127 {
128 assert_eq!(self.dimensions(), source.dimensions());
129 self.par_rows_mut()
132 .zip(source.par_row_slices())
133 .for_each(|(dst, src)| {
134 dst.copy_from_slice(src);
135 });
136 }
137
138 pub fn flatten_to_base<F: Field>(self) -> RowMajorMatrix<F>
140 where
141 T: ExtensionField<F>,
142 {
143 let width = self.width * T::DIMENSION;
144 let values = T::flatten_to_base(self.values.to_vec());
145 RowMajorMatrix::new(values, width)
146 }
147
148 pub fn row_slices(&self) -> impl DoubleEndedIterator<Item = &[T]> {
150 self.values.borrow().chunks_exact(self.width)
151 }
152
153 pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
155 where
156 T: Sync,
157 {
158 self.values.borrow().par_chunks_exact(self.width)
159 }
160
161 pub fn row_mut(&mut self, r: usize) -> &mut [T]
166 where
167 S: BorrowMut<[T]>,
168 {
169 &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
170 }
171
172 pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
174 where
175 S: BorrowMut<[T]>,
176 {
177 self.values.borrow_mut().chunks_exact_mut(self.width)
178 }
179
180 pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
182 where
183 T: 'a + Send,
184 S: BorrowMut<[T]>,
185 {
186 self.values.borrow_mut().par_chunks_exact_mut(self.width)
187 }
188
189 pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
194 where
195 P: PackedValue<Value = T>,
196 S: BorrowMut<[T]>,
197 {
198 P::pack_slice_with_suffix_mut(self.row_mut(r))
199 }
200
201 pub fn scale_row(&mut self, r: usize, scale: T)
206 where
207 T: Field,
208 S: BorrowMut<[T]>,
209 {
210 scale_slice_in_place_single_core(self.row_mut(r), scale);
211 }
212
213 pub fn par_scale_row(&mut self, r: usize, scale: T)
222 where
223 T: Field,
224 S: BorrowMut<[T]>,
225 {
226 par_scale_slice_in_place(self.row_mut(r), scale);
227 }
228
229 pub fn scale(&mut self, scale: T)
231 where
232 T: Field,
233 S: BorrowMut<[T]>,
234 {
235 par_scale_slice_in_place(self.values.borrow_mut(), scale);
236 }
237
238 pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<'_, T>, RowMajorMatrixView<'_, T>) {
243 let (lo, hi) = self.values.borrow().split_at(r * self.width);
244 (
245 DenseMatrix::new(lo, self.width),
246 DenseMatrix::new(hi, self.width),
247 )
248 }
249
250 pub fn split_rows_mut(
255 &mut self,
256 r: usize,
257 ) -> (RowMajorMatrixViewMut<'_, T>, RowMajorMatrixViewMut<'_, T>)
258 where
259 S: BorrowMut<[T]>,
260 {
261 let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
262 (
263 DenseMatrix::new(lo, self.width),
264 DenseMatrix::new(hi, self.width),
265 )
266 }
267
268 pub fn par_row_chunks(
272 &self,
273 chunk_rows: usize,
274 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
275 where
276 T: Send,
277 {
278 self.values
279 .borrow()
280 .par_chunks(self.width * chunk_rows)
281 .map(|slice| RowMajorMatrixView::new(slice, self.width))
282 }
283
284 pub fn par_row_chunks_exact(
288 &self,
289 chunk_rows: usize,
290 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<'_, T>>
291 where
292 T: Send,
293 {
294 self.values
295 .borrow()
296 .par_chunks_exact(self.width * chunk_rows)
297 .map(|slice| RowMajorMatrixView::new(slice, self.width))
298 }
299
300 pub fn par_row_chunks_mut(
304 &mut self,
305 chunk_rows: usize,
306 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
307 where
308 T: Send,
309 S: BorrowMut<[T]>,
310 {
311 self.values
312 .borrow_mut()
313 .par_chunks_mut(self.width * chunk_rows)
314 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
315 }
316
317 pub fn row_chunks_exact_mut(
322 &mut self,
323 chunk_rows: usize,
324 ) -> impl Iterator<Item = RowMajorMatrixViewMut<'_, T>>
325 where
326 T: Send,
327 S: BorrowMut<[T]>,
328 {
329 self.values
330 .borrow_mut()
331 .chunks_exact_mut(self.width * chunk_rows)
332 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
333 }
334
335 pub fn par_row_chunks_exact_mut(
340 &mut self,
341 chunk_rows: usize,
342 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<'_, T>>
343 where
344 T: Send,
345 S: BorrowMut<[T]>,
346 {
347 self.values
348 .borrow_mut()
349 .par_chunks_exact_mut(self.width * chunk_rows)
350 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
351 }
352
353 pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
358 where
359 S: BorrowMut<[T]>,
360 {
361 debug_assert_ne!(row_1, row_2);
362 let start_1 = row_1 * self.width;
363 let start_2 = row_2 * self.width;
364 let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
365 (&mut lo[start_1..][..self.width], &mut hi[..self.width])
366 }
367
368 #[allow(clippy::type_complexity)]
375 pub fn packed_row_pair_mut<P>(
376 &mut self,
377 row_1: usize,
378 row_2: usize,
379 ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
380 where
381 S: BorrowMut<[T]>,
382 P: PackedValue<Value = T>,
383 {
384 let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
385 (
386 P::pack_slice_with_suffix_mut(slice_1),
387 P::pack_slice_with_suffix_mut(slice_2),
388 )
389 }
390
391 #[instrument(level = "debug", skip_all)]
394 pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
395 where
396 T: Field,
397 {
398 if added_bits == 0 {
399 return self.to_row_major_matrix();
400 }
401
402 let w = self.width;
412 let mut padded =
413 RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
414 padded
415 .par_row_chunks_exact_mut(1 << added_bits)
416 .zip(self.par_row_slices())
417 .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
418
419 padded
420 }
421}
422
423impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
424 #[inline]
425 fn width(&self) -> usize {
426 self.width
427 }
428
429 #[inline]
430 fn height(&self) -> usize {
431 if self.width == 0 {
432 0
433 } else {
434 self.values.borrow().len() / self.width
435 }
436 }
437
438 #[inline]
439 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
440 unsafe {
441 self.values
443 .borrow()
444 .get_unchecked(r * self.width + c)
445 .clone()
446 }
447 }
448
449 #[inline]
450 unsafe fn row_subseq_unchecked(
451 &self,
452 r: usize,
453 start: usize,
454 end: usize,
455 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
456 unsafe {
457 self.values
459 .borrow()
460 .get_unchecked(r * self.width + start..r * self.width + end)
461 .iter()
462 .cloned()
463 }
464 }
465
466 #[inline]
467 unsafe fn row_subslice_unchecked(
468 &self,
469 r: usize,
470 start: usize,
471 end: usize,
472 ) -> impl Deref<Target = [T]> {
473 unsafe {
474 self.values
476 .borrow()
477 .get_unchecked(r * self.width + start..r * self.width + end)
478 }
479 }
480
481 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
482 where
483 Self: Sized,
484 T: Clone,
485 {
486 RowMajorMatrix::new(self.values.to_vec(), self.width)
487 }
488
489 #[inline]
490 fn horizontally_packed_row<'a, P>(
491 &'a self,
492 r: usize,
493 ) -> (
494 impl Iterator<Item = P> + Send + Sync,
495 impl Iterator<Item = T> + Send + Sync,
496 )
497 where
498 P: PackedValue<Value = T>,
499 T: Clone + 'a,
500 {
501 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
502 let (packed, sfx) = P::pack_slice_with_suffix(buf);
503 (packed.iter().copied(), sfx.iter().cloned())
504 }
505
506 #[inline]
507 fn padded_horizontally_packed_row<'a, P>(
508 &'a self,
509 r: usize,
510 ) -> impl Iterator<Item = P> + Send + Sync
511 where
512 P: PackedValue<Value = T>,
513 T: Clone + Default + 'a,
514 {
515 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
516 let (packed, sfx) = P::pack_slice_with_suffix(buf);
517 packed.iter().copied().chain(iter::once(P::from_fn(|i| {
518 sfx.get(i).cloned().unwrap_or_default()
519 })))
520 }
521}
522
523impl<T: Clone + Default + Send + Sync> DenseMatrix<T> {
524 pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
525 RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
526 }
527
528 pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
529 where
530 StandardUniform: Distribution<T>,
531 {
532 let values = rng.sample_iter(StandardUniform).take(rows * cols).collect();
533 Self::new(values, cols)
534 }
535
536 pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
537 where
538 T: Field,
539 StandardUniform: Distribution<T>,
540 {
541 let values = rng
542 .sample_iter(StandardUniform)
543 .filter(|x| !x.is_zero())
544 .take(rows * cols)
545 .collect();
546 Self::new(values, cols)
547 }
548
549 pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
550 assert!(new_height >= self.height());
551 self.values.resize(self.width * new_height, fill);
552 }
553
554 pub fn pad_to_power_of_two_height(&mut self, fill: T) {
564 let target_height = self.height().next_power_of_two();
566
567 self.values.resize(self.width * target_height, fill);
570 }
571}
572
573impl<T: Copy + Default + Send + Sync, V: DenseStorage<T>> DenseMatrix<T, V> {
574 pub fn transpose(&self) -> RowMajorMatrix<T> {
576 let nelts = self.height() * self.width();
577 let mut values = vec![T::default(); nelts];
578 transpose::transpose(
579 self.values.borrow(),
580 &mut values,
581 self.width(),
582 self.height(),
583 );
584 RowMajorMatrix::new(values, self.height())
585 }
586
587 pub fn transpose_into<W: DenseStorage<T> + BorrowMut<[T]>>(
589 &self,
590 other: &mut DenseMatrix<T, W>,
591 ) {
592 assert_eq!(self.height(), other.width());
593 assert_eq!(other.height(), self.width());
594 transpose::transpose(
595 self.values.borrow(),
596 other.values.borrow_mut(),
597 self.width(),
598 self.height(),
599 );
600 }
601}
602
603impl<'a, T: Clone + Default + Send + Sync> RowMajorMatrixView<'a, T> {
604 pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
605 RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use p3_baby_bear::BabyBear;
612 use p3_field::FieldArray;
613
614 use super::*;
615
616 #[test]
617 fn test_new() {
618 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
619 assert_eq!(matrix.width, 2);
620 assert_eq!(matrix.height(), 3);
621 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6]);
622 }
623
624 #[test]
625 fn test_new_row() {
626 let matrix = RowMajorMatrix::new_row(vec![1, 2, 3]);
627 assert_eq!(matrix.width, 3);
628 assert_eq!(matrix.height(), 1);
629 }
630
631 #[test]
632 fn test_new_col() {
633 let matrix = RowMajorMatrix::new_col(vec![1, 2, 3]);
634 assert_eq!(matrix.width, 1);
635 assert_eq!(matrix.height(), 3);
636 }
637
638 #[test]
639 fn test_height_with_zero_width() {
640 let matrix: DenseMatrix<i32> = RowMajorMatrix::new(vec![], 0);
641 assert_eq!(matrix.height(), 0);
642 }
643
644 #[test]
645 fn test_get_methods() {
646 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2); assert_eq!(matrix.get(0, 0), Some(1));
648 assert_eq!(matrix.get(1, 1), Some(4));
649 assert_eq!(matrix.get(2, 0), Some(5));
650 unsafe {
651 assert_eq!(matrix.get_unchecked(0, 1), 2);
652 assert_eq!(matrix.get_unchecked(1, 0), 3);
653 assert_eq!(matrix.get_unchecked(2, 1), 6);
654 }
655 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 2), None); }
658
659 #[test]
660 fn test_row_methods() {
661 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4); let row: Vec<_> = matrix.row(1).unwrap().into_iter().collect();
663 assert_eq!(row, vec![5, 6, 7, 8]);
664 unsafe {
665 let row: Vec<_> = matrix.row_unchecked(0).into_iter().collect();
666 assert_eq!(row, vec![1, 2, 3, 4]);
667 let row: Vec<_> = matrix.row_subseq_unchecked(0, 0, 3).into_iter().collect();
668 assert_eq!(row, vec![1, 2, 3]);
669 let row: Vec<_> = matrix.row_subseq_unchecked(0, 1, 3).into_iter().collect();
670 assert_eq!(row, vec![2, 3]);
671 let row: Vec<_> = matrix.row_subseq_unchecked(0, 2, 4).into_iter().collect();
672 assert_eq!(row, vec![3, 4]);
673 }
674 assert!(matrix.row(2).is_none()); }
676
677 #[test]
678 fn test_row_slice_methods() {
679 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], 3); let slice0 = matrix.row_slice(0);
681 let slice2 = matrix.row_slice(2);
682 assert_eq!(slice0.unwrap().deref(), &[1, 2, 3]);
683 assert_eq!(slice2.unwrap().deref(), &[7, 8, 9]);
684 unsafe {
685 assert_eq!(&[1, 2, 3], matrix.row_slice_unchecked(0).deref());
686 assert_eq!(&[7, 8, 9], matrix.row_slice_unchecked(2).deref());
687
688 assert_eq!(&[1, 2, 3], matrix.row_subslice_unchecked(0, 0, 3).deref());
689 assert_eq!(&[8], matrix.row_subslice_unchecked(2, 1, 2).deref());
690 }
691 assert!(matrix.row_slice(3).is_none()); }
693
694 #[test]
695 fn test_as_view() {
696 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
697 let view = matrix.as_view();
698 assert_eq!(view.values, &[1, 2, 3, 4]);
699 assert_eq!(view.width, 2);
700 }
701
702 #[test]
703 fn test_as_view_mut() {
704 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
705 let view = matrix.as_view_mut();
706 view.values[0] = 10;
707 assert_eq!(matrix.values, vec![10, 2, 3, 4]);
708 }
709
710 #[test]
711 fn test_copy_from() {
712 let mut matrix1 = RowMajorMatrix::new(vec![0, 0, 0, 0], 2);
713 let matrix2 = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
714 matrix1.copy_from(&matrix2);
715 assert_eq!(matrix1.values, vec![1, 2, 3, 4]);
716 }
717
718 #[test]
719 fn test_split_rows() {
720 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
721 let (top, bottom) = matrix.split_rows(1);
722 assert_eq!(top.values, vec![1, 2]);
723 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
724 }
725
726 #[test]
727 fn test_split_rows_mut() {
728 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
729 let (top, bottom) = matrix.split_rows_mut(1);
730 assert_eq!(top.values, vec![1, 2]);
731 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
732 }
733
734 #[test]
735 fn test_row_mut() {
736 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
737 matrix.row_mut(1)[0] = 10;
738 assert_eq!(matrix.values, vec![1, 2, 10, 4, 5, 6]);
739 }
740
741 #[test]
742 fn test_bit_reversed_zero_pad() {
743 let matrix = RowMajorMatrix::new(
744 vec![
745 BabyBear::new(1),
746 BabyBear::new(2),
747 BabyBear::new(3),
748 BabyBear::new(4),
749 ],
750 2,
751 );
752 let padded = matrix.bit_reversed_zero_pad(1);
753 assert_eq!(padded.width, 2);
754 assert_eq!(
755 padded.values,
756 vec![
757 BabyBear::new(1),
758 BabyBear::new(2),
759 BabyBear::new(0),
760 BabyBear::new(0),
761 BabyBear::new(3),
762 BabyBear::new(4),
763 BabyBear::new(0),
764 BabyBear::new(0)
765 ]
766 );
767 }
768
769 #[test]
770 fn test_bit_reversed_zero_pad_no_change() {
771 let matrix = RowMajorMatrix::new(
772 vec![
773 BabyBear::new(1),
774 BabyBear::new(2),
775 BabyBear::new(3),
776 BabyBear::new(4),
777 ],
778 2,
779 );
780 let padded = matrix.bit_reversed_zero_pad(0);
781
782 assert_eq!(padded.width, 2);
783 assert_eq!(
784 padded.values,
785 vec![
786 BabyBear::new(1),
787 BabyBear::new(2),
788 BabyBear::new(3),
789 BabyBear::new(4),
790 ]
791 );
792 }
793
794 #[test]
795 fn test_scale() {
796 let mut matrix = RowMajorMatrix::new(
797 vec![
798 BabyBear::new(1),
799 BabyBear::new(2),
800 BabyBear::new(3),
801 BabyBear::new(4),
802 BabyBear::new(5),
803 BabyBear::new(6),
804 ],
805 2,
806 );
807 matrix.scale(BabyBear::new(2));
808 assert_eq!(
809 matrix.values,
810 vec![
811 BabyBear::new(2),
812 BabyBear::new(4),
813 BabyBear::new(6),
814 BabyBear::new(8),
815 BabyBear::new(10),
816 BabyBear::new(12)
817 ]
818 );
819 }
820
821 #[test]
822 fn test_scale_row() {
823 let mut matrix = RowMajorMatrix::new(
824 vec![
825 BabyBear::new(1),
826 BabyBear::new(2),
827 BabyBear::new(3),
828 BabyBear::new(4),
829 BabyBear::new(5),
830 BabyBear::new(6),
831 ],
832 2,
833 );
834 matrix.scale_row(1, BabyBear::new(3));
835 assert_eq!(
836 matrix.values,
837 vec![
838 BabyBear::new(1),
839 BabyBear::new(2),
840 BabyBear::new(9),
841 BabyBear::new(12),
842 BabyBear::new(5),
843 BabyBear::new(6),
844 ]
845 );
846 }
847
848 #[test]
849 fn test_to_row_major_matrix() {
850 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
851 let converted = matrix.to_row_major_matrix();
852
853 assert_eq!(converted.width, 2);
855 assert_eq!(converted.height(), 3);
856 assert_eq!(converted.values, vec![1, 2, 3, 4, 5, 6]);
857 }
858
859 #[test]
860 fn test_horizontally_packed_row() {
861 type Packed = FieldArray<BabyBear, 2>;
862
863 let matrix = RowMajorMatrix::new(
864 vec![
865 BabyBear::new(1),
866 BabyBear::new(2),
867 BabyBear::new(3),
868 BabyBear::new(4),
869 BabyBear::new(5),
870 BabyBear::new(6),
871 ],
872 3,
873 );
874
875 let (packed_iter, suffix_iter) = matrix.horizontally_packed_row::<Packed>(1);
876
877 let packed: Vec<_> = packed_iter.collect();
878 let suffix: Vec<_> = suffix_iter.collect();
879
880 assert_eq!(
881 packed,
882 vec![Packed::from([BabyBear::new(4), BabyBear::new(5)])]
883 );
884 assert_eq!(suffix, vec![BabyBear::new(6)]);
885 }
886
887 #[test]
888 fn test_padded_horizontally_packed_row() {
889 use p3_baby_bear::BabyBear;
890
891 type Packed = FieldArray<BabyBear, 2>;
892
893 let matrix = RowMajorMatrix::new(
894 vec![
895 BabyBear::new(1),
896 BabyBear::new(2),
897 BabyBear::new(3),
898 BabyBear::new(4),
899 BabyBear::new(5),
900 BabyBear::new(6),
901 ],
902 3,
903 );
904
905 let packed_iter = matrix.padded_horizontally_packed_row::<Packed>(1);
906 let packed: Vec<_> = packed_iter.collect();
907
908 assert_eq!(
909 packed,
910 vec![
911 Packed::from([BabyBear::new(4), BabyBear::new(5)]),
912 Packed::from([BabyBear::new(6), BabyBear::new(0)])
913 ]
914 );
915 }
916
917 #[test]
918 fn test_pad_to_height() {
919 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
920
921 matrix.pad_to_height(4, 9);
926
927 assert_eq!(matrix.height(), 4);
934 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 9, 9, 9, 9, 9, 9]);
935 }
936
937 #[test]
938 fn test_pad_to_power_of_two_height() {
939 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
944 assert_eq!(matrix.height(), 3);
945 matrix.pad_to_power_of_two_height(0);
946 assert_eq!(matrix.height(), 4);
947 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 0, 0]);
949
950 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
955 assert_eq!(matrix.height(), 4);
956 matrix.pad_to_power_of_two_height(99);
957 assert_eq!(matrix.height(), 4);
958 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
960
961 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3], 3);
965 assert_eq!(matrix.height(), 1);
966 matrix.pad_to_power_of_two_height(42);
967 assert_eq!(matrix.height(), 1);
968 assert_eq!(matrix.values, vec![1, 2, 3]);
969
970 let mut matrix = RowMajorMatrix::new(vec![1; 10], 2);
974 assert_eq!(matrix.height(), 5);
975 matrix.pad_to_power_of_two_height(-1);
976 assert_eq!(matrix.height(), 8);
977 assert_eq!(matrix.values.len(), 16);
979 assert!(matrix.values[..10].iter().all(|&v| v == 1));
980 assert!(matrix.values[10..].iter().all(|&v| v == -1));
981 }
982
983 #[test]
984 fn test_pad_to_power_of_two_height_empty_matrix() {
985 let mut matrix: RowMajorMatrix<i32> = RowMajorMatrix::new(vec![], 3);
988 assert_eq!(matrix.height(), 0);
989 assert_eq!(matrix.width, 3);
990 matrix.pad_to_power_of_two_height(7);
991 assert_eq!(matrix.height(), 1);
993 assert_eq!(matrix.values, vec![7, 7, 7]);
994 }
995
996 #[test]
997 fn test_transpose_into() {
998 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
999
1000 let mut transposed = RowMajorMatrix::new(vec![0; 6], 2);
1005
1006 matrix.transpose_into(&mut transposed);
1007
1008 assert_eq!(transposed.width, 2);
1014 assert_eq!(transposed.height(), 3);
1015 assert_eq!(transposed.values, vec![1, 4, 2, 5, 3, 6]);
1016 }
1017
1018 #[test]
1019 fn test_flatten_to_base() {
1020 let matrix = RowMajorMatrix::new(
1021 vec![
1022 BabyBear::new(2),
1023 BabyBear::new(3),
1024 BabyBear::new(4),
1025 BabyBear::new(5),
1026 ],
1027 2,
1028 );
1029
1030 let flattened: RowMajorMatrix<BabyBear> = matrix.flatten_to_base();
1031
1032 assert_eq!(flattened.width, 2);
1033 assert_eq!(
1034 flattened.values,
1035 vec![
1036 BabyBear::new(2),
1037 BabyBear::new(3),
1038 BabyBear::new(4),
1039 BabyBear::new(5),
1040 ]
1041 );
1042 }
1043
1044 #[test]
1045 fn test_horizontally_packed_row_mut() {
1046 type Packed = FieldArray<BabyBear, 2>;
1047
1048 let mut matrix = RowMajorMatrix::new(
1049 vec![
1050 BabyBear::new(1),
1051 BabyBear::new(2),
1052 BabyBear::new(3),
1053 BabyBear::new(4),
1054 BabyBear::new(5),
1055 BabyBear::new(6),
1056 ],
1057 3,
1058 );
1059
1060 let (packed, suffix) = matrix.horizontally_packed_row_mut::<Packed>(1);
1061 packed[0] = Packed::from([BabyBear::new(9), BabyBear::new(10)]);
1062 suffix[0] = BabyBear::new(11);
1063
1064 assert_eq!(
1065 matrix.values,
1066 vec![
1067 BabyBear::new(1),
1068 BabyBear::new(2),
1069 BabyBear::new(3),
1070 BabyBear::new(9),
1071 BabyBear::new(10),
1072 BabyBear::new(11),
1073 ]
1074 );
1075 }
1076
1077 #[test]
1078 fn test_par_row_chunks() {
1079 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1080
1081 let chunks: Vec<_> = matrix.par_row_chunks(2).collect();
1082
1083 assert_eq!(chunks.len(), 2);
1084 assert_eq!(chunks[0].values, vec![1, 2, 3, 4]);
1085 assert_eq!(chunks[1].values, vec![5, 6, 7, 8]);
1086 }
1087
1088 #[test]
1089 fn test_par_row_chunks_exact() {
1090 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1091
1092 let chunks: Vec<_> = matrix.par_row_chunks_exact(1).collect();
1093
1094 assert_eq!(chunks.len(), 3);
1095 assert_eq!(chunks[0].values, vec![1, 2]);
1096 assert_eq!(chunks[1].values, vec![3, 4]);
1097 assert_eq!(chunks[2].values, vec![5, 6]);
1098 }
1099
1100 #[test]
1101 fn test_par_row_chunks_mut() {
1102 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1103
1104 matrix
1105 .par_row_chunks_mut(2)
1106 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 10));
1107
1108 assert_eq!(matrix.values, vec![11, 12, 13, 14, 15, 16, 17, 18]);
1109 }
1110
1111 #[test]
1112 fn test_row_chunks_exact_mut() {
1113 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1114
1115 for chunk in matrix.row_chunks_exact_mut(1) {
1116 chunk.values.iter_mut().for_each(|x| *x *= 2);
1117 }
1118
1119 assert_eq!(matrix.values, vec![2, 4, 6, 8, 10, 12]);
1120 }
1121
1122 #[test]
1123 fn test_par_row_chunks_exact_mut() {
1124 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1125
1126 matrix
1127 .par_row_chunks_exact_mut(1)
1128 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 5));
1129
1130 assert_eq!(matrix.values, vec![6, 7, 8, 9, 10, 11]);
1131 }
1132
1133 #[test]
1134 fn test_row_pair_mut() {
1135 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1136
1137 let (row1, row2) = matrix.row_pair_mut(0, 2);
1138 row1[0] = 9;
1139 row2[1] = 10;
1140
1141 assert_eq!(matrix.values, vec![9, 2, 3, 4, 5, 10]);
1142 }
1143
1144 #[test]
1145 fn test_packed_row_pair_mut() {
1146 type Packed = FieldArray<BabyBear, 2>;
1147
1148 let mut matrix = RowMajorMatrix::new(
1149 vec![
1150 BabyBear::new(1),
1151 BabyBear::new(2),
1152 BabyBear::new(3),
1153 BabyBear::new(4),
1154 BabyBear::new(5),
1155 BabyBear::new(6),
1156 ],
1157 3,
1158 );
1159
1160 let ((packed1, sfx1), (packed2, sfx2)) = matrix.packed_row_pair_mut::<Packed>(0, 1);
1161 packed1[0] = Packed::from([BabyBear::new(7), BabyBear::new(8)]);
1162 packed2[0] = Packed::from([BabyBear::new(33), BabyBear::new(44)]);
1163 sfx1[0] = BabyBear::new(99);
1164 sfx2[0] = BabyBear::new(9);
1165
1166 assert_eq!(
1167 matrix.values,
1168 vec![
1169 BabyBear::new(7),
1170 BabyBear::new(8),
1171 BabyBear::new(99),
1172 BabyBear::new(33),
1173 BabyBear::new(44),
1174 BabyBear::new(9),
1175 ]
1176 );
1177 }
1178
1179 #[test]
1180 fn test_transpose_square_matrix() {
1181 const START_INDEX: usize = 1;
1182 const VALUE_LEN: usize = 9;
1183 const WIDTH: usize = 3;
1184 const HEIGHT: usize = 3;
1185
1186 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1187 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1188 let transposed = matrix.transpose();
1189 let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
1190 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1191 assert_eq!(transposed, should_be_transposed);
1192 }
1193
1194 #[test]
1195 fn test_transpose_row_matrix() {
1196 const START_INDEX: usize = 1;
1197 const VALUE_LEN: usize = 30;
1198 const WIDTH: usize = 1;
1199 const HEIGHT: usize = 30;
1200
1201 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1202 let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
1203 let transposed = matrix.transpose();
1204 let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
1205 assert_eq!(transposed, should_be_transposed);
1206 }
1207
1208 #[test]
1209 fn test_transpose_rectangular_matrix() {
1210 const START_INDEX: usize = 1;
1211 const VALUE_LEN: usize = 30;
1212 const WIDTH: usize = 5;
1213 const HEIGHT: usize = 6;
1214
1215 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1216 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1217 let transposed = matrix.transpose();
1218 let should_be_transposed_values = vec![
1219 1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
1220 5, 10, 15, 20, 25, 30,
1221 ];
1222 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1223 assert_eq!(transposed, should_be_transposed);
1224 }
1225
1226 #[test]
1227 fn test_transpose_larger_rectangular_matrix() {
1228 const START_INDEX: usize = 1;
1229 const VALUE_LEN: usize = 131072; const WIDTH: usize = 256;
1231 const HEIGHT: usize = 512;
1232
1233 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1234 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1235 let transposed = matrix.transpose();
1236
1237 assert_eq!(transposed.width(), HEIGHT);
1238 assert_eq!(transposed.height(), WIDTH);
1239
1240 for col_index in 0..WIDTH {
1241 for row_index in 0..HEIGHT {
1242 assert_eq!(
1243 matrix.values[row_index * WIDTH + col_index],
1244 transposed.values[col_index * HEIGHT + row_index]
1245 );
1246 }
1247 }
1248 }
1249
1250 #[test]
1251 fn test_transpose_very_large_rectangular_matrix() {
1252 const START_INDEX: usize = 1;
1253 const VALUE_LEN: usize = 1048576; const WIDTH: usize = 1024;
1255 const HEIGHT: usize = 1024;
1256
1257 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1258 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1259 let transposed = matrix.transpose();
1260
1261 assert_eq!(transposed.width(), HEIGHT);
1262 assert_eq!(transposed.height(), WIDTH);
1263
1264 for col_index in 0..WIDTH {
1265 for row_index in 0..HEIGHT {
1266 assert_eq!(
1267 matrix.values[row_index * WIDTH + col_index],
1268 transposed.values[col_index * HEIGHT + row_index]
1269 );
1270 }
1271 }
1272 }
1273
1274 #[test]
1275 fn test_vertically_packed_row_pair() {
1276 type Packed = FieldArray<BabyBear, 2>;
1277
1278 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1279
1280 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 2);
1282
1283 assert_eq!(
1299 packed,
1300 (1..5)
1301 .chain(9..13)
1302 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1303 .collect::<Vec<_>>(),
1304 );
1305 }
1306
1307 #[test]
1308 fn test_vertically_packed_row_pair_overlap() {
1309 type Packed = FieldArray<BabyBear, 2>;
1310
1311 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1312
1313 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 1);
1330
1331 assert_eq!(
1332 packed,
1333 (1..5)
1334 .chain(5..9)
1335 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1336 .collect::<Vec<_>>(),
1337 );
1338 }
1339
1340 #[test]
1341 fn test_vertically_packed_row_pair_wraparound_start_1() {
1342 use p3_baby_bear::BabyBear;
1343 use p3_field::FieldArray;
1344
1345 type Packed = FieldArray<BabyBear, 2>;
1346
1347 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1348
1349 let packed = matrix.vertically_packed_row_pair::<Packed>(1, 2);
1368
1369 assert_eq!(
1370 packed,
1371 vec![
1372 Packed::from([BabyBear::new(5), BabyBear::new(9)]),
1373 Packed::from([BabyBear::new(6), BabyBear::new(10)]),
1374 Packed::from([BabyBear::new(7), BabyBear::new(11)]),
1375 Packed::from([BabyBear::new(8), BabyBear::new(12)]),
1376 Packed::from([BabyBear::new(13), BabyBear::new(1)]),
1377 Packed::from([BabyBear::new(14), BabyBear::new(2)]),
1378 Packed::from([BabyBear::new(15), BabyBear::new(3)]),
1379 Packed::from([BabyBear::new(16), BabyBear::new(4)]),
1380 ]
1381 );
1382 }
1383}