1use alloc::vec::Vec;
2use core::iter;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use p3_field::{ExtensionField, Field};
7
8use crate::Matrix;
9use crate::bitrev::BitReversibleMatrix;
10
11#[derive(Debug)]
17pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
18
19impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
20 pub const fn new(inner: Inner) -> Self {
21 Self(inner, PhantomData)
22 }
23}
24
25impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
26 type Target = Inner;
27
28 fn deref(&self) -> &Self::Target {
29 &self.0
30 }
31}
32
33impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
34where
35 F: Field,
36 EF: ExtensionField<F>,
37 Inner: Matrix<EF>,
38{
39 fn width(&self) -> usize {
40 self.0.width() * EF::DIMENSION
41 }
42
43 fn height(&self) -> usize {
44 self.0.height()
45 }
46
47 unsafe fn get_unchecked(&self, r: usize, c: usize) -> F {
48 let c_inner = c / EF::DIMENSION;
51 let inner = unsafe {
52 self.0.get_unchecked(r, c_inner)
55 };
56 inner.as_basis_coefficients_slice()[c % EF::DIMENSION]
57 }
58
59 unsafe fn row_unchecked(
60 &self,
61 r: usize,
62 ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
63 unsafe {
64 FlatIter {
66 inner: self.0.row_unchecked(r).into_iter().peekable(),
67 idx: 0,
68 _phantom: PhantomData,
69 }
70 }
71 }
72
73 unsafe fn row_subseq_unchecked(
74 &self,
75 r: usize,
76 start: usize,
77 end: usize,
78 ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
79 let len = end - start;
81 let inner_start = start / EF::DIMENSION;
82 unsafe {
83 FlatIter {
85 inner: self
86 .0
87 .row_subseq_unchecked(r, inner_start, self.0.width())
90 .into_iter()
91 .peekable(),
92 idx: start % EF::DIMENSION,
93 _phantom: PhantomData,
94 }
95 .take(len)
96 }
97 }
98
99 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [F]> {
100 unsafe {
101 self.0
103 .row_slice_unchecked(r)
104 .iter()
105 .flat_map(|val| val.as_basis_coefficients_slice())
106 .copied()
107 .collect::<Vec<_>>()
108 }
109 }
110}
111
112pub struct FlatIter<F, I: Iterator> {
113 inner: iter::Peekable<I>,
114 idx: usize,
115 _phantom: PhantomData<F>,
116}
117
118impl<F, EF, I> Iterator for FlatIter<F, I>
119where
120 F: Field,
121 EF: ExtensionField<F>,
122 I: Iterator<Item = EF>,
123{
124 type Item = F;
125 fn next(&mut self) -> Option<Self::Item> {
126 if self.idx == EF::DIMENSION {
127 self.idx = 0;
128 self.inner.next();
129 }
130 let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx];
131 self.idx += 1;
132 Some(value)
133 }
134}
135
136impl<F, EF, Inner> BitReversibleMatrix<F> for FlatMatrixView<F, EF, Inner>
137where
138 F: Field,
139 EF: ExtensionField<F>,
140 Inner: BitReversibleMatrix<EF>,
141{
142 type BitRev = FlatMatrixView<F, EF, Inner::BitRev>;
143
144 fn bit_reverse_rows(self) -> Self::BitRev {
145 FlatMatrixView::new(self.0.bit_reverse_rows())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use alloc::vec;
152
153 use itertools::Itertools;
154 use p3_field::extension::Complex;
155 use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
156 use p3_mersenne_31::Mersenne31;
157
158 use super::*;
159 use crate::dense::RowMajorMatrix;
160 type F = Mersenne31;
161 type EF = Complex<Mersenne31>;
162
163 #[test]
164 fn flat_matrix() {
165 let values = vec![
166 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
167 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 20)),
168 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 30)),
169 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 40)),
170 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 50)),
171 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 60)),
172 ];
173 let ext = RowMajorMatrix::<EF>::new(values, 2);
174 let flat = FlatMatrixView::<F, EF, _>::new(ext);
175
176 assert_eq!(flat.width(), 4);
177 assert_eq!(flat.height(), 3);
178
179 assert_eq!(flat.get(0, 2), Some(F::from_u8(20)));
180 assert_eq!(flat.get(1, 3), Some(F::from_u8(41)));
181 assert_eq!(flat.get(2, 0), Some(F::from_u8(50)));
182
183 unsafe {
184 assert_eq!(flat.get_unchecked(0, 1), F::from_u8(11));
185 assert_eq!(flat.get_unchecked(1, 0), F::from_u8(30));
186 assert_eq!(flat.get_unchecked(2, 2), F::from_u8(60));
187 }
188
189 assert_eq!(
190 &*flat.row_slice(0).unwrap(),
191 &[10, 11, 20, 21].map(F::from_u8)
192 );
193 unsafe {
194 assert_eq!(
195 &*flat.row_slice_unchecked(1),
196 &[30, 31, 40, 41].map(F::from_u8)
197 );
198 assert_eq!(
199 &*flat.row_subslice_unchecked(2, 0, 3),
200 &[50, 51, 60].map(F::from_u8)
201 );
202 }
203
204 assert_eq!(
205 flat.row(2).unwrap().into_iter().collect_vec(),
206 [50, 51, 60, 61].map(F::from_u8)
207 );
208 unsafe {
209 assert_eq!(
210 flat.row_unchecked(1).into_iter().collect_vec(),
211 [30, 31, 40, 41].map(F::from_u8)
212 );
213 assert_eq!(
214 flat.row_subseq_unchecked(0, 1, 4).into_iter().collect_vec(),
215 [11, 20, 21].map(F::from_u8)
216 );
217 }
218
219 assert!(flat.get(0, 4).is_none()); assert!(flat.get(3, 0).is_none()); assert!(flat.row(3).is_none()); assert!(flat.row_slice(3).is_none()); }
224
225 #[test]
226 fn test_flat_matrix_width() {
227 let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 4], 2);
231 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
232 assert_eq!(flat.width(), 2 * <EF as BasedVectorSpace<F>>::DIMENSION);
233 }
234
235 #[test]
236 fn test_flat_matrix_height() {
237 let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 6], 3);
240 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
241 assert_eq!(flat.height(), 2);
242 }
243
244 #[test]
245 fn test_flat_matrix_row_iterator() {
246 let values = vec![
249 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 1)),
250 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
251 ];
252 let matrix = RowMajorMatrix::new(values, 2);
253 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
254
255 let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
257 let expected = [1, 2, 10, 11].map(F::from_u8).to_vec();
258
259 assert_eq!(row, expected);
260 }
261
262 #[test]
263 fn test_flat_matrix_row_slice_correctness() {
264 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
267 let matrix = RowMajorMatrix::new(vec![ef(1), ef(10)], 2);
268 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
269
270 assert_eq!(
271 &*flat.row_slice(0).unwrap(),
272 &[1, 2, 10, 11].map(F::from_u8)
273 );
274 }
275
276 #[test]
277 fn test_flat_matrix_empty() {
278 let matrix = RowMajorMatrix::<EF>::new(vec![], 0);
281 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
282
283 assert_eq!(flat.height(), 0);
284 assert_eq!(flat.width(), 0);
285 }
286
287 #[test]
288 fn test_flat_iter_length_and_values() {
289 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
292 let values = vec![ef(0), ef(10), ef(20)];
293 let matrix = RowMajorMatrix::new(values, 3); let flat = FlatMatrixView::<F, EF, _>::new(matrix);
295
296 let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
297 let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8).to_vec();
298 assert_eq!(row, expected);
299 }
300
301 #[test]
302 fn test_flat_matrix_multiple_rows() {
303 let ef = |base| EF::from_basis_coefficients_fn(|i| F::from_u8(base + i as u8));
307 let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20), ef(30)], 2);
308 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
309
310 let row0: Vec<_> = flat.first_row().unwrap().into_iter().collect();
311 let row1: Vec<_> = flat.row(1).unwrap().into_iter().collect();
312
313 assert_eq!(row0, [0, 1, 10, 11].map(F::from_u8).to_vec());
314 assert_eq!(row1, [20, 21, 30, 31].map(F::from_u8).to_vec());
315 }
316
317 #[test]
318 fn test_flat_iter_yields_across_multiple_efs() {
319 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
327 let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20)], 3); let flat = FlatMatrixView::<F, EF, _>::new(matrix);
329
330 let mut row_iter = flat.row(0).unwrap().into_iter();
331
332 let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8);
334
335 for expected_val in expected {
336 assert_eq!(row_iter.next(), Some(expected_val));
337 }
338
339 assert_eq!(row_iter.next(), None);
341 }
342
343 #[test]
344 fn test_row_subseq_start_ge_dimension() {
345 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
346 let values = vec![ef(10), ef(20), ef(30)];
347 let matrix = RowMajorMatrix::new(values, 3);
348 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
349
350 unsafe {
351 let result: Vec<_> = flat.row_subseq_unchecked(0, 2, 5).into_iter().collect();
352 assert_eq!(result, [20, 21, 30].map(F::from_u8).to_vec());
353
354 let result: Vec<_> = flat.row_subseq_unchecked(0, 3, 6).into_iter().collect();
355 assert_eq!(result, [21, 30, 31].map(F::from_u8).to_vec());
356 }
357 }
358}