p3_matrix/
extension.rs

1use alloc::vec::Vec;
2use core::iter;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use p3_field::{ExtensionField, Field};
7
8use crate::Matrix;
9
10/// A view that flattens a matrix of extension field elements into a matrix of base field elements.
11///
12/// Each element of the original matrix is an extension field element `EF`, composed of several
13/// base field elements `F`. This view expands each `EF` element into its base field components,
14/// effectively increasing the number of columns (width) while keeping the number of rows unchanged.
15#[derive(Debug)]
16pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
17
18impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
19    pub const fn new(inner: Inner) -> Self {
20        Self(inner, PhantomData)
21    }
22}
23
24impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
25    type Target = Inner;
26
27    fn deref(&self) -> &Self::Target {
28        &self.0
29    }
30}
31
32impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
33where
34    F: Field,
35    EF: ExtensionField<F>,
36    Inner: Matrix<EF>,
37{
38    fn width(&self) -> usize {
39        self.0.width() * EF::DIMENSION
40    }
41
42    fn height(&self) -> usize {
43        self.0.height()
44    }
45
46    unsafe fn get_unchecked(&self, r: usize, c: usize) -> F {
47        // The c'th base field element in a row of extension field elements is
48        // at index c % EF::DIMENSION in the c / EF::DIMENSION'th extension element.
49        let c_inner = c / EF::DIMENSION;
50        let inner = unsafe {
51            // Safety: The caller must ensure that r < self.height() and c < self.width().
52            // Assuming this, c / EF::DIMENSION < self.0.width().
53            self.0.get_unchecked(r, c_inner)
54        };
55        inner.as_basis_coefficients_slice()[c % EF::DIMENSION]
56    }
57
58    unsafe fn row_unchecked(
59        &self,
60        r: usize,
61    ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
62        unsafe {
63            // Safety: The caller must ensure that r < self.height().
64            FlatIter {
65                inner: self.0.row_unchecked(r).into_iter().peekable(),
66                idx: 0,
67                _phantom: PhantomData,
68            }
69        }
70    }
71
72    unsafe fn row_subseq_unchecked(
73        &self,
74        r: usize,
75        start: usize,
76        end: usize,
77    ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
78        // We can skip the first start / EF::DIMENSION elements in the row.
79        let len = end - start;
80        let inner_start = start / EF::DIMENSION;
81        unsafe {
82            // Safety: The caller must ensure that r < self.height(), start <= end and end < self.width().
83            FlatIter {
84                inner: self
85                    .0
86                    // We set end to be the width of the inner matrix and use take to ensure we get the right
87                    // number of elements.
88                    .row_subseq_unchecked(r, inner_start, self.0.width())
89                    .into_iter()
90                    .peekable(),
91                idx: start,
92                _phantom: PhantomData,
93            }
94            .take(len)
95        }
96    }
97
98    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [F]> {
99        unsafe {
100            // Safety: The caller must ensure that r < self.height().
101            self.0
102                .row_slice_unchecked(r)
103                .iter()
104                .flat_map(|val| val.as_basis_coefficients_slice())
105                .copied()
106                .collect::<Vec<_>>()
107        }
108    }
109}
110
111pub struct FlatIter<F, I: Iterator> {
112    inner: iter::Peekable<I>,
113    idx: usize,
114    _phantom: PhantomData<F>,
115}
116
117impl<F, EF, I> Iterator for FlatIter<F, I>
118where
119    F: Field,
120    EF: ExtensionField<F>,
121    I: Iterator<Item = EF>,
122{
123    type Item = F;
124    fn next(&mut self) -> Option<Self::Item> {
125        if self.idx == EF::DIMENSION {
126            self.idx = 0;
127            self.inner.next();
128        }
129        let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx];
130        self.idx += 1;
131        Some(value)
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use alloc::vec;
138
139    use itertools::Itertools;
140    use p3_field::extension::Complex;
141    use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
142    use p3_mersenne_31::Mersenne31;
143
144    use super::*;
145    use crate::dense::RowMajorMatrix;
146    type F = Mersenne31;
147    type EF = Complex<Mersenne31>;
148
149    #[test]
150    fn flat_matrix() {
151        let values = vec![
152            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
153            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 20)),
154            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 30)),
155            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 40)),
156            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 50)),
157            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 60)),
158        ];
159        let ext = RowMajorMatrix::<EF>::new(values, 2);
160        let flat = FlatMatrixView::<F, EF, _>::new(ext);
161
162        assert_eq!(flat.width(), 4);
163        assert_eq!(flat.height(), 3);
164
165        assert_eq!(flat.get(0, 2), Some(F::from_u8(20)));
166        assert_eq!(flat.get(1, 3), Some(F::from_u8(41)));
167        assert_eq!(flat.get(2, 0), Some(F::from_u8(50)));
168
169        unsafe {
170            assert_eq!(flat.get_unchecked(0, 1), F::from_u8(11));
171            assert_eq!(flat.get_unchecked(1, 0), F::from_u8(30));
172            assert_eq!(flat.get_unchecked(2, 2), F::from_u8(60));
173        }
174
175        assert_eq!(
176            &*flat.row_slice(0).unwrap(),
177            &[10, 11, 20, 21].map(F::from_u8)
178        );
179        unsafe {
180            assert_eq!(
181                &*flat.row_slice_unchecked(1),
182                &[30, 31, 40, 41].map(F::from_u8)
183            );
184            assert_eq!(
185                &*flat.row_subslice_unchecked(2, 0, 3),
186                &[50, 51, 60].map(F::from_u8)
187            );
188        }
189
190        assert_eq!(
191            flat.row(2).unwrap().into_iter().collect_vec(),
192            [50, 51, 60, 61].map(F::from_u8)
193        );
194        unsafe {
195            assert_eq!(
196                flat.row_unchecked(1).into_iter().collect_vec(),
197                [30, 31, 40, 41].map(F::from_u8)
198            );
199            assert_eq!(
200                flat.row_subseq_unchecked(0, 1, 4).into_iter().collect_vec(),
201                [11, 20, 21].map(F::from_u8)
202            );
203        }
204
205        assert!(flat.get(0, 4).is_none()); // Width out of bounds
206        assert!(flat.get(3, 0).is_none()); // Height out of bounds
207        assert!(flat.row(3).is_none()); // Height out of bounds
208        assert!(flat.row_slice(3).is_none()); // Height out of bounds
209    }
210
211    #[test]
212    fn test_flat_matrix_width() {
213        // Create a 2-column, 2-row matrix of EF elements.
214        // Each EF element expands to EF::DIMENSION base field elements when flattened.
215        // Therefore, the flattened width should be 2 * EF::DIMENSION.
216        let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 4], 2);
217        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
218        assert_eq!(flat.width(), 2 * <EF as BasedVectorSpace<F>>::DIMENSION);
219    }
220
221    #[test]
222    fn test_flat_matrix_height() {
223        // Construct a 3-column matrix with 6 EF elements (2 rows).
224        // The flattened view should preserve the original number of rows.
225        let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 6], 3);
226        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
227        assert_eq!(flat.height(), 2);
228    }
229
230    #[test]
231    fn test_flat_matrix_row_iterator() {
232        // Create a single row of two EF elements:
233        // First EF = [1, 2], second EF = [10, 11] (in base field representation).
234        let values = vec![
235            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 1)),
236            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
237        ];
238        let matrix = RowMajorMatrix::new(values, 2);
239        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
240
241        // Flattened row should concatenate basis coefficients of both EF elements.
242        let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
243        let expected = [1, 2, 10, 11].map(F::from_u8).to_vec();
244
245        assert_eq!(row, expected);
246    }
247
248    #[test]
249    fn test_flat_matrix_row_slice_correctness() {
250        // Construct a row with two EF values: [1, 2] and [10, 11].
251        // Verify that row_slice() correctly returns a flat &[F] of base field values.
252        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
253        let matrix = RowMajorMatrix::new(vec![ef(1), ef(10)], 2);
254        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
255
256        assert_eq!(
257            &*flat.row_slice(0).unwrap(),
258            &[1, 2, 10, 11].map(F::from_u8)
259        );
260    }
261
262    #[test]
263    fn test_flat_matrix_empty() {
264        // Edge case: test behavior on empty matrix.
265        // Expect zero width and height in the flattened view.
266        let matrix = RowMajorMatrix::<EF>::new(vec![], 0);
267        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
268
269        assert_eq!(flat.height(), 0);
270        assert_eq!(flat.width(), 0);
271    }
272
273    #[test]
274    fn test_flat_iter_length_and_values() {
275        // Create a row with three EF values, each with offset base coefficients:
276        // [0,1], [10,11], [20,21] -> flattened row should be [0,1,10,11,20,21].
277        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
278        let values = vec![ef(0), ef(10), ef(20)];
279        let matrix = RowMajorMatrix::new(values, 3); // 1 row
280        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
281
282        let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
283        let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8).to_vec();
284        assert_eq!(row, expected);
285    }
286
287    #[test]
288    fn test_flat_matrix_multiple_rows() {
289        // Construct a 2-column, 2-row matrix of EF values, with varying offsets per row.
290        // Row 0: [0,1], [10,11]; Row 1: [20,21], [30,31].
291        // Verify that the flattening preserves row structure and ordering.
292        let ef = |base| EF::from_basis_coefficients_fn(|i| F::from_u8(base + i as u8));
293        let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20), ef(30)], 2);
294        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
295
296        let row0: Vec<_> = flat.first_row().unwrap().into_iter().collect();
297        let row1: Vec<_> = flat.row(1).unwrap().into_iter().collect();
298
299        assert_eq!(row0, [0, 1, 10, 11].map(F::from_u8).to_vec());
300        assert_eq!(row1, [20, 21, 30, 31].map(F::from_u8).to_vec());
301    }
302
303    #[test]
304    fn test_flat_iter_yields_across_multiple_efs() {
305        // Build 1 row with 3 EF elements:
306        // - ef(0)   = [0, 1]
307        // - ef(10)  = [10, 11]
308        // - ef(20)  = [20, 21]
309        //
310        // The flattened row should yield:
311        // [0, 1, 10, 11, 20, 21] as base field elements (F)
312        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
313        let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20)], 3); // 1 row, 3 EF elements
314        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
315
316        let mut row_iter = flat.row(0).unwrap().into_iter();
317
318        // Expected flattened result
319        let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8);
320
321        for expected_val in expected {
322            assert_eq!(row_iter.next(), Some(expected_val));
323        }
324
325        // Iterator should now be exhausted
326        assert_eq!(row_iter.next(), None);
327    }
328}