Skip to main content

p3_matrix/
extension.rs

1use alloc::vec::Vec;
2use core::iter;
3use core::marker::PhantomData;
4use core::ops::Deref;
5
6use p3_field::{ExtensionField, Field};
7
8use crate::Matrix;
9use crate::bitrev::BitReversibleMatrix;
10
11/// A view that flattens a matrix of extension field elements into a matrix of base field elements.
12///
13/// Each element of the original matrix is an extension field element `EF`, composed of several
14/// base field elements `F`. This view expands each `EF` element into its base field components,
15/// effectively increasing the number of columns (width) while keeping the number of rows unchanged.
16#[derive(Debug)]
17pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
18
19impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
20    pub const fn new(inner: Inner) -> Self {
21        Self(inner, PhantomData)
22    }
23}
24
25impl<F, EF, Inner> Deref for FlatMatrixView<F, EF, Inner> {
26    type Target = Inner;
27
28    fn deref(&self) -> &Self::Target {
29        &self.0
30    }
31}
32
33impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
34where
35    F: Field,
36    EF: ExtensionField<F>,
37    Inner: Matrix<EF>,
38{
39    fn width(&self) -> usize {
40        self.0.width() * EF::DIMENSION
41    }
42
43    fn height(&self) -> usize {
44        self.0.height()
45    }
46
47    unsafe fn get_unchecked(&self, r: usize, c: usize) -> F {
48        // The c'th base field element in a row of extension field elements is
49        // at index c % EF::DIMENSION in the c / EF::DIMENSION'th extension element.
50        let c_inner = c / EF::DIMENSION;
51        let inner = unsafe {
52            // Safety: The caller must ensure that r < self.height() and c < self.width().
53            // Assuming this, c / EF::DIMENSION < self.0.width().
54            self.0.get_unchecked(r, c_inner)
55        };
56        inner.as_basis_coefficients_slice()[c % EF::DIMENSION]
57    }
58
59    unsafe fn row_unchecked(
60        &self,
61        r: usize,
62    ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
63        unsafe {
64            // Safety: The caller must ensure that r < self.height().
65            FlatIter {
66                inner: self.0.row_unchecked(r).into_iter().peekable(),
67                idx: 0,
68                _phantom: PhantomData,
69            }
70        }
71    }
72
73    unsafe fn row_subseq_unchecked(
74        &self,
75        r: usize,
76        start: usize,
77        end: usize,
78    ) -> impl IntoIterator<Item = F, IntoIter = impl Iterator<Item = F> + Send + Sync> {
79        // We can skip the first start / EF::DIMENSION elements in the row.
80        let len = end - start;
81        let inner_start = start / EF::DIMENSION;
82        unsafe {
83            // Safety: The caller must ensure that r < self.height(), start <= end and end < self.width().
84            FlatIter {
85                inner: self
86                    .0
87                    // We set end to be the width of the inner matrix and use take to ensure we get the right
88                    // number of elements.
89                    .row_subseq_unchecked(r, inner_start, self.0.width())
90                    .into_iter()
91                    .peekable(),
92                idx: start % EF::DIMENSION,
93                _phantom: PhantomData,
94            }
95            .take(len)
96        }
97    }
98
99    unsafe fn row_slice_unchecked(&self, r: usize) -> impl Deref<Target = [F]> {
100        unsafe {
101            // Safety: The caller must ensure that r < self.height().
102            self.0
103                .row_slice_unchecked(r)
104                .iter()
105                .flat_map(|val| val.as_basis_coefficients_slice())
106                .copied()
107                .collect::<Vec<_>>()
108        }
109    }
110}
111
112pub struct FlatIter<F, I: Iterator> {
113    inner: iter::Peekable<I>,
114    idx: usize,
115    _phantom: PhantomData<F>,
116}
117
118impl<F, EF, I> Iterator for FlatIter<F, I>
119where
120    F: Field,
121    EF: ExtensionField<F>,
122    I: Iterator<Item = EF>,
123{
124    type Item = F;
125    fn next(&mut self) -> Option<Self::Item> {
126        if self.idx == EF::DIMENSION {
127            self.idx = 0;
128            self.inner.next();
129        }
130        let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx];
131        self.idx += 1;
132        Some(value)
133    }
134}
135
136impl<F, EF, Inner> BitReversibleMatrix<F> for FlatMatrixView<F, EF, Inner>
137where
138    F: Field,
139    EF: ExtensionField<F>,
140    Inner: BitReversibleMatrix<EF>,
141{
142    type BitRev = FlatMatrixView<F, EF, Inner::BitRev>;
143
144    fn bit_reverse_rows(self) -> Self::BitRev {
145        FlatMatrixView::new(self.0.bit_reverse_rows())
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use alloc::vec;
152
153    use itertools::Itertools;
154    use p3_field::extension::Complex;
155    use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
156    use p3_mersenne_31::Mersenne31;
157
158    use super::*;
159    use crate::dense::RowMajorMatrix;
160    type F = Mersenne31;
161    type EF = Complex<Mersenne31>;
162
163    #[test]
164    fn flat_matrix() {
165        let values = vec![
166            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
167            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 20)),
168            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 30)),
169            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 40)),
170            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 50)),
171            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 60)),
172        ];
173        let ext = RowMajorMatrix::<EF>::new(values, 2);
174        let flat = FlatMatrixView::<F, EF, _>::new(ext);
175
176        assert_eq!(flat.width(), 4);
177        assert_eq!(flat.height(), 3);
178
179        assert_eq!(flat.get(0, 2), Some(F::from_u8(20)));
180        assert_eq!(flat.get(1, 3), Some(F::from_u8(41)));
181        assert_eq!(flat.get(2, 0), Some(F::from_u8(50)));
182
183        unsafe {
184            assert_eq!(flat.get_unchecked(0, 1), F::from_u8(11));
185            assert_eq!(flat.get_unchecked(1, 0), F::from_u8(30));
186            assert_eq!(flat.get_unchecked(2, 2), F::from_u8(60));
187        }
188
189        assert_eq!(
190            &*flat.row_slice(0).unwrap(),
191            &[10, 11, 20, 21].map(F::from_u8)
192        );
193        unsafe {
194            assert_eq!(
195                &*flat.row_slice_unchecked(1),
196                &[30, 31, 40, 41].map(F::from_u8)
197            );
198            assert_eq!(
199                &*flat.row_subslice_unchecked(2, 0, 3),
200                &[50, 51, 60].map(F::from_u8)
201            );
202        }
203
204        assert_eq!(
205            flat.row(2).unwrap().into_iter().collect_vec(),
206            [50, 51, 60, 61].map(F::from_u8)
207        );
208        unsafe {
209            assert_eq!(
210                flat.row_unchecked(1).into_iter().collect_vec(),
211                [30, 31, 40, 41].map(F::from_u8)
212            );
213            assert_eq!(
214                flat.row_subseq_unchecked(0, 1, 4).into_iter().collect_vec(),
215                [11, 20, 21].map(F::from_u8)
216            );
217        }
218
219        assert!(flat.get(0, 4).is_none()); // Width out of bounds
220        assert!(flat.get(3, 0).is_none()); // Height out of bounds
221        assert!(flat.row(3).is_none()); // Height out of bounds
222        assert!(flat.row_slice(3).is_none()); // Height out of bounds
223    }
224
225    #[test]
226    fn test_flat_matrix_width() {
227        // Create a 2-column, 2-row matrix of EF elements.
228        // Each EF element expands to EF::DIMENSION base field elements when flattened.
229        // Therefore, the flattened width should be 2 * EF::DIMENSION.
230        let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 4], 2);
231        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
232        assert_eq!(flat.width(), 2 * <EF as BasedVectorSpace<F>>::DIMENSION);
233    }
234
235    #[test]
236    fn test_flat_matrix_height() {
237        // Construct a 3-column matrix with 6 EF elements (2 rows).
238        // The flattened view should preserve the original number of rows.
239        let matrix = RowMajorMatrix::<EF>::new(vec![EF::default(); 6], 3);
240        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
241        assert_eq!(flat.height(), 2);
242    }
243
244    #[test]
245    fn test_flat_matrix_row_iterator() {
246        // Create a single row of two EF elements:
247        // First EF = [1, 2], second EF = [10, 11] (in base field representation).
248        let values = vec![
249            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 1)),
250            EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + 10)),
251        ];
252        let matrix = RowMajorMatrix::new(values, 2);
253        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
254
255        // Flattened row should concatenate basis coefficients of both EF elements.
256        let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
257        let expected = [1, 2, 10, 11].map(F::from_u8).to_vec();
258
259        assert_eq!(row, expected);
260    }
261
262    #[test]
263    fn test_flat_matrix_row_slice_correctness() {
264        // Construct a row with two EF values: [1, 2] and [10, 11].
265        // Verify that row_slice() correctly returns a flat &[F] of base field values.
266        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
267        let matrix = RowMajorMatrix::new(vec![ef(1), ef(10)], 2);
268        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
269
270        assert_eq!(
271            &*flat.row_slice(0).unwrap(),
272            &[1, 2, 10, 11].map(F::from_u8)
273        );
274    }
275
276    #[test]
277    fn test_flat_matrix_empty() {
278        // Edge case: test behavior on empty matrix.
279        // Expect zero width and height in the flattened view.
280        let matrix = RowMajorMatrix::<EF>::new(vec![], 0);
281        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
282
283        assert_eq!(flat.height(), 0);
284        assert_eq!(flat.width(), 0);
285    }
286
287    #[test]
288    fn test_flat_iter_length_and_values() {
289        // Create a row with three EF values, each with offset base coefficients:
290        // [0,1], [10,11], [20,21] -> flattened row should be [0,1,10,11,20,21].
291        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
292        let values = vec![ef(0), ef(10), ef(20)];
293        let matrix = RowMajorMatrix::new(values, 3); // 1 row
294        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
295
296        let row: Vec<_> = flat.first_row().unwrap().into_iter().collect();
297        let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8).to_vec();
298        assert_eq!(row, expected);
299    }
300
301    #[test]
302    fn test_flat_matrix_multiple_rows() {
303        // Construct a 2-column, 2-row matrix of EF values, with varying offsets per row.
304        // Row 0: [0,1], [10,11]; Row 1: [20,21], [30,31].
305        // Verify that the flattening preserves row structure and ordering.
306        let ef = |base| EF::from_basis_coefficients_fn(|i| F::from_u8(base + i as u8));
307        let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20), ef(30)], 2);
308        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
309
310        let row0: Vec<_> = flat.first_row().unwrap().into_iter().collect();
311        let row1: Vec<_> = flat.row(1).unwrap().into_iter().collect();
312
313        assert_eq!(row0, [0, 1, 10, 11].map(F::from_u8).to_vec());
314        assert_eq!(row1, [20, 21, 30, 31].map(F::from_u8).to_vec());
315    }
316
317    #[test]
318    fn test_flat_iter_yields_across_multiple_efs() {
319        // Build 1 row with 3 EF elements:
320        // - ef(0)   = [0, 1]
321        // - ef(10)  = [10, 11]
322        // - ef(20)  = [20, 21]
323        //
324        // The flattened row should yield:
325        // [0, 1, 10, 11, 20, 21] as base field elements (F)
326        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
327        let matrix = RowMajorMatrix::new(vec![ef(0), ef(10), ef(20)], 3); // 1 row, 3 EF elements
328        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
329
330        let mut row_iter = flat.row(0).unwrap().into_iter();
331
332        // Expected flattened result
333        let expected = [0, 1, 10, 11, 20, 21].map(F::from_u8);
334
335        for expected_val in expected {
336            assert_eq!(row_iter.next(), Some(expected_val));
337        }
338
339        // Iterator should now be exhausted
340        assert_eq!(row_iter.next(), None);
341    }
342
343    #[test]
344    fn test_row_subseq_start_ge_dimension() {
345        let ef = |offset| EF::from_basis_coefficients_fn(|i| F::from_u8(i as u8 + offset));
346        let values = vec![ef(10), ef(20), ef(30)];
347        let matrix = RowMajorMatrix::new(values, 3);
348        let flat = FlatMatrixView::<F, EF, _>::new(matrix);
349
350        unsafe {
351            let result: Vec<_> = flat.row_subseq_unchecked(0, 2, 5).into_iter().collect();
352            assert_eq!(result, [20, 21, 30].map(F::from_u8).to_vec());
353
354            let result: Vec<_> = flat.row_subseq_unchecked(0, 3, 6).into_iter().collect();
355            assert_eq!(result, [21, 30, 31].map(F::from_u8).to_vec());
356        }
357    }
358}