1use alloc::vec::Vec;
2use core::iter;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use p3_field::{ExtensionField, Field};
7
8use crate::Matrix;
9
10#[derive(Debug)]
16pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
17
18impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
19 pub const fn new(inner: Inner) -> Self {
20 Self(inner, PhantomData)
21 }
22}
23
24impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
25 type Target = Inner;
26
27 fn deref(&self) -> &Self::Target {
28 &self.0
29 }
30}
31
32impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
33where
34 F: Field,
35 EF: ExtensionField<F>,
36 Inner: Matrix<EF>,
37{
38 fn width(&self) -> usize {
39 self.0.width() * EF::DIMENSION
40 }
41
42 fn height(&self) -> usize {
43 self.0.height()
44 }
45
46 unsafe fn get_unchecked(&self, r: usize, c: usize) -> F {
47 let c_inner = c / EF::DIMENSION;
50 let inner = unsafe {
51 self.0.get_unchecked(r, c_inner)
54 };
55 inner.as_basis_coefficients_slice()[c % EF::DIMENSION]
56 }
57
58 unsafe fn row_unchecked(
59 &self,
60 r: usize,
61 ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
62 unsafe {
63 FlatIter {
65 inner: self.0.row_unchecked(r).into_iter().peekable(),
66 idx: 0,
67 _phantom: PhantomData,
68 }
69 }
70 }
71
72 unsafe fn row_subseq_unchecked(
73 &self,
74 r: usize,
75 start: usize,
76 end: usize,
77 ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
78 let len = end - start;
80 let inner_start = start / EF::DIMENSION;
81 unsafe {
82 FlatIter {
84 inner: self
85 .0
86 .row_subseq_unchecked(r, inner_start, self.0.width())
89 .into_iter()
90 .peekable(),
91 idx: start,
92 _phantom: PhantomData,
93 }
94 .take(len)
95 }
96 }
97
98 unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [F]> {
99 unsafe {
100 self.0
102 .row_slice_unchecked(r)
103 .iter()
104 .flat_map(|val| val.as_basis_coefficients_slice())
105 .copied()
106 .collect::<Vec<_>>()
107 }
108 }
109}
110
111pub struct FlatIter<F, I: Iterator> {
112 inner: iter::Peekable<I>,
113 idx: usize,
114 _phantom: PhantomData<F>,
115}
116
117impl<F, EF, I> Iterator for FlatIter<F, I>
118where
119 F: Field,
120 EF: ExtensionField<F>,
121 I: Iterator<Item = EF>,
122{
123 type Item = F;
124 fn next(&mut self) -> Option<Self::Item> {
125 if self.idx == EF::DIMENSION {
126 self.idx = 0;
127 self.inner.next();
128 }
129 let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx];
130 self.idx += 1;
131 Some(value)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use alloc::vec;
138
139 use itertools::Itertools;
140 use p3_field::extension::Complex;
141 use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
142 use p3_mersenne_31::Mersenne31;
143
144 use super::*;
145 use crate::dense::RowMajorMatrix;
146 type F = Mersenne31;
147 type EF = Complex<Mersenne31>;
148
149 #[test]
150 fn flat_matrix() {
151 let values = vec![
152 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
153 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 20)),
154 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 30)),
155 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 40)),
156 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 50)),
157 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 60)),
158 ];
159 let ext = RowMajorMatrix::<EF>::new(values, 2);
160 let flat = FlatMatrixView::<F, EF, _>::new(ext);
161
162 assert_eq!(flat.width(), 4);
163 assert_eq!(flat.height(), 3);
164
165 assert_eq!(flat.get(0, 2), Some(F::from_u8(20)));
166 assert_eq!(flat.get(1, 3), Some(F::from_u8(41)));
167 assert_eq!(flat.get(2, 0), Some(F::from_u8(50)));
168
169 unsafe {
170 assert_eq!(flat.get_unchecked(0, 1), F::from_u8(11));
171 assert_eq!(flat.get_unchecked(1, 0), F::from_u8(30));
172 assert_eq!(flat.get_unchecked(2, 2), F::from_u8(60));
173 }
174
175 assert_eq!(
176 &*flat.row_slice(0).unwrap(),
177 &[10, 11, 20, 21].map(F::from_u8)
178 );
179 unsafe {
180 assert_eq!(
181 &*flat.row_slice_unchecked(1),
182 &[30, 31, 40, 41].map(F::from_u8)
183 );
184 assert_eq!(
185 &*flat.row_subslice_unchecked(2, 0, 3),
186 &[50, 51, 60].map(F::from_u8)
187 );
188 }
189
190 assert_eq!(
191 flat.row(2).unwrap().into_iter().collect_vec(),
192 [50, 51, 60, 61].map(F::from_u8)
193 );
194 unsafe {
195 assert_eq!(
196 flat.row_unchecked(1).into_iter().collect_vec(),
197 [30, 31, 40, 41].map(F::from_u8)
198 );
199 assert_eq!(
200 flat.row_subseq_unchecked(0, 1, 4).into_iter().collect_vec(),
201 [11, 20, 21].map(F::from_u8)
202 );
203 }
204
205 assert!(flat.get(0, 4).is_none()); assert!(flat.get(3, 0).is_none()); assert!(flat.row(3).is_none()); assert!(flat.row_slice(3).is_none()); }
210
211 #[test]
212 fn test_flat_matrix_width() {
213 let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 4], 2);
217 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
218 assert_eq!(flat.width(), 2 * <EF as BasedVectorSpace<F>>::DIMENSION);
219 }
220
221 #[test]
222 fn test_flat_matrix_height() {
223 let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 6], 3);
226 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
227 assert_eq!(flat.height(), 2);
228 }
229
230 #[test]
231 fn test_flat_matrix_row_iterator() {
232 let values = vec![
235 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 1)),
236 EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
237 ];
238 let matrix = RowMajorMatrix::new(values, 2);
239 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
240
241 let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
243 let expected = [1, 2, 10, 11].map(F::from_u8).to_vec();
244
245 assert_eq!(row, expected);
246 }
247
248 #[test]
249 fn test_flat_matrix_row_slice_correctness() {
250 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
253 let matrix = RowMajorMatrix::new(vec![ef(1), ef(10)], 2);
254 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
255
256 assert_eq!(
257 &*flat.row_slice(0).unwrap(),
258 &[1, 2, 10, 11].map(F::from_u8)
259 );
260 }
261
262 #[test]
263 fn test_flat_matrix_empty() {
264 let matrix = RowMajorMatrix::<EF>::new(vec![], 0);
267 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
268
269 assert_eq!(flat.height(), 0);
270 assert_eq!(flat.width(), 0);
271 }
272
273 #[test]
274 fn test_flat_iter_length_and_values() {
275 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
278 let values = vec![ef(0), ef(10), ef(20)];
279 let matrix = RowMajorMatrix::new(values, 3); let flat = FlatMatrixView::<F, EF, _>::new(matrix);
281
282 let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
283 let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8).to_vec();
284 assert_eq!(row, expected);
285 }
286
287 #[test]
288 fn test_flat_matrix_multiple_rows() {
289 let ef = |base| EF::from_basis_coefficients_fn(|i| F::from_u8(base + i as u8));
293 let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20), ef(30)], 2);
294 let flat = FlatMatrixView::<F, EF, _>::new(matrix);
295
296 let row0: Vec<_> = flat.first_row().unwrap().into_iter().collect();
297 let row1: Vec<_> = flat.row(1).unwrap().into_iter().collect();
298
299 assert_eq!(row0, [0, 1, 10, 11].map(F::from_u8).to_vec());
300 assert_eq!(row1, [20, 21, 30, 31].map(F::from_u8).to_vec());
301 }
302
303 #[test]
304 fn test_flat_iter_yields_across_multiple_efs() {
305 let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
313 let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20)], 3); let flat = FlatMatrixView::<F, EF, _>::new(matrix);
315
316 let mut row_iter = flat.row(0).unwrap().into_iter();
317
318 let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8);
320
321 for expected_val in expected {
322 assert_eq!(row_iter.next(), Some(expected_val));
323 }
324
325 assert_eq!(row_iter.next(), None);
327 }
328}