1use alloc::borrow::Cow;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::borrow::{Borrow, BorrowMut};
5use core::iter;
6use core::marker::PhantomData;
7use core::ops::Deref;
8
9use p3_field::{
10 ExtensionField, Field, PackedValue, par_scale_slice_in_place, scale_slice_in_place_single_core,
11};
12use p3_maybe_rayon::prelude::*;
13use rand::Rng;
14use rand::distr::{Distribution, StandardUniform};
15use serde::{Deserialize, Serialize};
16use tracing::instrument;
17
18use crate::Matrix;
19
20#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct DenseMatrix<T, V = Vec<T>> {
25 pub values: V,
27 pub width: usize,
31 _phantom: PhantomData<T>,
35}
36
37pub type RowMajorMatrix<T> = DenseMatrix<T>;
38pub type RowMajorMatrixView<'a, T> = DenseMatrix<T, &'a [T]>;
39pub type RowMajorMatrixViewMut<'a, T> = DenseMatrix<T, &'a mut [T]>;
40pub type RowMajorMatrixCow<'a, T> = DenseMatrix<T, Cow<'a, [T]>>;
41
42pub trait DenseStorage<T>: Borrow<[T]> + Send + Sync {
43 fn to_vec(self) -> Vec<T>;
44}
45
46impl<T: Clone + Send + Sync> DenseStorage<T> for Vec<T> {
48 fn to_vec(self) -> Self {
49 self
50 }
51}
52
53impl<T: Clone + Send + Sync> DenseStorage<T> for &[T] {
54 fn to_vec(self) -> Vec<T> {
55 <[T]>::to_vec(self)
56 }
57}
58
59impl<T: Clone + Send + Sync> DenseStorage<T> for &mut [T] {
60 fn to_vec(self) -> Vec<T> {
61 <[T]>::to_vec(self)
62 }
63}
64
65impl<T: Clone + Send + Sync> DenseStorage<T> for Cow<'_, [T]> {
66 fn to_vec(self) -> Vec<T> {
67 self.into_owned()
68 }
69}
70
71impl<T: Clone + Send + Sync + Default> DenseMatrix<T> {
72 #[must_use]
75 pub fn default(width: usize, height: usize) -> Self {
76 Self::new(vec![T::default(); width * height], width)
77 }
78}
79
80impl<T: Clone + Send + Sync, S: DenseStorage<T>> DenseMatrix<T, S> {
81 #[must_use]
86 pub fn new(values: S, width: usize) -> Self {
87 debug_assert!(width == 0 || values.borrow().len() % width == 0);
88 Self {
89 values,
90 width,
91 _phantom: PhantomData,
92 }
93 }
94
95 #[must_use]
97 pub fn new_row(values: S) -> Self {
98 let width = values.borrow().len();
99 Self::new(values, width)
100 }
101
102 #[must_use]
104 pub fn new_col(values: S) -> Self {
105 Self::new(values, 1)
106 }
107
108 pub fn as_view(&self) -> RowMajorMatrixView<'_, T> {
110 RowMajorMatrixView::new(self.values.borrow(), self.width)
111 }
112
113 pub fn as_view_mut(&mut self) -> RowMajorMatrixViewMut<'_, T>
115 where
116 S: BorrowMut<[T]>,
117 {
118 RowMajorMatrixViewMut::new(self.values.borrow_mut(), self.width)
119 }
120
121 pub fn copy_from<S2>(&mut self, source: &DenseMatrix<T, S2>)
123 where
124 T: Copy,
125 S: BorrowMut<[T]>,
126 S2: DenseStorage<T>,
127 {
128 assert_eq!(self.dimensions(), source.dimensions());
129 self.par_rows_mut()
132 .zip(source.par_row_slices())
133 .for_each(|(dst, src)| {
134 dst.copy_from_slice(src);
135 });
136 }
137
138 pub fn flatten_to_base<F: Field>(self) -> RowMajorMatrix<F>
140 where
141 T: ExtensionField<F>,
142 {
143 let width = self.width * T::DIMENSION;
144 let values = T::flatten_to_base(self.values.to_vec());
145 RowMajorMatrix::new(values, width)
146 }
147
148 pub fn row_slices(&self) -> impl Iterator<Item = &[T]> {
150 self.values.borrow().chunks_exact(self.width)
151 }
152
153 pub fn par_row_slices(&self) -> impl IndexedParallelIterator<Item = &[T]>
155 where
156 T: Sync,
157 {
158 self.values.borrow().par_chunks_exact(self.width)
159 }
160
161 pub fn row_mut(&mut self, r: usize) -> &mut [T]
166 where
167 S: BorrowMut<[T]>,
168 {
169 &mut self.values.borrow_mut()[r * self.width..(r + 1) * self.width]
170 }
171
172 pub fn rows_mut(&mut self) -> impl Iterator<Item = &mut [T]>
174 where
175 S: BorrowMut<[T]>,
176 {
177 self.values.borrow_mut().chunks_exact_mut(self.width)
178 }
179
180 pub fn par_rows_mut<'a>(&'a mut self) -> impl IndexedParallelIterator<Item = &'a mut [T]>
182 where
183 T: 'a + Send,
184 S: BorrowMut<[T]>,
185 {
186 self.values.borrow_mut().par_chunks_exact_mut(self.width)
187 }
188
189 pub fn horizontally_packed_row_mut<P>(&mut self, r: usize) -> (&mut [P], &mut [T])
194 where
195 P: PackedValue<Value = T>,
196 S: BorrowMut<[T]>,
197 {
198 P::pack_slice_with_suffix_mut(self.row_mut(r))
199 }
200
201 pub fn scale_row(&mut self, r: usize, scale: T)
206 where
207 T: Field,
208 S: BorrowMut<[T]>,
209 {
210 scale_slice_in_place_single_core(self.row_mut(r), scale);
211 }
212
213 pub fn par_scale_row(&mut self, r: usize, scale: T)
222 where
223 T: Field,
224 S: BorrowMut<[T]>,
225 {
226 par_scale_slice_in_place(self.row_mut(r), scale);
227 }
228
229 pub fn scale(&mut self, scale: T)
231 where
232 T: Field,
233 S: BorrowMut<[T]>,
234 {
235 par_scale_slice_in_place(self.values.borrow_mut(), scale);
236 }
237
238 pub fn split_rows(&self, r: usize) -> (RowMajorMatrixView<T>, RowMajorMatrixView<T>) {
243 let (lo, hi) = self.values.borrow().split_at(r * self.width);
244 (
245 DenseMatrix::new(lo, self.width),
246 DenseMatrix::new(hi, self.width),
247 )
248 }
249
250 pub fn split_rows_mut(
255 &mut self,
256 r: usize,
257 ) -> (RowMajorMatrixViewMut<T>, RowMajorMatrixViewMut<T>)
258 where
259 S: BorrowMut<[T]>,
260 {
261 let (lo, hi) = self.values.borrow_mut().split_at_mut(r * self.width);
262 (
263 DenseMatrix::new(lo, self.width),
264 DenseMatrix::new(hi, self.width),
265 )
266 }
267
268 pub fn par_row_chunks(
272 &self,
273 chunk_rows: usize,
274 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
275 where
276 T: Send,
277 {
278 self.values
279 .borrow()
280 .par_chunks(self.width * chunk_rows)
281 .map(|slice| RowMajorMatrixView::new(slice, self.width))
282 }
283
284 pub fn par_row_chunks_exact(
288 &self,
289 chunk_rows: usize,
290 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixView<T>>
291 where
292 T: Send,
293 {
294 self.values
295 .borrow()
296 .par_chunks_exact(self.width * chunk_rows)
297 .map(|slice| RowMajorMatrixView::new(slice, self.width))
298 }
299
300 pub fn par_row_chunks_mut(
304 &mut self,
305 chunk_rows: usize,
306 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
307 where
308 T: Send,
309 S: BorrowMut<[T]>,
310 {
311 self.values
312 .borrow_mut()
313 .par_chunks_mut(self.width * chunk_rows)
314 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
315 }
316
317 pub fn row_chunks_exact_mut(
322 &mut self,
323 chunk_rows: usize,
324 ) -> impl Iterator<Item = RowMajorMatrixViewMut<T>>
325 where
326 T: Send,
327 S: BorrowMut<[T]>,
328 {
329 self.values
330 .borrow_mut()
331 .chunks_exact_mut(self.width * chunk_rows)
332 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
333 }
334
335 pub fn par_row_chunks_exact_mut(
340 &mut self,
341 chunk_rows: usize,
342 ) -> impl IndexedParallelIterator<Item = RowMajorMatrixViewMut<T>>
343 where
344 T: Send,
345 S: BorrowMut<[T]>,
346 {
347 self.values
348 .borrow_mut()
349 .par_chunks_exact_mut(self.width * chunk_rows)
350 .map(|slice| RowMajorMatrixViewMut::new(slice, self.width))
351 }
352
353 pub fn row_pair_mut(&mut self, row_1: usize, row_2: usize) -> (&mut [T], &mut [T])
358 where
359 S: BorrowMut<[T]>,
360 {
361 debug_assert_ne!(row_1, row_2);
362 let start_1 = row_1 * self.width;
363 let start_2 = row_2 * self.width;
364 let (lo, hi) = self.values.borrow_mut().split_at_mut(start_2);
365 (&mut lo[start_1..][..self.width], &mut hi[..self.width])
366 }
367
368 #[allow(clippy::type_complexity)]
375 pub fn packed_row_pair_mut<P>(
376 &mut self,
377 row_1: usize,
378 row_2: usize,
379 ) -> ((&mut [P], &mut [T]), (&mut [P], &mut [T]))
380 where
381 S: BorrowMut<[T]>,
382 P: PackedValue<Value = T>,
383 {
384 let (slice_1, slice_2) = self.row_pair_mut(row_1, row_2);
385 (
386 P::pack_slice_with_suffix_mut(slice_1),
387 P::pack_slice_with_suffix_mut(slice_2),
388 )
389 }
390
391 #[instrument(level = "debug", skip_all)]
394 pub fn bit_reversed_zero_pad(self, added_bits: usize) -> RowMajorMatrix<T>
395 where
396 T: Field,
397 {
398 if added_bits == 0 {
399 return self.to_row_major_matrix();
400 }
401
402 let w = self.width;
412 let mut padded =
413 RowMajorMatrix::new(T::zero_vec(self.values.borrow().len() << added_bits), w);
414 padded
415 .par_row_chunks_exact_mut(1 << added_bits)
416 .zip(self.par_row_slices())
417 .for_each(|(mut ch, r)| ch.row_mut(0).copy_from_slice(r));
418
419 padded
420 }
421}
422
423impl<T: Clone + Send + Sync, S: DenseStorage<T>> Matrix<T> for DenseMatrix<T, S> {
424 #[inline]
425 fn width(&self) -> usize {
426 self.width
427 }
428
429 #[inline]
430 fn height(&self) -> usize {
431 if self.width == 0 {
432 0
433 } else {
434 self.values.borrow().len() / self.width
435 }
436 }
437
438 #[inline]
439 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
440 unsafe {
441 self.values
443 .borrow()
444 .get_unchecked(r * self.width + c)
445 .clone()
446 }
447 }
448
449 #[inline]
450 unsafe fn row_subseq_unchecked(
451 &self,
452 r: usize,
453 start: usize,
454 end: usize,
455 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
456 unsafe {
457 self.values
459 .borrow()
460 .get_unchecked(r * self.width + start..r * self.width + end)
461 .iter()
462 .cloned()
463 }
464 }
465
466 #[inline]
467 unsafe fn row_subslice_unchecked(
468 &self,
469 r: usize,
470 start: usize,
471 end: usize,
472 ) -> impl Deref<Target = [T]> {
473 unsafe {
474 self.values
476 .borrow()
477 .get_unchecked(r * self.width + start..r * self.width + end)
478 }
479 }
480
481 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
482 where
483 Self: Sized,
484 T: Clone,
485 {
486 RowMajorMatrix::new(self.values.to_vec(), self.width)
487 }
488
489 #[inline]
490 fn horizontally_packed_row<'a, P>(
491 &'a self,
492 r: usize,
493 ) -> (
494 impl Iterator<Item = P> + Send + Sync,
495 impl Iterator<Item = T> + Send + Sync,
496 )
497 where
498 P: PackedValue<Value = T>,
499 T: Clone + 'a,
500 {
501 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
502 let (packed, sfx) = P::pack_slice_with_suffix(buf);
503 (packed.iter().copied(), sfx.iter().cloned())
504 }
505
506 #[inline]
507 fn padded_horizontally_packed_row<'a, P>(
508 &'a self,
509 r: usize,
510 ) -> impl Iterator<Item = P> + Send + Sync
511 where
512 P: PackedValue<Value = T>,
513 T: Clone + Default + 'a,
514 {
515 let buf = &self.values.borrow()[r * self.width..(r + 1) * self.width];
516 let (packed, sfx) = P::pack_slice_with_suffix(buf);
517 packed.iter().copied().chain(iter::once(P::from_fn(|i| {
518 sfx.get(i).cloned().unwrap_or_default()
519 })))
520 }
521}
522
523impl<T: Clone + Default + Send + Sync> DenseMatrix<T> {
524 pub fn as_cow<'a>(self) -> RowMajorMatrixCow<'a, T> {
525 RowMajorMatrixCow::new(Cow::Owned(self.values), self.width)
526 }
527
528 pub fn rand<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
529 where
530 StandardUniform: Distribution<T>,
531 {
532 let values = rng.sample_iter(StandardUniform).take(rows * cols).collect();
533 Self::new(values, cols)
534 }
535
536 pub fn rand_nonzero<R: Rng>(rng: &mut R, rows: usize, cols: usize) -> Self
537 where
538 T: Field,
539 StandardUniform: Distribution<T>,
540 {
541 let values = rng
542 .sample_iter(StandardUniform)
543 .filter(|x| !x.is_zero())
544 .take(rows * cols)
545 .collect();
546 Self::new(values, cols)
547 }
548
549 pub fn pad_to_height(&mut self, new_height: usize, fill: T) {
550 assert!(new_height >= self.height());
551 self.values.resize(self.width * new_height, fill);
552 }
553}
554
555impl<T: Copy + Default + Send + Sync, V: DenseStorage<T>> DenseMatrix<T, V> {
556 pub fn transpose(&self) -> RowMajorMatrix<T> {
558 let nelts = self.height() * self.width();
559 let mut values = vec![T::default(); nelts];
560 transpose::transpose(
561 self.values.borrow(),
562 &mut values,
563 self.width(),
564 self.height(),
565 );
566 RowMajorMatrix::new(values, self.height())
567 }
568
569 pub fn transpose_into<W: DenseStorage<T> + BorrowMut<[T]>>(
571 &self,
572 other: &mut DenseMatrix<T, W>,
573 ) {
574 assert_eq!(self.height(), other.width());
575 assert_eq!(other.height(), self.width());
576 transpose::transpose(
577 self.values.borrow(),
578 other.values.borrow_mut(),
579 self.width(),
580 self.height(),
581 );
582 }
583}
584
585impl<'a, T: Clone + Default + Send + Sync> RowMajorMatrixView<'a, T> {
586 pub fn as_cow(self) -> RowMajorMatrixCow<'a, T> {
587 RowMajorMatrixCow::new(Cow::Borrowed(self.values), self.width)
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use p3_baby_bear::BabyBear;
594 use p3_field::FieldArray;
595
596 use super::*;
597
598 #[test]
599 fn test_new() {
600 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
601 assert_eq!(matrix.width, 2);
602 assert_eq!(matrix.height(), 3);
603 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6]);
604 }
605
606 #[test]
607 fn test_new_row() {
608 let matrix = RowMajorMatrix::new_row(vec![1, 2, 3]);
609 assert_eq!(matrix.width, 3);
610 assert_eq!(matrix.height(), 1);
611 }
612
613 #[test]
614 fn test_new_col() {
615 let matrix = RowMajorMatrix::new_col(vec![1, 2, 3]);
616 assert_eq!(matrix.width, 1);
617 assert_eq!(matrix.height(), 3);
618 }
619
620 #[test]
621 fn test_height_with_zero_width() {
622 let matrix: DenseMatrix<i32> = RowMajorMatrix::new(vec![], 0);
623 assert_eq!(matrix.height(), 0);
624 }
625
626 #[test]
627 fn test_get_methods() {
628 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2); assert_eq!(matrix.get(0, 0), Some(1));
630 assert_eq!(matrix.get(1, 1), Some(4));
631 assert_eq!(matrix.get(2, 0), Some(5));
632 unsafe {
633 assert_eq!(matrix.get_unchecked(0, 1), 2);
634 assert_eq!(matrix.get_unchecked(1, 0), 3);
635 assert_eq!(matrix.get_unchecked(2, 1), 6);
636 }
637 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 2), None); }
640
641 #[test]
642 fn test_row_methods() {
643 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 4); let row: Vec<_> = matrix.row(1).unwrap().into_iter().collect();
645 assert_eq!(row, vec![5, 6, 7, 8]);
646 unsafe {
647 let row: Vec<_> = matrix.row_unchecked(0).into_iter().collect();
648 assert_eq!(row, vec![1, 2, 3, 4]);
649 let row: Vec<_> = matrix.row_subseq_unchecked(0, 0, 3).into_iter().collect();
650 assert_eq!(row, vec![1, 2, 3]);
651 let row: Vec<_> = matrix.row_subseq_unchecked(0, 1, 3).into_iter().collect();
652 assert_eq!(row, vec![2, 3]);
653 let row: Vec<_> = matrix.row_subseq_unchecked(0, 2, 4).into_iter().collect();
654 assert_eq!(row, vec![3, 4]);
655 }
656 assert!(matrix.row(2).is_none()); }
658
659 #[test]
660 fn test_row_slice_methods() {
661 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9], 3); let slice0 = matrix.row_slice(0);
663 let slice2 = matrix.row_slice(2);
664 assert_eq!(slice0.unwrap().deref(), &[1, 2, 3]);
665 assert_eq!(slice2.unwrap().deref(), &[7, 8, 9]);
666 unsafe {
667 assert_eq!(&[1, 2, 3], matrix.row_slice_unchecked(0).deref());
668 assert_eq!(&[7, 8, 9], matrix.row_slice_unchecked(2).deref());
669
670 assert_eq!(&[1, 2, 3], matrix.row_subslice_unchecked(0, 0, 3).deref());
671 assert_eq!(&[8], matrix.row_subslice_unchecked(2, 1, 2).deref());
672 }
673 assert!(matrix.row_slice(3).is_none()); }
675
676 #[test]
677 fn test_as_view() {
678 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
679 let view = matrix.as_view();
680 assert_eq!(view.values, &[1, 2, 3, 4]);
681 assert_eq!(view.width, 2);
682 }
683
684 #[test]
685 fn test_as_view_mut() {
686 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
687 let view = matrix.as_view_mut();
688 view.values[0] = 10;
689 assert_eq!(matrix.values, vec![10, 2, 3, 4]);
690 }
691
692 #[test]
693 fn test_copy_from() {
694 let mut matrix1 = RowMajorMatrix::new(vec![0, 0, 0, 0], 2);
695 let matrix2 = RowMajorMatrix::new(vec![1, 2, 3, 4], 2);
696 matrix1.copy_from(&matrix2);
697 assert_eq!(matrix1.values, vec![1, 2, 3, 4]);
698 }
699
700 #[test]
701 fn test_split_rows() {
702 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
703 let (top, bottom) = matrix.split_rows(1);
704 assert_eq!(top.values, vec![1, 2]);
705 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
706 }
707
708 #[test]
709 fn test_split_rows_mut() {
710 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
711 let (top, bottom) = matrix.split_rows_mut(1);
712 assert_eq!(top.values, vec![1, 2]);
713 assert_eq!(bottom.values, vec![3, 4, 5, 6]);
714 }
715
716 #[test]
717 fn test_row_mut() {
718 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
719 matrix.row_mut(1)[0] = 10;
720 assert_eq!(matrix.values, vec![1, 2, 10, 4, 5, 6]);
721 }
722
723 #[test]
724 fn test_bit_reversed_zero_pad() {
725 let matrix = RowMajorMatrix::new(
726 vec![
727 BabyBear::new(1),
728 BabyBear::new(2),
729 BabyBear::new(3),
730 BabyBear::new(4),
731 ],
732 2,
733 );
734 let padded = matrix.bit_reversed_zero_pad(1);
735 assert_eq!(padded.width, 2);
736 assert_eq!(
737 padded.values,
738 vec![
739 BabyBear::new(1),
740 BabyBear::new(2),
741 BabyBear::new(0),
742 BabyBear::new(0),
743 BabyBear::new(3),
744 BabyBear::new(4),
745 BabyBear::new(0),
746 BabyBear::new(0)
747 ]
748 );
749 }
750
751 #[test]
752 fn test_bit_reversed_zero_pad_no_change() {
753 let matrix = RowMajorMatrix::new(
754 vec![
755 BabyBear::new(1),
756 BabyBear::new(2),
757 BabyBear::new(3),
758 BabyBear::new(4),
759 ],
760 2,
761 );
762 let padded = matrix.bit_reversed_zero_pad(0);
763
764 assert_eq!(padded.width, 2);
765 assert_eq!(
766 padded.values,
767 vec![
768 BabyBear::new(1),
769 BabyBear::new(2),
770 BabyBear::new(3),
771 BabyBear::new(4),
772 ]
773 );
774 }
775
776 #[test]
777 fn test_scale() {
778 let mut matrix = RowMajorMatrix::new(
779 vec![
780 BabyBear::new(1),
781 BabyBear::new(2),
782 BabyBear::new(3),
783 BabyBear::new(4),
784 BabyBear::new(5),
785 BabyBear::new(6),
786 ],
787 2,
788 );
789 matrix.scale(BabyBear::new(2));
790 assert_eq!(
791 matrix.values,
792 vec![
793 BabyBear::new(2),
794 BabyBear::new(4),
795 BabyBear::new(6),
796 BabyBear::new(8),
797 BabyBear::new(10),
798 BabyBear::new(12)
799 ]
800 );
801 }
802
803 #[test]
804 fn test_scale_row() {
805 let mut matrix = RowMajorMatrix::new(
806 vec![
807 BabyBear::new(1),
808 BabyBear::new(2),
809 BabyBear::new(3),
810 BabyBear::new(4),
811 BabyBear::new(5),
812 BabyBear::new(6),
813 ],
814 2,
815 );
816 matrix.scale_row(1, BabyBear::new(3));
817 assert_eq!(
818 matrix.values,
819 vec![
820 BabyBear::new(1),
821 BabyBear::new(2),
822 BabyBear::new(9),
823 BabyBear::new(12),
824 BabyBear::new(5),
825 BabyBear::new(6),
826 ]
827 );
828 }
829
830 #[test]
831 fn test_to_row_major_matrix() {
832 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
833 let converted = matrix.to_row_major_matrix();
834
835 assert_eq!(converted.width, 2);
837 assert_eq!(converted.height(), 3);
838 assert_eq!(converted.values, vec![1, 2, 3, 4, 5, 6]);
839 }
840
841 #[test]
842 fn test_horizontally_packed_row() {
843 type Packed = FieldArray<BabyBear, 2>;
844
845 let matrix = RowMajorMatrix::new(
846 vec![
847 BabyBear::new(1),
848 BabyBear::new(2),
849 BabyBear::new(3),
850 BabyBear::new(4),
851 BabyBear::new(5),
852 BabyBear::new(6),
853 ],
854 3,
855 );
856
857 let (packed_iter, suffix_iter) = matrix.horizontally_packed_row::<Packed>(1);
858
859 let packed: Vec<_> = packed_iter.collect();
860 let suffix: Vec<_> = suffix_iter.collect();
861
862 assert_eq!(
863 packed,
864 vec![Packed::from([BabyBear::new(4), BabyBear::new(5)])]
865 );
866 assert_eq!(suffix, vec![BabyBear::new(6)]);
867 }
868
869 #[test]
870 fn test_padded_horizontally_packed_row() {
871 use p3_baby_bear::BabyBear;
872
873 type Packed = FieldArray<BabyBear, 2>;
874
875 let matrix = RowMajorMatrix::new(
876 vec![
877 BabyBear::new(1),
878 BabyBear::new(2),
879 BabyBear::new(3),
880 BabyBear::new(4),
881 BabyBear::new(5),
882 BabyBear::new(6),
883 ],
884 3,
885 );
886
887 let packed_iter = matrix.padded_horizontally_packed_row::<Packed>(1);
888 let packed: Vec<_> = packed_iter.collect();
889
890 assert_eq!(
891 packed,
892 vec![
893 Packed::from([BabyBear::new(4), BabyBear::new(5)]),
894 Packed::from([BabyBear::new(6), BabyBear::new(0)])
895 ]
896 );
897 }
898
899 #[test]
900 fn test_pad_to_height() {
901 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
902
903 matrix.pad_to_height(4, 9);
908
909 assert_eq!(matrix.height(), 4);
916 assert_eq!(matrix.values, vec![1, 2, 3, 4, 5, 6, 9, 9, 9, 9, 9, 9]);
917 }
918
919 #[test]
920 fn test_transpose_into() {
921 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 3);
922
923 let mut transposed = RowMajorMatrix::new(vec![0; 6], 2);
928
929 matrix.transpose_into(&mut transposed);
930
931 assert_eq!(transposed.width, 2);
937 assert_eq!(transposed.height(), 3);
938 assert_eq!(transposed.values, vec![1, 4, 2, 5, 3, 6]);
939 }
940
941 #[test]
942 fn test_flatten_to_base() {
943 let matrix = RowMajorMatrix::new(
944 vec![
945 BabyBear::new(2),
946 BabyBear::new(3),
947 BabyBear::new(4),
948 BabyBear::new(5),
949 ],
950 2,
951 );
952
953 let flattened: RowMajorMatrix<BabyBear> = matrix.flatten_to_base();
954
955 assert_eq!(flattened.width, 2);
956 assert_eq!(
957 flattened.values,
958 vec![
959 BabyBear::new(2),
960 BabyBear::new(3),
961 BabyBear::new(4),
962 BabyBear::new(5),
963 ]
964 );
965 }
966
967 #[test]
968 fn test_horizontally_packed_row_mut() {
969 type Packed = FieldArray<BabyBear, 2>;
970
971 let mut matrix = RowMajorMatrix::new(
972 vec![
973 BabyBear::new(1),
974 BabyBear::new(2),
975 BabyBear::new(3),
976 BabyBear::new(4),
977 BabyBear::new(5),
978 BabyBear::new(6),
979 ],
980 3,
981 );
982
983 let (packed, suffix) = matrix.horizontally_packed_row_mut::<Packed>(1);
984 packed[0] = Packed::from([BabyBear::new(9), BabyBear::new(10)]);
985 suffix[0] = BabyBear::new(11);
986
987 assert_eq!(
988 matrix.values,
989 vec![
990 BabyBear::new(1),
991 BabyBear::new(2),
992 BabyBear::new(3),
993 BabyBear::new(9),
994 BabyBear::new(10),
995 BabyBear::new(11),
996 ]
997 );
998 }
999
1000 #[test]
1001 fn test_par_row_chunks() {
1002 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1003
1004 let chunks: Vec<_> = matrix.par_row_chunks(2).collect();
1005
1006 assert_eq!(chunks.len(), 2);
1007 assert_eq!(chunks[0].values, vec![1, 2, 3, 4]);
1008 assert_eq!(chunks[1].values, vec![5, 6, 7, 8]);
1009 }
1010
1011 #[test]
1012 fn test_par_row_chunks_exact() {
1013 let matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1014
1015 let chunks: Vec<_> = matrix.par_row_chunks_exact(1).collect();
1016
1017 assert_eq!(chunks.len(), 3);
1018 assert_eq!(chunks[0].values, vec![1, 2]);
1019 assert_eq!(chunks[1].values, vec![3, 4]);
1020 assert_eq!(chunks[2].values, vec![5, 6]);
1021 }
1022
1023 #[test]
1024 fn test_par_row_chunks_mut() {
1025 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6, 7, 8], 2);
1026
1027 matrix
1028 .par_row_chunks_mut(2)
1029 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 10));
1030
1031 assert_eq!(matrix.values, vec![11, 12, 13, 14, 15, 16, 17, 18]);
1032 }
1033
1034 #[test]
1035 fn test_row_chunks_exact_mut() {
1036 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1037
1038 for chunk in matrix.row_chunks_exact_mut(1) {
1039 chunk.values.iter_mut().for_each(|x| *x *= 2);
1040 }
1041
1042 assert_eq!(matrix.values, vec![2, 4, 6, 8, 10, 12]);
1043 }
1044
1045 #[test]
1046 fn test_par_row_chunks_exact_mut() {
1047 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1048
1049 matrix
1050 .par_row_chunks_exact_mut(1)
1051 .for_each(|chunk| chunk.values.iter_mut().for_each(|x| *x += 5));
1052
1053 assert_eq!(matrix.values, vec![6, 7, 8, 9, 10, 11]);
1054 }
1055
1056 #[test]
1057 fn test_row_pair_mut() {
1058 let mut matrix = RowMajorMatrix::new(vec![1, 2, 3, 4, 5, 6], 2);
1059
1060 let (row1, row2) = matrix.row_pair_mut(0, 2);
1061 row1[0] = 9;
1062 row2[1] = 10;
1063
1064 assert_eq!(matrix.values, vec![9, 2, 3, 4, 5, 10]);
1065 }
1066
1067 #[test]
1068 fn test_packed_row_pair_mut() {
1069 type Packed = FieldArray<BabyBear, 2>;
1070
1071 let mut matrix = RowMajorMatrix::new(
1072 vec![
1073 BabyBear::new(1),
1074 BabyBear::new(2),
1075 BabyBear::new(3),
1076 BabyBear::new(4),
1077 BabyBear::new(5),
1078 BabyBear::new(6),
1079 ],
1080 3,
1081 );
1082
1083 let ((packed1, sfx1), (packed2, sfx2)) = matrix.packed_row_pair_mut::<Packed>(0, 1);
1084 packed1[0] = Packed::from([BabyBear::new(7), BabyBear::new(8)]);
1085 packed2[0] = Packed::from([BabyBear::new(33), BabyBear::new(44)]);
1086 sfx1[0] = BabyBear::new(99);
1087 sfx2[0] = BabyBear::new(9);
1088
1089 assert_eq!(
1090 matrix.values,
1091 vec![
1092 BabyBear::new(7),
1093 BabyBear::new(8),
1094 BabyBear::new(99),
1095 BabyBear::new(33),
1096 BabyBear::new(44),
1097 BabyBear::new(9),
1098 ]
1099 );
1100 }
1101
1102 #[test]
1103 fn test_transpose_square_matrix() {
1104 const START_INDEX: usize = 1;
1105 const VALUE_LEN: usize = 9;
1106 const WIDTH: usize = 3;
1107 const HEIGHT: usize = 3;
1108
1109 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1110 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1111 let transposed = matrix.transpose();
1112 let should_be_transposed_values = vec![1, 4, 7, 2, 5, 8, 3, 6, 9];
1113 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1114 assert_eq!(transposed, should_be_transposed);
1115 }
1116
1117 #[test]
1118 fn test_transpose_row_matrix() {
1119 const START_INDEX: usize = 1;
1120 const VALUE_LEN: usize = 30;
1121 const WIDTH: usize = 1;
1122 const HEIGHT: usize = 30;
1123
1124 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1125 let matrix = RowMajorMatrix::new(matrix_values.clone(), WIDTH);
1126 let transposed = matrix.transpose();
1127 let should_be_transposed = RowMajorMatrix::new(matrix_values, HEIGHT);
1128 assert_eq!(transposed, should_be_transposed);
1129 }
1130
1131 #[test]
1132 fn test_transpose_rectangular_matrix() {
1133 const START_INDEX: usize = 1;
1134 const VALUE_LEN: usize = 30;
1135 const WIDTH: usize = 5;
1136 const HEIGHT: usize = 6;
1137
1138 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1139 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1140 let transposed = matrix.transpose();
1141 let should_be_transposed_values = vec![
1142 1, 6, 11, 16, 21, 26, 2, 7, 12, 17, 22, 27, 3, 8, 13, 18, 23, 28, 4, 9, 14, 19, 24, 29,
1143 5, 10, 15, 20, 25, 30,
1144 ];
1145 let should_be_transposed = RowMajorMatrix::new(should_be_transposed_values, HEIGHT);
1146 assert_eq!(transposed, should_be_transposed);
1147 }
1148
1149 #[test]
1150 fn test_transpose_larger_rectangular_matrix() {
1151 const START_INDEX: usize = 1;
1152 const VALUE_LEN: usize = 131072; const WIDTH: usize = 256;
1154 const HEIGHT: usize = 512;
1155
1156 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1157 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1158 let transposed = matrix.transpose();
1159
1160 assert_eq!(transposed.width(), HEIGHT);
1161 assert_eq!(transposed.height(), WIDTH);
1162
1163 for col_index in 0..WIDTH {
1164 for row_index in 0..HEIGHT {
1165 assert_eq!(
1166 matrix.values[row_index * WIDTH + col_index],
1167 transposed.values[col_index * HEIGHT + row_index]
1168 );
1169 }
1170 }
1171 }
1172
1173 #[test]
1174 fn test_transpose_very_large_rectangular_matrix() {
1175 const START_INDEX: usize = 1;
1176 const VALUE_LEN: usize = 1048576; const WIDTH: usize = 1024;
1178 const HEIGHT: usize = 1024;
1179
1180 let matrix_values = (START_INDEX..=VALUE_LEN).collect::<Vec<_>>();
1181 let matrix = RowMajorMatrix::new(matrix_values, WIDTH);
1182 let transposed = matrix.transpose();
1183
1184 assert_eq!(transposed.width(), HEIGHT);
1185 assert_eq!(transposed.height(), WIDTH);
1186
1187 for col_index in 0..WIDTH {
1188 for row_index in 0..HEIGHT {
1189 assert_eq!(
1190 matrix.values[row_index * WIDTH + col_index],
1191 transposed.values[col_index * HEIGHT + row_index]
1192 );
1193 }
1194 }
1195 }
1196
1197 #[test]
1198 fn test_vertically_packed_row_pair() {
1199 type Packed = FieldArray<BabyBear, 2>;
1200
1201 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1202
1203 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 2);
1205
1206 assert_eq!(
1222 packed,
1223 (1..5)
1224 .chain(9..13)
1225 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1226 .collect::<Vec<_>>(),
1227 );
1228 }
1229
1230 #[test]
1231 fn test_vertically_packed_row_pair_overlap() {
1232 type Packed = FieldArray<BabyBear, 2>;
1233
1234 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1235
1236 let packed = matrix.vertically_packed_row_pair::<Packed>(0, 1);
1253
1254 assert_eq!(
1255 packed,
1256 (1..5)
1257 .chain(5..9)
1258 .map(|i| [BabyBear::new(i), BabyBear::new(i + 4)].into())
1259 .collect::<Vec<_>>(),
1260 );
1261 }
1262
1263 #[test]
1264 fn test_vertically_packed_row_pair_wraparound_start_1() {
1265 use p3_baby_bear::BabyBear;
1266 use p3_field::FieldArray;
1267
1268 type Packed = FieldArray<BabyBear, 2>;
1269
1270 let matrix = RowMajorMatrix::new((1..17).map(BabyBear::new).collect::<Vec<_>>(), 4);
1271
1272 let packed = matrix.vertically_packed_row_pair::<Packed>(1, 2);
1291
1292 assert_eq!(
1293 packed,
1294 vec![
1295 Packed::from([BabyBear::new(5), BabyBear::new(9)]),
1296 Packed::from([BabyBear::new(6), BabyBear::new(10)]),
1297 Packed::from([BabyBear::new(7), BabyBear::new(11)]),
1298 Packed::from([BabyBear::new(8), BabyBear::new(12)]),
1299 Packed::from([BabyBear::new(13), BabyBear::new(1)]),
1300 Packed::from([BabyBear::new(14), BabyBear::new(2)]),
1301 Packed::from([BabyBear::new(15), BabyBear::new(3)]),
1302 Packed::from([BabyBear::new(16), BabyBear::new(4)]),
1303 ]
1304 );
1305 }
1306}