1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::ops::Deref;
11
12use itertools::{izip, Itertools};
13use p3_field::{ExtensionField, Field, PackedValue};
14use p3_maybe_rayon::prelude::*;
15use serde::{Deserialize, Serialize};
16use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
17
18use crate::dense::RowMajorMatrix;
19
20pub mod bitrev;
21pub mod dense;
22pub mod extension;
23pub mod mul;
24pub mod row_index_mapped;
25pub mod sparse;
26pub mod stack;
27pub mod strided;
28pub mod util;
29
30#[derive(Clone, Copy, Serialize, Deserialize)]
31pub struct Dimensions {
32 pub width: usize,
33 pub height: usize,
34}
35
36impl Debug for Dimensions {
37 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
38 write!(f, "{}x{}", self.width, self.height)
39 }
40}
41
42impl Display for Dimensions {
43 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
44 write!(f, "{}x{}", self.width, self.height)
45 }
46}
47
48pub trait Matrix<T: Send + Sync>: Send + Sync {
49 fn width(&self) -> usize;
50 fn height(&self) -> usize;
51
52 fn dimensions(&self) -> Dimensions {
53 Dimensions {
54 width: self.width(),
55 height: self.height(),
56 }
57 }
58
59 fn get(&self, r: usize, c: usize) -> T {
60 self.row(r).nth(c).unwrap()
61 }
62
63 type Row<'a>: Iterator<Item = T> + Send + Sync
64 where
65 Self: 'a;
66
67 fn row(&self, r: usize) -> Self::Row<'_>;
68
69 fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
70 (0..self.height()).map(move |r| self.row(r))
71 }
72
73 fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
74 (0..self.height()).into_par_iter().map(move |r| self.row(r))
75 }
76
77 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
79 self.row(r).collect_vec()
80 }
81
82 fn first_row(&self) -> Self::Row<'_> {
83 self.row(0)
84 }
85
86 fn last_row(&self) -> Self::Row<'_> {
87 self.row(self.height() - 1)
88 }
89
90 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
91 where
92 Self: Sized,
93 T: Clone,
94 {
95 RowMajorMatrix::new(
96 (0..self.height()).flat_map(|r| self.row(r)).collect(),
97 self.width(),
98 )
99 }
100
101 fn horizontally_packed_row<'a, P>(
102 &'a self,
103 r: usize,
104 ) -> (impl Iterator<Item = P>, impl Iterator<Item = T>)
105 where
106 P: PackedValue<Value = T>,
107 T: Clone + 'a,
108 {
109 let num_packed = self.width() / P::WIDTH;
110 let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
111 let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
112 (packed, sfx)
113 }
114
115 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
117 where
118 P: PackedValue<Value = T>,
119 {
120 (0..self.width()).map(move |c| P::from_fn(|i| self.get((r + i) % self.height(), c)))
121 }
122
123 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
124 where
125 Self: Sized,
126 {
127 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
128 }
129
130 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
134 where
135 T: Field,
136 EF: ExtensionField<T>,
137 {
138 self.par_rows().zip(v).par_fold_reduce(
139 || vec![EF::zero(); self.width()],
140 |mut acc, (row, &scale)| {
141 izip!(&mut acc, row).for_each(|(a, x)| *a += scale * x);
142 acc
143 },
144 |mut acc_l, acc_r| {
145 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
146 acc_l
147 },
148 )
149 }
150}