1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec::Vec;
8use core::fmt::{
9 Debug,
10 Display,
11 Formatter,
12};
13use core::ops::Deref;
14
15use itertools::{
16 Itertools,
17 izip,
18};
19use lib_q_stark_field::{
20 BasedVectorSpace,
21 ExtensionField,
22 Field,
23 PackedValue,
24 PrimeCharacteristicRing,
25 dot_product,
26};
27use lib_q_stark_rayon::prelude::*;
28use strided::{
29 VerticallyStridedMatrixView,
30 VerticallyStridedRowIndexMap,
31};
32
33use crate::dense::RowMajorMatrix;
34
35pub mod bitrev;
36pub mod dense;
37pub mod extension;
38pub mod horizontally_truncated;
39pub mod row_index_mapped;
40pub mod stack;
41pub mod strided;
42pub mod util;
43
44#[derive(Copy, Clone, PartialEq, Eq)]
49pub struct Dimensions {
50 pub width: usize,
52 pub height: usize,
54}
55
56impl Debug for Dimensions {
57 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
58 write!(f, "{}x{}", self.width, self.height)
59 }
60}
61
62impl Display for Dimensions {
63 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
64 write!(f, "{}x{}", self.width, self.height)
65 }
66}
67
68pub trait Matrix<T: Send + Sync + Clone>: Send + Sync {
74 fn width(&self) -> usize;
76
77 fn height(&self) -> usize;
79
80 fn dimensions(&self) -> Dimensions {
82 Dimensions {
83 width: self.width(),
84 height: self.height(),
85 }
86 }
87
88 #[inline]
99 fn get(&self, r: usize, c: usize) -> Option<T> {
100 (r < self.height() && c < self.width()).then(|| unsafe {
101 self.get_unchecked(r, c)
103 })
104 }
105
106 #[inline]
114 unsafe fn get_unchecked(&self, r: usize, c: usize) -> T {
115 unsafe { self.row_slice_unchecked(r)[c].clone() }
116 }
117
118 #[inline]
124 fn row(
125 &self,
126 r: usize,
127 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
128 (r < self.height()).then(|| unsafe {
129 self.row_unchecked(r)
131 })
132 }
133
134 #[inline]
144 unsafe fn row_unchecked(
145 &self,
146 r: usize,
147 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
148 unsafe { self.row_subseq_unchecked(r, 0, self.width()) }
149 }
150
151 #[inline]
161 unsafe fn row_subseq_unchecked(
162 &self,
163 r: usize,
164 start: usize,
165 end: usize,
166 ) -> impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync> {
167 unsafe {
168 self.row_unchecked(r)
169 .into_iter()
170 .skip(start)
171 .take(end - start)
172 }
173 }
174
175 #[inline]
179 fn row_slice(&self, r: usize) -> Option<impl Deref<Target = [T]>> {
180 (r < self.height()).then(|| unsafe {
181 self.row_slice_unchecked(r)
183 })
184 }
185
186 #[inline]
194 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [T]> {
195 unsafe { self.row_subslice_unchecked(r, 0, self.width()) }
196 }
197
198 #[inline]
208 unsafe fn row_subslice_unchecked(
209 &self,
210 r: usize,
211 start: usize,
212 end: usize,
213 ) -> impl Deref<Target = [T]> {
214 unsafe {
215 self.row_subseq_unchecked(r, start, end)
216 .into_iter()
217 .collect_vec()
218 }
219 }
220
221 #[inline]
223 fn rows(&self) -> impl Iterator<Item = impl Iterator<Item = T>> + Send + Sync {
224 unsafe {
225 (0..self.height()).map(move |r| self.row_unchecked(r).into_iter())
227 }
228 }
229
230 #[inline]
232 fn par_rows(
233 &self,
234 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = T>> + Send + Sync {
235 unsafe {
236 (0..self.height())
238 .into_par_iter()
239 .map(move |r| self.row_unchecked(r).into_iter())
240 }
241 }
242
243 fn wrapping_row_slices(&self, r: usize, c: usize) -> Vec<impl Deref<Target = [T]>> {
246 unsafe {
247 (0..c)
249 .map(|i| self.row_slice_unchecked((r + i) % self.height()))
250 .collect_vec()
251 }
252 }
253
254 #[inline]
258 fn first_row(
259 &self,
260 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
261 self.row(0)
262 }
263
264 #[inline]
268 fn last_row(
269 &self,
270 ) -> Option<impl IntoIterator<Item = T, IntoIter = impl Iterator<Item = T> + Send + Sync>> {
271 if self.height() == 0 {
272 None
273 } else {
274 unsafe { Some(self.row_unchecked(self.height() - 1)) }
276 }
277 }
278
279 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
281 where
282 Self: Sized,
283 T: Clone,
284 {
285 RowMajorMatrix::new(self.rows().flatten().collect(), self.width())
286 }
287
288 fn horizontally_packed_row<'a, P>(
296 &'a self,
297 r: usize,
298 ) -> (
299 impl Iterator<Item = P> + Send + Sync,
300 impl Iterator<Item = T> + Send + Sync,
301 )
302 where
303 P: PackedValue<Value = T>,
304 T: Clone + 'a,
305 {
306 assert!(r < self.height(), "Row index out of bounds.");
307 let num_packed = self.width() / P::WIDTH;
308 unsafe {
309 let mut iter = self
311 .row_subseq_unchecked(r, 0, num_packed * P::WIDTH)
312 .into_iter();
313
314 let packed =
316 (0..num_packed).map(move |_| P::from_fn(|_| iter.next().unwrap_unchecked()));
317
318 let sfx = self
319 .row_subseq_unchecked(r, num_packed * P::WIDTH, self.width())
320 .into_iter();
321 (packed, sfx)
322 }
323 }
324
325 fn padded_horizontally_packed_row<'a, P>(
332 &'a self,
333 r: usize,
334 ) -> impl Iterator<Item = P> + Send + Sync
335 where
336 P: PackedValue<Value = T>,
337 T: Clone + Default + 'a,
338 {
339 let mut row_iter = self.row(r).expect("Row index out of bounds.").into_iter();
340 let num_elems = self.width().div_ceil(P::WIDTH);
341 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
343 }
344
345 fn par_horizontally_packed_rows<'a, P>(
350 &'a self,
351 ) -> impl IndexedParallelIterator<
352 Item = (
353 impl Iterator<Item = P> + Send + Sync,
354 impl Iterator<Item = T> + Send + Sync,
355 ),
356 >
357 where
358 P: PackedValue<Value = T>,
359 T: Clone + 'a,
360 {
361 (0..self.height())
362 .into_par_iter()
363 .map(|r| self.horizontally_packed_row(r))
364 }
365
366 fn par_padded_horizontally_packed_rows<'a, P>(
370 &'a self,
371 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
372 where
373 P: PackedValue<Value = T>,
374 T: Clone + Default + 'a,
375 {
376 (0..self.height())
377 .into_par_iter()
378 .map(|r| self.padded_horizontally_packed_row(r))
379 }
380
381 #[inline]
387 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
388 where
389 T: Copy,
390 P: PackedValue<Value = T>,
391 {
392 let rows = self.wrapping_row_slices(r, P::WIDTH);
394
395 (0..self.width()).map(move |c| P::from_fn(|i| rows[i][c]))
397 }
398
399 #[inline]
407 fn vertically_packed_row_pair<P>(&self, r: usize, step: usize) -> Vec<P>
408 where
409 T: Copy,
410 P: PackedValue<Value = T>,
411 {
412 let rows = self.wrapping_row_slices(r, P::WIDTH);
417 let next_rows = self.wrapping_row_slices(r + step, P::WIDTH);
418
419 (0..self.width())
420 .map(|c| P::from_fn(|i| rows[i][c]))
421 .chain((0..self.width()).map(|c| P::from_fn(|i| next_rows[i][c])))
422 .collect_vec()
423 }
424
425 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
429 where
430 Self: Sized,
431 {
432 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
433 }
434
435 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
439 where
440 T: Field,
441 EF: ExtensionField<T>,
442 {
443 let packed_width = self.width().div_ceil(T::Packing::WIDTH);
444
445 let packed_result = self
446 .par_padded_horizontally_packed_rows::<T::Packing>()
447 .zip(v)
448 .par_fold_reduce(
449 || EF::ExtensionPacking::zero_vec(packed_width),
450 |mut acc, (row, &scale)| {
451 let scale = EF::ExtensionPacking::from_basis_coefficients_fn(|i| {
452 T::Packing::from(scale.as_basis_coefficients_slice()[i])
453 });
454 izip!(&mut acc, row).for_each(|(l, r)| *l += scale * r);
455 acc
456 },
457 |mut acc_l, acc_r| {
458 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
459 acc_l
460 },
461 );
462
463 packed_result
464 .into_iter()
465 .flat_map(|p| {
466 (0..T::Packing::WIDTH).map(move |i| {
467 EF::from_basis_coefficients_fn(|j| {
468 p.as_basis_coefficients_slice()[j].as_slice()[i]
469 })
470 })
471 })
472 .take(self.width())
473 .collect()
474 }
475
476 fn rowwise_packed_dot_product<EF>(
486 &self,
487 vec: &[EF::ExtensionPacking],
488 ) -> impl IndexedParallelIterator<Item = EF>
489 where
490 T: Field,
491 EF: ExtensionField<T>,
492 {
493 assert!(vec.len() >= self.width().div_ceil(T::Packing::WIDTH));
495
496 self.par_padded_horizontally_packed_rows::<T::Packing>()
500 .map(move |row_packed| {
501 let packed_sum_of_packed: EF::ExtensionPacking =
502 dot_product(vec.iter().copied(), row_packed);
503 let sum_of_packed: EF = EF::from_basis_coefficients_fn(|i| {
504 packed_sum_of_packed.as_basis_coefficients_slice()[i]
505 .as_slice()
506 .iter()
507 .copied()
508 .sum()
509 });
510 sum_of_packed
511 })
512 }
513}
514
515#[cfg(any())]
516mod tests {
517 use alloc::vec::Vec;
518 use alloc::{
519 format,
520 vec,
521 };
522
523 use itertools::izip;
524 use lib_q_stark_field::PrimeCharacteristicRing;
525 use lib_q_stark_field::extension::{
526 BinomialExtensionField,
527 Complex,
528 };
529 use lib_q_stark_mersenne31::Mersenne31;
530 use rand::SeedableRng;
531 use rand::rngs::SmallRng;
532
533 use super::*;
534
535 #[test]
536 fn test_columnwise_dot_product() {
537 type F = Complex<Mersenne31>;
539 type EF = BinomialExtensionField<Complex<Mersenne31>, 2>;
541
542 let mut rng = SmallRng::seed_from_u64(1);
543 let m = RowMajorMatrix::<F>::rand(&mut rng, 1 << 8, 1 << 4);
544 let v = RowMajorMatrix::<EF>::rand(&mut rng, 1 << 8, 1).values;
545
546 let mut expected = vec![EF::ZERO; m.width()];
547 for (row, &scale) in izip!(m.rows(), &v) {
548 for (l, r) in izip!(&mut expected, row) {
549 let r_ext = EF::from(r);
551 *l += scale * r_ext;
552 }
553 }
554
555 assert_eq!(m.columnwise_dot_product(&v), expected);
556 }
557
558 struct MockMatrix {
560 data: Vec<Vec<u32>>,
561 width: usize,
562 height: usize,
563 }
564
565 impl Matrix<u32> for MockMatrix {
566 fn width(&self) -> usize {
567 self.width
568 }
569
570 fn height(&self) -> usize {
571 self.height
572 }
573
574 unsafe fn row_unchecked(
575 &self,
576 r: usize,
577 ) -> impl IntoIterator<Item = u32, IntoIter = impl Iterator<Item = u32> + Send + Sync>
578 {
579 self.data[r].clone()
581 }
582 }
583
584 #[test]
585 fn test_dimensions() {
586 let dims = Dimensions {
587 width: 3,
588 height: 5,
589 };
590 assert_eq!(dims.width, 3);
591 assert_eq!(dims.height, 5);
592 assert_eq!(format!("{dims:?}"), "3x5");
593 assert_eq!(format!("{dims}"), "3x5");
594 }
595
596 #[test]
597 fn test_mock_matrix_dimensions() {
598 let matrix = MockMatrix {
599 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
600 width: 3,
601 height: 3,
602 };
603 assert_eq!(matrix.width(), 3);
604 assert_eq!(matrix.height(), 3);
605 assert_eq!(
606 matrix.dimensions(),
607 Dimensions {
608 width: 3,
609 height: 3
610 }
611 );
612 }
613
614 #[test]
615 fn test_first_row() {
616 let matrix = MockMatrix {
617 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
618 width: 3,
619 height: 3,
620 };
621 let mut first_row = matrix.first_row().unwrap().into_iter();
622 assert_eq!(first_row.next(), Some(1));
623 assert_eq!(first_row.next(), Some(2));
624 assert_eq!(first_row.next(), Some(3));
625 }
626
627 #[test]
628 fn test_last_row() {
629 let matrix = MockMatrix {
630 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
631 width: 3,
632 height: 3,
633 };
634 let mut last_row = matrix.last_row().unwrap().into_iter();
635 assert_eq!(last_row.next(), Some(7));
636 assert_eq!(last_row.next(), Some(8));
637 assert_eq!(last_row.next(), Some(9));
638 }
639
640 #[test]
641 fn test_first_last_row_empty_matrix() {
642 let matrix = MockMatrix {
643 data: vec![],
644 width: 3,
645 height: 0,
646 };
647 let first_row = matrix.first_row();
648 let last_row = matrix.last_row();
649 assert!(first_row.is_none());
650 assert!(last_row.is_none());
651 }
652
653 #[test]
654 fn test_to_row_major_matrix() {
655 let matrix = MockMatrix {
656 data: vec![vec![1, 2], vec![3, 4]],
657 width: 2,
658 height: 2,
659 };
660 let row_major = matrix.to_row_major_matrix();
661 assert_eq!(row_major.values, vec![1, 2, 3, 4]);
662 assert_eq!(row_major.width, 2);
663 }
664
665 #[test]
666 fn test_matrix_get_methods() {
667 let matrix = MockMatrix {
668 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
669 width: 3,
670 height: 3,
671 };
672 assert_eq!(matrix.get(0, 0), Some(1));
673 assert_eq!(matrix.get(1, 2), Some(6));
674 assert_eq!(matrix.get(2, 1), Some(8));
675
676 unsafe {
677 assert_eq!(matrix.get_unchecked(0, 1), 2);
678 assert_eq!(matrix.get_unchecked(1, 0), 4);
679 assert_eq!(matrix.get_unchecked(2, 2), 9);
680 }
681
682 assert_eq!(matrix.get(3, 0), None); assert_eq!(matrix.get(0, 3), None); }
685
686 #[test]
687 fn test_matrix_row_methods_iteration() {
688 let matrix = MockMatrix {
689 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
690 width: 3,
691 height: 3,
692 };
693
694 let mut row_iter = matrix.row(1).unwrap().into_iter();
695 assert_eq!(row_iter.next(), Some(4));
696 assert_eq!(row_iter.next(), Some(5));
697 assert_eq!(row_iter.next(), Some(6));
698 assert_eq!(row_iter.next(), None);
699
700 unsafe {
701 let mut row_iter_unchecked = matrix.row_unchecked(2).into_iter();
702 assert_eq!(row_iter_unchecked.next(), Some(7));
703 assert_eq!(row_iter_unchecked.next(), Some(8));
704 assert_eq!(row_iter_unchecked.next(), Some(9));
705 assert_eq!(row_iter_unchecked.next(), None);
706
707 let mut row_iter_subset = matrix.row_subseq_unchecked(0, 1, 3).into_iter();
708 assert_eq!(row_iter_subset.next(), Some(2));
709 assert_eq!(row_iter_subset.next(), Some(3));
710 assert_eq!(row_iter_subset.next(), None);
711 }
712
713 assert!(matrix.row(3).is_none()); }
715
716 #[test]
717 fn test_row_slice_methods() {
718 let matrix = MockMatrix {
719 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
720 width: 3,
721 height: 3,
722 };
723 let row_slice = matrix.row_slice(1).unwrap();
724 assert_eq!(*row_slice, [4, 5, 6]);
725 unsafe {
726 let row_slice_unchecked = matrix.row_slice_unchecked(2);
727 assert_eq!(*row_slice_unchecked, [7, 8, 9]);
728
729 let row_subslice = matrix.row_subslice_unchecked(0, 1, 2);
730 assert_eq!(*row_subslice, [2]);
731 }
732
733 assert!(matrix.row_slice(3).is_none()); }
735
736 #[test]
737 fn test_matrix_rows() {
738 let matrix = MockMatrix {
739 data: vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]],
740 width: 3,
741 height: 3,
742 };
743
744 let all_rows: Vec<Vec<u32>> = matrix.rows().map(|row| row.collect()).collect();
745 assert_eq!(all_rows, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]);
746 }
747}