p3_matrix/
extension.rs

1use core::iter;
2use core::marker::PhantomData;
3
4use p3_field::{ExtensionField, Field};
5
6use crate::Matrix;
7
8#[derive(Debug)]
9pub struct FlatMatrixView<F, EF, Inner>(Inner, PhantomData<(F, EF)>);
10
11impl<F, EF, Inner> FlatMatrixView<F, EF, Inner> {
12    pub fn new(inner: Inner) -> Self {
13        Self(inner, PhantomData)
14    }
15    pub fn inner_ref(&self) -> &Inner {
16        &self.0
17    }
18}
19
20impl<F, EF, Inner> Matrix<F> for FlatMatrixView<F, EF, Inner>
21where
22    F: Field,
23    EF: ExtensionField<F>,
24    Inner: Matrix<EF>,
25{
26    fn width(&self) -> usize {
27        self.0.width() * EF::D
28    }
29
30    fn height(&self) -> usize {
31        self.0.height()
32    }
33
34    type Row<'a>
35        = FlatIter<F, Inner::Row<'a>>
36    where
37        Self: 'a;
38
39    fn row(&self, r: usize) -> Self::Row<'_> {
40        FlatIter {
41            inner: self.0.row(r).peekable(),
42            idx: 0,
43            _phantom: PhantomData,
44        }
45    }
46}
47
48pub struct FlatIter<F, I: Iterator> {
49    inner: iter::Peekable<I>,
50    idx: usize,
51    _phantom: PhantomData<F>,
52}
53
54impl<F, EF, I> Iterator for FlatIter<F, I>
55where
56    F: Field,
57    EF: ExtensionField<F>,
58    I: Iterator<Item = EF>,
59{
60    type Item = F;
61    fn next(&mut self) -> Option<Self::Item> {
62        if self.idx == EF::D {
63            self.idx = 0;
64            self.inner.next();
65        }
66        let value = self.inner.peek()?.as_base_slice()[self.idx];
67        self.idx += 1;
68        Some(value)
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use alloc::vec;
75
76    use p3_field::extension::Complex;
77    use p3_field::{AbstractExtensionField, AbstractField};
78    use p3_mersenne_31::Mersenne31;
79
80    use super::*;
81    use crate::dense::RowMajorMatrix;
82    type F = Mersenne31;
83    type EF = Complex<Mersenne31>;
84
85    #[test]
86    fn flat_matrix() {
87        let values = vec![
88            EF::from_base_fn(|i| F::from_canonical_usize(i + 10)),
89            EF::from_base_fn(|i| F::from_canonical_usize(i + 20)),
90            EF::from_base_fn(|i| F::from_canonical_usize(i + 30)),
91            EF::from_base_fn(|i| F::from_canonical_usize(i + 40)),
92        ];
93        let ext = RowMajorMatrix::<EF>::new(values, 2);
94        let flat = FlatMatrixView::<F, EF, _>::new(ext);
95        assert_eq!(
96            &*flat.row_slice(0),
97            &[10, 11, 20, 21].map(F::from_canonical_usize)
98        );
99        assert_eq!(
100            &*flat.row_slice(1),
101            &[30, 31, 40, 41].map(F::from_canonical_usize)
102        );
103    }
104}