1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::{Itertools, izip};
12use p3_field::{
13 ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing, dot_product,
14};
15use p3_maybe_rayon::prelude::*;
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17use tracing::instrument;
18
19use crate::dense::RowMajorMatrix;
20
21pub mod bitrev;
22pub mod dense;
23pub mod extension;
24pub mod horizontally_truncated;
25pub mod row_index_mapped;
26pub mod stack;
27pub mod strided;
28pub mod util;
29
30#[derive(Copy, Clone, PartialEq, Eq)]
35pub struct Dimensions {
36 pub width: usize,
38 pub height: usize,
40}
41
42impl Debug for Dimensions {
43 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
44 write!(f, "{}x{}", self.width, self.height)
45 }
46}
47
48impl Display for Dimensions {
49 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
50 write!(f, "{}x{}", self.width, self.height)
51 }
52}
53
54pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
60 fn width(&self) -> usize;
62
63 fn height(&self) -> usize;
65
66 fn dimensions(&self) -> Dimensions {
68 Dimensions {
69 width: self.width(),
70 height: self.height(),
71 }
72 }
73
74 #[inline]
85 fn get(&self, r: usize, c: usize) -> Option<T> {
86 (r < self.height() && c < self.width()).then(|| unsafe {
87 self.get_unchecked(r, c)
89 })
90 }
91
92 #[inline]
100 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
101 unsafe { self.row_slice_unchecked(r)[c].clone() }
102 }
103
104 #[inline]
110 fn row(
111 &self,
112 r: usize,
113 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
114 (r < self.height()).then(|| unsafe {
115 self.row_unchecked(r)
117 })
118 }
119
120 #[inline]
130 unsafe fn row_unchecked(
131 &self,
132 r: usize,
133 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
134 unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
135 }
136
137 #[inline]
147 unsafe fn row_subseq_unchecked(
148 &self,
149 r: usize,
150 start: usize,
151 end: usize,
152 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
153 unsafe {
154 self.row_unchecked(r)
155 .into_iter()
156 .skip(start)
157 .take(end - start)
158 }
159 }
160
161 #[inline]
165 fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
166 (r < self.height()).then(|| unsafe {
167 self.row_slice_unchecked(r)
169 })
170 }
171
172 #[inline]
180 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
181 unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
182 }
183
184 #[inline]
194 unsafe fn row_subslice_unchecked(
195 &self,
196 r: usize,
197 start: usize,
198 end: usize,
199 ) -> impl Deref<Target = [T]> {
200 unsafe {
201 self.row_subseq_unchecked(r, start, end)
202 .into_iter()
203 .collect_vec()
204 }
205 }
206
207 #[inline]
209 fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
210 unsafe {
211 (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
213 }
214 }
215
216 #[inline]
218 fn par_rows(
219 &self,
220 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
221 unsafe {
222 (0..self.height())
224 .into_par_iter()
225 .map(move |r| self.row_unchecked(r).into_iter())
226 }
227 }
228
229 fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
232 unsafe {
233 (0..c)
235 .map(|i| self.row_slice_unchecked((r + i) % self.height()))
236 .collect_vec()
237 }
238 }
239
240 #[inline]
244 fn first_row(
245 &self,
246 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
247 self.row(0)
248 }
249
250 #[inline]
254 fn last_row(
255 &self,
256 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
257 if self.height() == 0 {
258 None
259 } else {
260 unsafe { Some(self.row_unchecked(self.height() - 1)) }
262 }
263 }
264
265 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
267 where
268 Self: Sized,
269 T: Clone,
270 {
271 RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
272 }
273
274 fn horizontally_packed_row<'a, P>(
282 &'a self,
283 r: usize,
284 ) -> (
285 impl Iterator<Item = P> + Send + Sync,
286 impl Iterator<Item = T> + Send + Sync,
287 )
288 where
289 P: PackedValue<Value = T>,
290 T: Clone + 'a,
291 {
292 assert!(r < self.height(), "Row index out of bounds.");
293 let num_packed = self.width() / P::WIDTH;
294 unsafe {
295 let mut iter = self
297 .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
298 .into_iter();
299
300 let packed =
302 (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
303
304 let sfx = self
305 .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
306 .into_iter();
307 (packed, sfx)
308 }
309 }
310
311 fn padded_horizontally_packed_row<'a, P>(
318 &'a self,
319 r: usize,
320 ) -> impl Iterator<Item = P> + Send + Sync
321 where
322 P: PackedValue<Value = T>,
323 T: Clone + Default + 'a,
324 {
325 let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
326 let num_elems = self.width().div_ceil(P::WIDTH);
327 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
329 }
330
331 fn par_horizontally_packed_rows<'a, P>(
336 &'a self,
337 ) -> impl IndexedParallelIterator<
338 Item = (
339 impl Iterator<Item = P> + Send + Sync,
340 impl Iterator<Item = T> + Send + Sync,
341 ),
342 >
343 where
344 P: PackedValue<Value = T>,
345 T: Clone + 'a,
346 {
347 (0..self.height())
348 .into_par_iter()
349 .map(|r| self.horizontally_packed_row(r))
350 }
351
352 fn par_padded_horizontally_packed_rows<'a, P>(
356 &'a self,
357 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
358 where
359 P: PackedValue<Value = T>,
360 T: Clone + Default + 'a,
361 {
362 (0..self.height())
363 .into_par_iter()
364 .map(|r| self.padded_horizontally_packed_row(r))
365 }
366
367 #[inline]
373 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
374 where
375 T: Copy,
376 P: PackedValue<Value = T>,
377 {
378 let rows = self.wrapping_row_slices(r, P::WIDTH);
380
381 (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
383 }
384
385 #[inline]
393 fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
394 where
395 T: Copy,
396 P: PackedValue<Value = T>,
397 {
398 let rows = self.wrapping_row_slices(r, P::WIDTH);
403 let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
404
405 (0..self.width())
406 .map(|c| P::from_fn(|i| rows[i][c]))
407 .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
408 .collect_vec()
409 }
410
411 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
415 where
416 Self: Sized,
417 {
418 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
419 }
420
421 #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
425 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
426 where
427 T: Field,
428 EF: ExtensionField<T>,
429 {
430 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
431
432 let packed_result = self
433 .par_padded_horizontally_packed_rows::<T::Packing>()
434 .zip(v)
435 .par_fold_reduce(
436 || EF::ExtensionPacking::zero_vec(packed_width),
437 |mut acc, (row, &scale)| {
438 let scale = EF::ExtensionPacking::from(scale);
439 izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
440 acc
441 },
442 |mut acc_l, acc_r| {
443 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
444 acc_l
445 },
446 );
447
448 EF::ExtensionPacking::to_ext_iter(packed_result)
449 .take(self.width())
450 .collect()
451 }
452
453 fn rowwise_packed_dot_product<EF>(
463 &self,
464 vec: &[EF::ExtensionPacking],
465 ) -> impl IndexedParallelIterator<Item = EF>
466 where
467 T: Field,
468 EF: ExtensionField<T>,
469 {
470 assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
472
473 self.par_padded_horizontally_packed_rows::<T::Packing>()
477 .map(move |row_packed| {
478 let packed_sum_of_packed: EF::ExtensionPacking =
479 dot_product(vec.iter().copied(), row_packed);
480 EF::ExtensionPacking::to_ext_iter([packed_sum_of_packed]).sum()
481 })
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use alloc::vec::Vec;
488 use alloc::{format, vec};
489
490 use itertools::izip;
491 use p3_baby_bear::BabyBear;
492 use p3_field::PrimeCharacteristicRing;
493 use p3_field::extension::BinomialExtensionField;
494 use rand::SeedableRng;
495 use rand::rngs::SmallRng;
496
497 use super::*;
498
499 #[test]
500 fn test_columnwise_dot_product() {
501 type F = BabyBear;
502 type EF = BinomialExtensionField<BabyBear, 4>;
503
504 let mut rng = SmallRng::seed_from_u64(1);
505 let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
506 let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
507
508 let mut expected = vec![EF::ZERO; m.width()];
509 for (row, &scale) in izip!(m.rows(), &v) {
510 for (l, r) in izip!(&mut expected, row) {
511 *l += scale * r;
512 }
513 }
514
515 assert_eq!(m.columnwise_dot_product(&v), expected);
516 }
517
518 struct MockMatrix {
520 data: Vec<Vec<u32>>,
521 width: usize,
522 height: usize,
523 }
524
525 impl Matrix<u32> for MockMatrix {
526 fn width(&self) -> usize {
527 self.width
528 }
529
530 fn height(&self) -> usize {
531 self.height
532 }
533
534 unsafe fn row_unchecked(
535 &self,
536 r: usize,
537 ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
538 {
539 self.data[r].clone()
541 }
542 }
543
544 #[test]
545 fn test_dimensions() {
546 let dims = Dimensions {
547 width: 3,
548 height: 5,
549 };
550 assert_eq!(dims.width, 3);
551 assert_eq!(dims.height, 5);
552 assert_eq!(format!("{dims:?}"), "3x5");
553 assert_eq!(format!("{dims}"), "3x5");
554 }
555
556 #[test]
557 fn test_mock_matrix_dimensions() {
558 let matrix = MockMatrix {
559 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
560 width: 3,
561 height: 3,
562 };
563 assert_eq!(matrix.width(), 3);
564 assert_eq!(matrix.height(), 3);
565 assert_eq!(
566 matrix.dimensions(),
567 Dimensions {
568 width: 3,
569 height: 3
570 }
571 );
572 }
573
574 #[test]
575 fn test_first_row() {
576 let matrix = MockMatrix {
577 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
578 width: 3,
579 height: 3,
580 };
581 let mut first_row = matrix.first_row().unwrap().into_iter();
582 assert_eq!(first_row.next(), Some(1));
583 assert_eq!(first_row.next(), Some(2));
584 assert_eq!(first_row.next(), Some(3));
585 }
586
587 #[test]
588 fn test_last_row() {
589 let matrix = MockMatrix {
590 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
591 width: 3,
592 height: 3,
593 };
594 let mut last_row = matrix.last_row().unwrap().into_iter();
595 assert_eq!(last_row.next(), Some(7));
596 assert_eq!(last_row.next(), Some(8));
597 assert_eq!(last_row.next(), Some(9));
598 }
599
600 #[test]
601 fn test_first_last_row_empty_matrix() {
602 let matrix = MockMatrix {
603 data: vec![],
604 width: 3,
605 height: 0,
606 };
607 let first_row = matrix.first_row();
608 let last_row = matrix.last_row();
609 assert!(first_row.is_none());
610 assert!(last_row.is_none());
611 }
612
613 #[test]
614 fn test_to_row_major_matrix() {
615 let matrix = MockMatrix {
616 data: vec![vec![1, 2], vec![3, 4]],
617 width: 2,
618 height: 2,
619 };
620 let row_major = matrix.to_row_major_matrix();
621 assert_eq!(row_major.values, vec![1, 2, 3, 4]);
622 assert_eq!(row_major.width, 2);
623 }
624
625 #[test]
626 fn test_matrix_get_methods() {
627 let matrix = MockMatrix {
628 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
629 width: 3,
630 height: 3,
631 };
632 assert_eq!(matrix.get(0, 0), Some(1));
633 assert_eq!(matrix.get(1, 2), Some(6));
634 assert_eq!(matrix.get(2, 1), Some(8));
635
636 unsafe {
637 assert_eq!(matrix.get_unchecked(0, 1), 2);
638 assert_eq!(matrix.get_unchecked(1, 0), 4);
639 assert_eq!(matrix.get_unchecked(2, 2), 9);
640 }
641
642 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 3), None); }
645
646 #[test]
647 fn test_matrix_row_methods_iteration() {
648 let matrix = MockMatrix {
649 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
650 width: 3,
651 height: 3,
652 };
653
654 let mut row_iter = matrix.row(1).unwrap().into_iter();
655 assert_eq!(row_iter.next(), Some(4));
656 assert_eq!(row_iter.next(), Some(5));
657 assert_eq!(row_iter.next(), Some(6));
658 assert_eq!(row_iter.next(), None);
659
660 unsafe {
661 let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
662 assert_eq!(row_iter_unchecked.next(), Some(7));
663 assert_eq!(row_iter_unchecked.next(), Some(8));
664 assert_eq!(row_iter_unchecked.next(), Some(9));
665 assert_eq!(row_iter_unchecked.next(), None);
666
667 let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
668 assert_eq!(row_iter_subset.next(), Some(2));
669 assert_eq!(row_iter_subset.next(), Some(3));
670 assert_eq!(row_iter_subset.next(), None);
671 }
672
673 assert!(matrix.row(3).is_none()); }
675
676 #[test]
677 fn test_row_slice_methods() {
678 let matrix = MockMatrix {
679 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
680 width: 3,
681 height: 3,
682 };
683 let row_slice = matrix.row_slice(1).unwrap();
684 assert_eq!(*row_slice, [4, 5, 6]);
685 unsafe {
686 let row_slice_unchecked = matrix.row_slice_unchecked(2);
687 assert_eq!(*row_slice_unchecked, [7, 8, 9]);
688
689 let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
690 assert_eq!(*row_subslice, [2]);
691 }
692
693 assert!(matrix.row_slice(3).is_none()); }
695
696 #[test]
697 fn test_matrix_rows() {
698 let matrix = MockMatrix {
699 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
700 width: 3,
701 height: 3,
702 };
703
704 let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
705 assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
706 }
707}