1#![no_std]
4
5extern crate alloc;
6
7use alloc::vec;
8use alloc::vec::Vec;
9use core::fmt::{Debug, Display, Formatter};
10use core::ops::Deref;
11
12use itertools::{izip, Itertools};
13use p3_field::{dot_product, AbstractExtensionField, ExtensionField, Field, PackedValue};
14use p3_maybe_rayon::prelude::*;
15use strided::{VerticallyStridedMatrixView, VerticallyStridedRowIndexMap};
16
17use crate::dense::RowMajorMatrix;
18
19pub mod bitrev;
20pub mod dense;
21pub mod extension;
22pub mod mul;
23pub mod row_index_mapped;
24pub mod sparse;
25pub mod stack;
26pub mod strided;
27pub mod util;
28
29#[derive(Clone, Copy)]
30pub struct Dimensions {
31 pub width: usize,
32 pub height: usize,
33}
34
35impl Debug for Dimensions {
36 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
37 write!(f, "{}x{}", self.width, self.height)
38 }
39}
40
41impl Display for Dimensions {
42 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
43 write!(f, "{}x{}", self.width, self.height)
44 }
45}
46
47pub trait Matrix<T: Send + Sync>: Send + Sync {
48 fn width(&self) -> usize;
49 fn height(&self) -> usize;
50
51 fn dimensions(&self) -> Dimensions {
52 Dimensions {
53 width: self.width(),
54 height: self.height(),
55 }
56 }
57
58 fn get(&self, r: usize, c: usize) -> T {
59 self.row(r).nth(c).unwrap()
60 }
61
62 type Row<'a>: Iterator<Item = T> + Send + Sync
63 where
64 Self: 'a;
65
66 fn row(&self, r: usize) -> Self::Row<'_>;
67
68 fn rows(&self) -> impl Iterator<Item = Self::Row<'_>> {
69 (0..self.height()).map(move |r| self.row(r))
70 }
71
72 fn par_rows(&self) -> impl IndexedParallelIterator<Item = Self::Row<'_>> {
73 (0..self.height()).into_par_iter().map(move |r| self.row(r))
74 }
75
76 fn row_slice(&self, r: usize) -> impl Deref<Target = [T]> {
78 self.row(r).collect_vec()
79 }
80
81 fn first_row(&self) -> Self::Row<'_> {
82 self.row(0)
83 }
84
85 fn last_row(&self) -> Self::Row<'_> {
86 self.row(self.height() - 1)
87 }
88
89 fn to_row_major_matrix(self) -> RowMajorMatrix<T>
90 where
91 Self: Sized,
92 T: Clone,
93 {
94 RowMajorMatrix::new(
95 (0..self.height()).flat_map(|r| self.row(r)).collect(),
96 self.width(),
97 )
98 }
99
100 fn horizontally_packed_row<'a, P>(
101 &'a self,
102 r: usize,
103 ) -> (
104 impl Iterator<Item = P> + Send + Sync,
105 impl Iterator<Item = T> + Send + Sync,
106 )
107 where
108 P: PackedValue<Value = T>,
109 T: Clone + 'a,
110 {
111 let num_packed = self.width() / P::WIDTH;
112 let packed = (0..num_packed).map(move |c| P::from_fn(|i| self.get(r, P::WIDTH * c + i)));
113 let sfx = (num_packed * P::WIDTH..self.width()).map(move |c| self.get(r, c));
114 (packed, sfx)
115 }
116
117 fn padded_horizontally_packed_row<'a, P>(
119 &'a self,
120 r: usize,
121 ) -> impl Iterator<Item = P> + Send + Sync
122 where
123 P: PackedValue<Value = T>,
124 T: Clone + Default + 'a,
125 {
126 let mut row_iter = self.row(r);
127 let num_elems = self.width().next_multiple_of(P::WIDTH);
128 (0..num_elems).map(move |_| P::from_fn(|_| row_iter.next().unwrap_or_default()))
130 }
131
132 fn par_horizontally_packed_rows<'a, P>(
133 &'a self,
134 ) -> impl IndexedParallelIterator<
135 Item = (
136 impl Iterator<Item = P> + Send + Sync,
137 impl Iterator<Item = T> + Send + Sync,
138 ),
139 >
140 where
141 P: PackedValue<Value = T>,
142 T: Clone + 'a,
143 {
144 (0..self.height())
145 .into_par_iter()
146 .map(|r| self.horizontally_packed_row(r))
147 }
148
149 fn par_padded_horizontally_packed_rows<'a, P>(
150 &'a self,
151 ) -> impl IndexedParallelIterator<Item = impl Iterator<Item = P> + Send + Sync>
152 where
153 P: PackedValue<Value = T>,
154 T: Clone + Default + 'a,
155 {
156 (0..self.height())
157 .into_par_iter()
158 .map(|r| self.padded_horizontally_packed_row(r))
159 }
160
161 fn vertically_packed_row<P>(&self, r: usize) -> impl Iterator<Item = P>
163 where
164 P: PackedValue<Value = T>,
165 {
166 (0..self.width()).map(move |c| P::from_fn(|i| self.get((r + i) % self.height(), c)))
167 }
168
169 fn vertically_strided(self, stride: usize, offset: usize) -> VerticallyStridedMatrixView<Self>
170 where
171 Self: Sized,
172 {
173 VerticallyStridedRowIndexMap::new_view(self, stride, offset)
174 }
175
176 fn columnwise_dot_product<EF>(&self, v: &[EF]) -> Vec<EF>
180 where
181 T: Field,
182 EF: ExtensionField<T>,
183 {
184 self.par_rows().zip(v).par_fold_reduce(
185 || vec![EF::zero(); self.width()],
186 |mut acc, (row, &scale)| {
187 izip!(&mut acc, row).for_each(|(a, x)| *a += scale * x);
188 acc
189 },
190 |mut acc_l, acc_r| {
191 izip!(&mut acc_l, acc_r).for_each(|(l, r)| *l += r);
192 acc_l
193 },
194 )
195 }
196
197 fn dot_ext_powers<EF>(&self, base: EF) -> impl IndexedParallelIterator<Item = EF>
199 where
200 T: Field,
201 EF: ExtensionField<T>,
202 {
203 let powers_packed = base
204 .ext_powers_packed()
205 .take(self.width().next_multiple_of(T::Packing::WIDTH))
206 .collect_vec();
207 self.par_padded_horizontally_packed_rows::<T::Packing>()
208 .map(move |row_packed| {
209 let packed_sum_of_packed: EF::ExtensionPacking =
210 dot_product(powers_packed.iter().copied(), row_packed);
211 let sum_of_packed: EF = EF::from_base_fn(|i| {
212 packed_sum_of_packed.as_base_slice()[i]
213 .as_slice()
214 .iter()
215 .copied()
216 .sum()
217 });
218 sum_of_packed
219 })
220 }
221}