1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{Debug, Display, Formatter};
9use core::ops::Deref;
10
11use itertools::{Itertools, izip};
12use p3_field::{
13 BasedVectorSpace, ExtensionField, Field, PackedValue, PrimeCharacteristicRing, dot_product,
14};
15use p3_maybe_rayon::prelude::*;
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17use tracing::instrument;
18
19use crate::dense::RowMajorMatrix;
20
21pub mod bitrev;
22pub mod dense;
23pub mod extension;
24pub mod horizontally_truncated;
25pub mod row_index_mapped;
26pub mod stack;
27pub mod strided;
28pub mod util;
29
30#[derive(Copy, Clone, PartialEq, Eq)]
35pub struct Dimensions {
36 pub width: usize,
38 pub height: usize,
40}
41
42impl Debug for Dimensions {
43 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
44 write!(f, "{}x{}", self.width, self.height)
45 }
46}
47
48impl Display for Dimensions {
49 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
50 write!(f, "{}x{}", self.width, self.height)
51 }
52}
53
54pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
60 fn width(&self) -> usize;
62
63 fn height(&self) -> usize;
65
66 fn dimensions(&self) -> Dimensions {
68 Dimensions {
69 width: self.width(),
70 height: self.height(),
71 }
72 }
73
74 #[inline]
85 fn get(&self, r: usize, c: usize) -> Option<T> {
86 (r < self.height() && c < self.width()).then(|| unsafe {
87 self.get_unchecked(r, c)
89 })
90 }
91
92 #[inline]
100 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
101 unsafe { self.row_slice_unchecked(r)[c].clone() }
102 }
103
104 #[inline]
110 fn row(
111 &self,
112 r: usize,
113 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
114 (r < self.height()).then(|| unsafe {
115 self.row_unchecked(r)
117 })
118 }
119
120 #[inline]
130 unsafe fn row_unchecked(
131 &self,
132 r: usize,
133 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
134 unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
135 }
136
137 #[inline]
147 unsafe fn row_subseq_unchecked(
148 &self,
149 r: usize,
150 start: usize,
151 end: usize,
152 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
153 unsafe {
154 self.row_unchecked(r)
155 .into_iter()
156 .skip(start)
157 .take(end - start)
158 }
159 }
160
161 #[inline]
165 fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
166 (r < self.height()).then(|| unsafe {
167 self.row_slice_unchecked(r)
169 })
170 }
171
172 #[inline]
180 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
181 unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
182 }
183
184 #[inline]
194 unsafe fn row_subslice_unchecked(
195 &self,
196 r: usize,
197 start: usize,
198 end: usize,
199 ) -> impl Deref<Target = [T]> {
200 unsafe {
201 self.row_subseq_unchecked(r, start, end)
202 .into_iter()
203 .collect_vec()
204 }
205 }
206
207 #[inline]
209 fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
210 unsafe {
211 (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
213 }
214 }
215
216 #[inline]
218 fn par_rows(
219 &self,
220 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
221 unsafe {
222 (0..self.height())
224 .into_par_iter()
225 .map(move |r| self.row_unchecked(r).into_iter())
226 }
227 }
228
229 fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
232 unsafe {
233 (0..c)
235 .map(|i| self.row_slice_unchecked((r + i) % self.height()))
236 .collect_vec()
237 }
238 }
239
240 #[inline]
244 fn first_row(
245 &self,
246 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
247 self.row(0)
248 }
249
250 #[inline]
254 fn last_row(
255 &self,
256 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
257 if self.height() == 0 {
258 None
259 } else {
260 unsafe { Some(self.row_unchecked(self.height() - 1)) }
262 }
263 }
264
265 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
267 where
268 Self: Sized,
269 T: Clone,
270 {
271 RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
272 }
273
274 fn horizontally_packed_row<'a, P>(
282 &'a self,
283 r: usize,
284 ) -> (
285 impl Iterator<Item = P> + Send + Sync,
286 impl Iterator<Item = T> + Send + Sync,
287 )
288 where
289 P: PackedValue<Value = T>,
290 T: Clone + 'a,
291 {
292 assert!(r < self.height(), "Row index out of bounds.");
293 let num_packed = self.width() / P::WIDTH;
294 unsafe {
295 let mut iter = self
297 .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
298 .into_iter();
299
300 let packed =
302 (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
303
304 let sfx = self
305 .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
306 .into_iter();
307 (packed, sfx)
308 }
309 }
310
311 fn padded_horizontally_packed_row<'a, P>(
318 &'a self,
319 r: usize,
320 ) -> impl Iterator<Item = P> + Send + Sync
321 where
322 P: PackedValue<Value = T>,
323 T: Clone + Default + 'a,
324 {
325 let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
326 let num_elems = self.width().div_ceil(P::WIDTH);
327 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
329 }
330
331 fn par_horizontally_packed_rows<'a, P>(
336 &'a self,
337 ) -> impl IndexedParallelIterator<
338 Item = (
339 impl Iterator<Item = P> + Send + Sync,
340 impl Iterator<Item = T> + Send + Sync,
341 ),
342 >
343 where
344 P: PackedValue<Value = T>,
345 T: Clone + 'a,
346 {
347 (0..self.height())
348 .into_par_iter()
349 .map(|r| self.horizontally_packed_row(r))
350 }
351
352 fn par_padded_horizontally_packed_rows<'a, P>(
356 &'a self,
357 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
358 where
359 P: PackedValue<Value = T>,
360 T: Clone + Default + 'a,
361 {
362 (0..self.height())
363 .into_par_iter()
364 .map(|r| self.padded_horizontally_packed_row(r))
365 }
366
367 #[inline]
373 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
374 where
375 T: Copy,
376 P: PackedValue<Value = T>,
377 {
378 let rows = self.wrapping_row_slices(r, P::WIDTH);
380
381 (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
383 }
384
385 #[inline]
393 fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
394 where
395 T: Copy,
396 P: PackedValue<Value = T>,
397 {
398 let rows = self.wrapping_row_slices(r, P::WIDTH);
403 let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
404
405 (0..self.width())
406 .map(|c| P::from_fn(|i| rows[i][c]))
407 .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
408 .collect_vec()
409 }
410
411 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
415 where
416 Self: Sized,
417 {
418 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
419 }
420
421 #[instrument(level = "debug", skip_all, fields(dims = %self.dimensions()))]
425 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
426 where
427 T: Field,
428 EF: ExtensionField<T>,
429 {
430 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
431
432 let packed_result = self
433 .par_padded_horizontally_packed_rows::<T::Packing>()
434 .zip(v)
435 .par_fold_reduce(
436 || EF::ExtensionPacking::zero_vec(packed_width),
437 |mut acc, (row, &scale)| {
438 let scale = EF::ExtensionPacking::from_basis_coefficients_fn(|i| {
439 T::Packing::from(scale.as_basis_coefficients_slice()[i])
440 });
441 izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
442 acc
443 },
444 |mut acc_l, acc_r| {
445 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
446 acc_l
447 },
448 );
449
450 packed_result
451 .into_iter()
452 .flat_map(|p| {
453 (0..T::Packing::WIDTH).map(move |i| {
454 EF::from_basis_coefficients_fn(|j| {
455 p.as_basis_coefficients_slice()[j].as_slice()[i]
456 })
457 })
458 })
459 .take(self.width())
460 .collect()
461 }
462
463 fn rowwise_packed_dot_product<EF>(
473 &self,
474 vec: &[EF::ExtensionPacking],
475 ) -> impl IndexedParallelIterator<Item = EF>
476 where
477 T: Field,
478 EF: ExtensionField<T>,
479 {
480 assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
482
483 self.par_padded_horizontally_packed_rows::<T::Packing>()
487 .map(move |row_packed| {
488 let packed_sum_of_packed: EF::ExtensionPacking =
489 dot_product(vec.iter().copied(), row_packed);
490 let sum_of_packed: EF = EF::from_basis_coefficients_fn(|i| {
491 packed_sum_of_packed.as_basis_coefficients_slice()[i]
492 .as_slice()
493 .iter()
494 .copied()
495 .sum()
496 });
497 sum_of_packed
498 })
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use alloc::vec::Vec;
505 use alloc::{format, vec};
506
507 use itertools::izip;
508 use p3_baby_bear::BabyBear;
509 use p3_field::PrimeCharacteristicRing;
510 use p3_field::extension::BinomialExtensionField;
511 use rand::SeedableRng;
512 use rand::rngs::SmallRng;
513
514 use super::*;
515
516 #[test]
517 fn test_columnwise_dot_product() {
518 type F = BabyBear;
519 type EF = BinomialExtensionField<BabyBear, 4>;
520
521 let mut rng = SmallRng::seed_from_u64(1);
522 let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
523 let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
524
525 let mut expected = vec![EF::ZERO; m.width()];
526 for (row, &scale) in izip!(m.rows(), &v) {
527 for (l, r) in izip!(&mut expected, row) {
528 *l += scale * r;
529 }
530 }
531
532 assert_eq!(m.columnwise_dot_product(&v), expected);
533 }
534
535 struct MockMatrix {
537 data: Vec<Vec<u32>>,
538 width: usize,
539 height: usize,
540 }
541
542 impl Matrix<u32> for MockMatrix {
543 fn width(&self) -> usize {
544 self.width
545 }
546
547 fn height(&self) -> usize {
548 self.height
549 }
550
551 unsafe fn row_unchecked(
552 &self,
553 r: usize,
554 ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
555 {
556 self.data[r].clone()
558 }
559 }
560
561 #[test]
562 fn test_dimensions() {
563 let dims = Dimensions {
564 width: 3,
565 height: 5,
566 };
567 assert_eq!(dims.width, 3);
568 assert_eq!(dims.height, 5);
569 assert_eq!(format!("{:?}", dims), "3x5");
570 assert_eq!(format!("{}", dims), "3x5");
571 }
572
573 #[test]
574 fn test_mock_matrix_dimensions() {
575 let matrix = MockMatrix {
576 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
577 width: 3,
578 height: 3,
579 };
580 assert_eq!(matrix.width(), 3);
581 assert_eq!(matrix.height(), 3);
582 assert_eq!(
583 matrix.dimensions(),
584 Dimensions {
585 width: 3,
586 height: 3
587 }
588 );
589 }
590
591 #[test]
592 fn test_first_row() {
593 let matrix = MockMatrix {
594 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
595 width: 3,
596 height: 3,
597 };
598 let mut first_row = matrix.first_row().unwrap().into_iter();
599 assert_eq!(first_row.next(), Some(1));
600 assert_eq!(first_row.next(), Some(2));
601 assert_eq!(first_row.next(), Some(3));
602 }
603
604 #[test]
605 fn test_last_row() {
606 let matrix = MockMatrix {
607 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
608 width: 3,
609 height: 3,
610 };
611 let mut last_row = matrix.last_row().unwrap().into_iter();
612 assert_eq!(last_row.next(), Some(7));
613 assert_eq!(last_row.next(), Some(8));
614 assert_eq!(last_row.next(), Some(9));
615 }
616
617 #[test]
618 fn test_first_last_row_empty_matrix() {
619 let matrix = MockMatrix {
620 data: vec![],
621 width: 3,
622 height: 0,
623 };
624 let first_row = matrix.first_row();
625 let last_row = matrix.last_row();
626 assert!(first_row.is_none());
627 assert!(last_row.is_none());
628 }
629
630 #[test]
631 fn test_to_row_major_matrix() {
632 let matrix = MockMatrix {
633 data: vec![vec![1, 2], vec![3, 4]],
634 width: 2,
635 height: 2,
636 };
637 let row_major = matrix.to_row_major_matrix();
638 assert_eq!(row_major.values, vec![1, 2, 3, 4]);
639 assert_eq!(row_major.width, 2);
640 }
641
642 #[test]
643 fn test_matrix_get_methods() {
644 let matrix = MockMatrix {
645 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
646 width: 3,
647 height: 3,
648 };
649 assert_eq!(matrix.get(0, 0), Some(1));
650 assert_eq!(matrix.get(1, 2), Some(6));
651 assert_eq!(matrix.get(2, 1), Some(8));
652
653 unsafe {
654 assert_eq!(matrix.get_unchecked(0, 1), 2);
655 assert_eq!(matrix.get_unchecked(1, 0), 4);
656 assert_eq!(matrix.get_unchecked(2, 2), 9);
657 }
658
659 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 3), None); }
662
663 #[test]
664 fn test_matrix_row_methods_iteration() {
665 let matrix = MockMatrix {
666 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
667 width: 3,
668 height: 3,
669 };
670
671 let mut row_iter = matrix.row(1).unwrap().into_iter();
672 assert_eq!(row_iter.next(), Some(4));
673 assert_eq!(row_iter.next(), Some(5));
674 assert_eq!(row_iter.next(), Some(6));
675 assert_eq!(row_iter.next(), None);
676
677 unsafe {
678 let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
679 assert_eq!(row_iter_unchecked.next(), Some(7));
680 assert_eq!(row_iter_unchecked.next(), Some(8));
681 assert_eq!(row_iter_unchecked.next(), Some(9));
682 assert_eq!(row_iter_unchecked.next(), None);
683
684 let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
685 assert_eq!(row_iter_subset.next(), Some(2));
686 assert_eq!(row_iter_subset.next(), Some(3));
687 assert_eq!(row_iter_subset.next(), None);
688 }
689
690 assert!(matrix.row(3).is_none()); }
692
693 #[test]
694 fn test_row_slice_methods() {
695 let matrix = MockMatrix {
696 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
697 width: 3,
698 height: 3,
699 };
700 let row_slice = matrix.row_slice(1).unwrap();
701 assert_eq!(*row_slice, [4, 5, 6]);
702 unsafe {
703 let row_slice_unchecked = matrix.row_slice_unchecked(2);
704 assert_eq!(*row_slice_unchecked, [7, 8, 9]);
705
706 let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
707 assert_eq!(*row_subslice, [2]);
708 }
709
710 assert!(matrix.row_slice(3).is_none()); }
712
713 #[test]
714 fn test_matrix_rows() {
715 let matrix = MockMatrix {
716 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
717 width: 3,
718 height: 3,
719 };
720
721 let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
722 assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
723 }
724}