1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::Itertools;
12use p3_field::{
13 BasedVectorSpace, ExtensionField, Field, FieldArray, PackedFieldExtension, PackedValue,
14 PrimeCharacteristicRing,
15};
16use p3_maybe_rayon::prelude::*;
17use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
18use tracing::instrument;
19
20use crate::dense::RowMajorMatrix;
21
22pub mod bitrev;
23pub mod dense;
24pub mod extension;
25pub mod horizontally_truncated;
26pub mod row_index_mapped;
27pub mod stack;
28pub mod strided;
29pub mod util;
30
31#[derive(Copy, Clone, PartialEq, Eq)]
36pub struct Dimensions {
37 pub width: usize,
39 pub height: usize,
41}
42
43impl Debug for Dimensions {
44 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
45 write!(f, "{}x{}", self.width, self.height)
46 }
47}
48
49impl Display for Dimensions {
50 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
51 write!(f, "{}x{}", self.width, self.height)
52 }
53}
54
55pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
61 fn width(&self) -> usize;
63
64 fn height(&self) -> usize;
66
67 fn dimensions(&self) -> Dimensions {
69 Dimensions {
70 width: self.width(),
71 height: self.height(),
72 }
73 }
74
75 #[inline]
86 fn get(&self, r: usize, c: usize) -> Option<T> {
87 (r < self.height() && c < self.width()).then(|| unsafe {
88 self.get_unchecked(r, c)
90 })
91 }
92
93 #[inline]
101 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
102 unsafe { self.row_slice_unchecked(r)[c].clone() }
103 }
104
105 #[inline]
111 fn row(
112 &self,
113 r: usize,
114 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
115 (r < self.height()).then(|| unsafe {
116 self.row_unchecked(r)
118 })
119 }
120
121 #[inline]
131 unsafe fn row_unchecked(
132 &self,
133 r: usize,
134 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
135 unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
136 }
137
138 #[inline]
148 unsafe fn row_subseq_unchecked(
149 &self,
150 r: usize,
151 start: usize,
152 end: usize,
153 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
154 unsafe {
155 self.row_unchecked(r)
156 .into_iter()
157 .skip(start)
158 .take(end - start)
159 }
160 }
161
162 #[inline]
166 fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
167 (r < self.height()).then(|| unsafe {
168 self.row_slice_unchecked(r)
170 })
171 }
172
173 #[inline]
181 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
182 unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
183 }
184
185 #[inline]
195 unsafe fn row_subslice_unchecked(
196 &self,
197 r: usize,
198 start: usize,
199 end: usize,
200 ) -> impl Deref<Target = [T]> {
201 unsafe {
202 self.row_subseq_unchecked(r, start, end)
203 .into_iter()
204 .collect_vec()
205 }
206 }
207
208 #[inline]
210 fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
211 unsafe {
212 (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
214 }
215 }
216
217 #[inline]
219 fn par_rows(
220 &self,
221 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
222 unsafe {
223 (0..self.height())
225 .into_par_iter()
226 .map(move |r| self.row_unchecked(r).into_iter())
227 }
228 }
229
230 fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
233 unsafe {
234 (0..c)
236 .map(|i| self.row_slice_unchecked((r + i) % self.height()))
237 .collect_vec()
238 }
239 }
240
241 #[inline]
245 fn first_row(
246 &self,
247 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
248 self.row(0)
249 }
250
251 #[inline]
255 fn last_row(
256 &self,
257 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
258 if self.height() == 0 {
259 None
260 } else {
261 unsafe { Some(self.row_unchecked(self.height() - 1)) }
263 }
264 }
265
266 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
268 where
269 Self: Sized,
270 T: Clone,
271 {
272 RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
273 }
274
275 fn horizontally_packed_row<'a, P>(
283 &'a self,
284 r: usize,
285 ) -> (
286 impl Iterator<Item = P> + Send + Sync,
287 impl Iterator<Item = T> + Send + Sync,
288 )
289 where
290 P: PackedValue<Value = T>,
291 T: Clone + 'a,
292 {
293 assert!(r < self.height(), "Row index out of bounds.");
294 let num_packed = self.width() / P::WIDTH;
295 unsafe {
296 let mut iter = self
298 .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
299 .into_iter();
300
301 let packed =
303 (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
304
305 let sfx = self
306 .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
307 .into_iter();
308 (packed, sfx)
309 }
310 }
311
312 fn padded_horizontally_packed_row<'a, P>(
319 &'a self,
320 r: usize,
321 ) -> impl Iterator<Item = P> + Send + Sync
322 where
323 P: PackedValue<Value = T>,
324 T: Clone + Default + 'a,
325 {
326 let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
327 let num_elems = self.width().div_ceil(P::WIDTH);
328 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
330 }
331
332 fn par_horizontally_packed_rows<'a, P>(
337 &'a self,
338 ) -> impl IndexedParallelIterator<
339 Item = (
340 impl Iterator<Item = P> + Send + Sync,
341 impl Iterator<Item = T> + Send + Sync,
342 ),
343 >
344 where
345 P: PackedValue<Value = T>,
346 T: Clone + 'a,
347 {
348 (0..self.height())
349 .into_par_iter()
350 .map(|r| self.horizontally_packed_row(r))
351 }
352
353 fn par_padded_horizontally_packed_rows<'a, P>(
357 &'a self,
358 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
359 where
360 P: PackedValue<Value = T>,
361 T: Clone + Default + 'a,
362 {
363 (0..self.height())
364 .into_par_iter()
365 .map(|r| self.padded_horizontally_packed_row(r))
366 }
367
368 #[inline]
374 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
375 where
376 T: Copy,
377 P: PackedValue<Value = T>,
378 {
379 let rows = self.wrapping_row_slices(r, P::WIDTH);
381
382 (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
384 }
385
386 #[inline]
394 fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
395 where
396 T: Copy,
397 P: PackedValue<Value = T>,
398 {
399 let rows = self.wrapping_row_slices(r, P::WIDTH);
404 let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
405
406 (0..self.width())
407 .map(|c| P::from_fn(|i| rows[i][c]))
408 .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
409 .collect_vec()
410 }
411
412 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
416 where
417 Self: Sized,
418 {
419 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
420 }
421
422 #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
426 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
427 where
428 T: Field,
429 EF: ExtensionField<T>,
430 {
431 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
432
433 let packed_result = self
434 .par_padded_horizontally_packed_rows::<T::Packing>()
435 .zip(v)
436 .par_fold_reduce(
437 || EF::ExtensionPacking::zero_vec(packed_width),
438 |mut acc, (row, &scale)| {
439 let scale: EF::ExtensionPacking = scale.into();
440 acc.iter_mut().zip(row).for_each(|(l, r)| *l += scale * r);
441 acc
442 },
443 |mut acc_l, acc_r| {
444 acc_l.iter_mut().zip(&acc_r).for_each(|(l, r)| *l += *r);
445 acc_l
446 },
447 );
448
449 EF::ExtensionPacking::to_ext_iter(packed_result)
450 .take(self.width())
451 .collect()
452 }
453
454 #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
461 fn columnwise_dot_product_batched<EF, const N: usize>(
462 &self,
463 vs: &[FieldArray<EF, N>],
464 ) -> Vec<FieldArray<EF, N>>
465 where
466 T: Field,
467 EF: ExtensionField<T>,
468 {
469 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
470
471 let packed_results: Vec<EF::ExtensionPacking> = self
472 .par_padded_horizontally_packed_rows::<T::Packing>()
473 .zip(vs)
474 .par_fold_reduce(
475 || EF::ExtensionPacking::zero_vec(packed_width * N),
476 |mut acc, (packed_row, scales)| {
477 let packed_scales: [EF::ExtensionPacking; N] =
479 scales.map_into_array(EF::ExtensionPacking::from);
480
481 for (acc_c, row_c) in acc.chunks_exact_mut(N).zip(packed_row) {
483 for j in 0..N {
484 acc_c[j] += packed_scales[j] * row_c;
485 }
486 }
487 acc
488 },
489 |mut acc_l, acc_r| {
490 acc_l.iter_mut().zip(&acc_r).for_each(|(lj, rj)| *lj += *rj);
491 acc_l
492 },
493 );
494
495 packed_results
497 .chunks(N)
498 .flat_map(|chunk| {
499 (0..T::Packing::WIDTH)
500 .map(move |lane| FieldArray::from_fn(|j| chunk[j].extract(lane)))
501 })
502 .take(self.width())
503 .collect()
504 }
505
506 fn rowwise_packed_dot_product<EF>(
516 &self,
517 vec: &[EF::ExtensionPacking],
518 ) -> impl IndexedParallelIterator<Item = EF>
519 where
520 T: Field,
521 EF: ExtensionField<T>,
522 {
523 assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
525
526 self.par_padded_horizontally_packed_rows::<T::Packing>()
529 .map(move |row_packed| {
530 let d = <EF::ExtensionPacking as BasedVectorSpace<T::Packing>>::DIMENSION;
532
533 let mut coeff_accs: [T::Packing; 8] = [T::Packing::ZERO; 8];
536 debug_assert!(d <= 8, "Extension degree > 8 not supported");
537
538 for (v, r) in vec.iter().zip(row_packed) {
540 let v_coeffs = v.as_basis_coefficients_slice();
541 for (acc, &v_coeff) in coeff_accs[..d].iter_mut().zip(v_coeffs) {
542 *acc += v_coeff * r;
543 }
544 }
545
546 let packed_result =
548 EF::ExtensionPacking::from_basis_coefficients_fn(|i| coeff_accs[i]);
549 EF::ExtensionPacking::to_ext_iter([packed_result]).sum()
550 })
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use alloc::vec::Vec;
557 use alloc::{format, vec};
558
559 use itertools::izip;
560 use p3_baby_bear::BabyBear;
561 use p3_field::PrimeCharacteristicRing;
562 use p3_field::extension::BinomialExtensionField;
563 use rand::SeedableRng;
564 use rand::rngs::SmallRng;
565
566 use super::*;
567
568 #[test]
569 fn test_columnwise_dot_product() {
570 type F = BabyBear;
571 type EF = BinomialExtensionField<BabyBear, 4>;
572
573 let mut rng = SmallRng::seed_from_u64(1);
574 let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
575 let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
576
577 let mut expected = vec![EF::ZERO; m.width()];
578 for (row, &scale) in izip!(m.rows(), &v) {
579 for (l, r) in izip!(&mut expected, row) {
580 *l += scale * r;
581 }
582 }
583
584 assert_eq!(m.columnwise_dot_product(&v), expected);
585 }
586
587 #[test]
588 fn test_columnwise_dot_product_batched() {
589 type F = BabyBear;
590 type EF = BinomialExtensionField<BabyBear, 4>;
591
592 let mut rng = SmallRng::seed_from_u64(1);
593 let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
594 let v1 = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
595 let v2 = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
596
597 let expected1 = m.columnwise_dot_product(&v1);
599 let expected2 = m.columnwise_dot_product(&v2);
600
601 let vs: Vec<FieldArray<EF, 2>> = v1
603 .into_iter()
604 .zip(v2)
605 .map(|(a, b)| FieldArray([a, b]))
606 .collect();
607 let results = m.columnwise_dot_product_batched::<EF, 2>(&vs);
608
609 let result1: Vec<EF> = results.iter().map(|r| r[0]).collect();
611 let result2: Vec<EF> = results.iter().map(|r| r[1]).collect();
612
613 assert_eq!(result1, expected1);
614 assert_eq!(result2, expected2);
615 }
616
617 struct MockMatrix {
619 data: Vec<Vec<u32>>,
620 width: usize,
621 height: usize,
622 }
623
624 impl Matrix<u32> for MockMatrix {
625 fn width(&self) -> usize {
626 self.width
627 }
628
629 fn height(&self) -> usize {
630 self.height
631 }
632
633 unsafe fn row_unchecked(
634 &self,
635 r: usize,
636 ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
637 {
638 self.data[r].clone()
640 }
641 }
642
643 #[test]
644 fn test_dimensions() {
645 let dims = Dimensions {
646 width: 3,
647 height: 5,
648 };
649 assert_eq!(dims.width, 3);
650 assert_eq!(dims.height, 5);
651 assert_eq!(format!("{dims:?}"), "3x5");
652 assert_eq!(format!("{dims}"), "3x5");
653 }
654
655 #[test]
656 fn test_mock_matrix_dimensions() {
657 let matrix = MockMatrix {
658 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
659 width: 3,
660 height: 3,
661 };
662 assert_eq!(matrix.width(), 3);
663 assert_eq!(matrix.height(), 3);
664 assert_eq!(
665 matrix.dimensions(),
666 Dimensions {
667 width: 3,
668 height: 3
669 }
670 );
671 }
672
673 #[test]
674 fn test_first_row() {
675 let matrix = MockMatrix {
676 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
677 width: 3,
678 height: 3,
679 };
680 let mut first_row = matrix.first_row().unwrap().into_iter();
681 assert_eq!(first_row.next(), Some(1));
682 assert_eq!(first_row.next(), Some(2));
683 assert_eq!(first_row.next(), Some(3));
684 }
685
686 #[test]
687 fn test_last_row() {
688 let matrix = MockMatrix {
689 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
690 width: 3,
691 height: 3,
692 };
693 let mut last_row = matrix.last_row().unwrap().into_iter();
694 assert_eq!(last_row.next(), Some(7));
695 assert_eq!(last_row.next(), Some(8));
696 assert_eq!(last_row.next(), Some(9));
697 }
698
699 #[test]
700 fn test_first_last_row_empty_matrix() {
701 let matrix = MockMatrix {
702 data: vec![],
703 width: 3,
704 height: 0,
705 };
706 let first_row = matrix.first_row();
707 let last_row = matrix.last_row();
708 assert!(first_row.is_none());
709 assert!(last_row.is_none());
710 }
711
712 #[test]
713 fn test_to_row_major_matrix() {
714 let matrix = MockMatrix {
715 data: vec![vec![1, 2], vec![3, 4]],
716 width: 2,
717 height: 2,
718 };
719 let row_major = matrix.to_row_major_matrix();
720 assert_eq!(row_major.values, vec![1, 2, 3, 4]);
721 assert_eq!(row_major.width, 2);
722 }
723
724 #[test]
725 fn test_matrix_get_methods() {
726 let matrix = MockMatrix {
727 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
728 width: 3,
729 height: 3,
730 };
731 assert_eq!(matrix.get(0, 0), Some(1));
732 assert_eq!(matrix.get(1, 2), Some(6));
733 assert_eq!(matrix.get(2, 1), Some(8));
734
735 unsafe {
736 assert_eq!(matrix.get_unchecked(0, 1), 2);
737 assert_eq!(matrix.get_unchecked(1, 0), 4);
738 assert_eq!(matrix.get_unchecked(2, 2), 9);
739 }
740
741 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 3), None); }
744
745 #[test]
746 fn test_matrix_row_methods_iteration() {
747 let matrix = MockMatrix {
748 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
749 width: 3,
750 height: 3,
751 };
752
753 let mut row_iter = matrix.row(1).unwrap().into_iter();
754 assert_eq!(row_iter.next(), Some(4));
755 assert_eq!(row_iter.next(), Some(5));
756 assert_eq!(row_iter.next(), Some(6));
757 assert_eq!(row_iter.next(), None);
758
759 unsafe {
760 let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
761 assert_eq!(row_iter_unchecked.next(), Some(7));
762 assert_eq!(row_iter_unchecked.next(), Some(8));
763 assert_eq!(row_iter_unchecked.next(), Some(9));
764 assert_eq!(row_iter_unchecked.next(), None);
765
766 let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
767 assert_eq!(row_iter_subset.next(), Some(2));
768 assert_eq!(row_iter_subset.next(), Some(3));
769 assert_eq!(row_iter_subset.next(), None);
770 }
771
772 assert!(matrix.row(3).is_none()); }
774
775 #[test]
776 fn test_row_slice_methods() {
777 let matrix = MockMatrix {
778 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
779 width: 3,
780 height: 3,
781 };
782 let row_slice = matrix.row_slice(1).unwrap();
783 assert_eq!(*row_slice, [4, 5, 6]);
784 unsafe {
785 let row_slice_unchecked = matrix.row_slice_unchecked(2);
786 assert_eq!(*row_slice_unchecked, [7, 8, 9]);
787
788 let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
789 assert_eq!(*row_subslice, [2]);
790 }
791
792 assert!(matrix.row_slice(3).is_none()); }
794
795 #[test]
796 fn test_matrix_rows() {
797 let matrix = MockMatrix {
798 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
799 width: 3,
800 height: 3,
801 };
802
803 let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
804 assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
805 }
806
807 #[test]
808 fn test_rowwise_packed_dot_product() {
809 use p3_field::PackedFieldExtension;
810
811 type F = BabyBear;
812 type EF = BinomialExtensionField<BabyBear, 4>;
813 type PF = <F as p3_field::Field>::Packing;
814 type EFPacked = <EF as p3_field::ExtensionField<F>>::ExtensionPacking;
815
816 let mut rng = SmallRng::seed_from_u64(42);
817
818 for (height, width) in [(32, 16), (64, 128), (128, 17), (256, 255)] {
820 let m = RowMajorMatrix::<F>::rand(&mut rng, height, width);
821 let v = RowMajorMatrix::<EF>::rand(&mut rng, width, 1).values;
822
823 let expected: Vec<EF> = m
825 .rows()
826 .map(|row| {
827 row.into_iter()
828 .zip(v.iter())
829 .map(|(r, &ve)| ve * r)
830 .sum::<EF>()
831 })
832 .collect();
833
834 let packed_v: Vec<EFPacked> = v
836 .chunks(<PF as PackedValue>::WIDTH)
837 .map(|chunk| {
838 let mut padded = vec![EF::ZERO; <PF as PackedValue>::WIDTH];
839 padded[..chunk.len()].copy_from_slice(chunk);
840 EFPacked::from_ext_slice(&padded)
841 })
842 .collect();
843
844 let result: Vec<EF> = m.rowwise_packed_dot_product::<EF>(&packed_v).collect();
846
847 assert_eq!(result, expected, "Mismatch for matrix {}x{}", height, width);
848 }
849 }
850}