p3_dft/
traits.rs

1use alloc::vec::Vec;
2
3use p3_field::TwoAdicField;
4use p3_matrix::bitrev::BitReversableMatrix;
5use p3_matrix::dense::RowMajorMatrix;
6use p3_matrix::util::swap_rows;
7use p3_matrix::Matrix;
8
9use crate::util::divide_by_height;
10
11pub trait TwoAdicSubgroupDft<F: TwoAdicField>: Clone + Default {
12    // Effectively this is either RowMajorMatrix or BitReversedMatrixView<RowMajorMatrix>.
13    // Always owned.
14    type Evaluations: BitReversableMatrix<F> + 'static;
15
16    /// Compute the discrete Fourier transform (DFT) `vec`.
17    fn dft(&self, vec: Vec<F>) -> Vec<F> {
18        self.dft_batch(RowMajorMatrix::new_col(vec))
19            .to_row_major_matrix()
20            .values
21    }
22
23    /// Compute the discrete Fourier transform (DFT) of each column in `mat`.
24    /// This is the only method an implementer needs to define, all other
25    /// methods can be derived from this one.
26    fn dft_batch(&self, mat: RowMajorMatrix<F>) -> Self::Evaluations;
27
28    /// Compute the "coset DFT" of `vec`. This can be viewed as interpolation onto a coset of a
29    /// multiplicative subgroup, rather than the subgroup itself.
30    fn coset_dft(&self, vec: Vec<F>, shift: F) -> Vec<F> {
31        self.coset_dft_batch(RowMajorMatrix::new_col(vec), shift)
32            .to_row_major_matrix()
33            .values
34    }
35
36    /// Compute the "coset DFT" of each column in `mat`. This can be viewed as interpolation onto a
37    /// coset of a multiplicative subgroup, rather than the subgroup itself.
38    fn coset_dft_batch(&self, mut mat: RowMajorMatrix<F>, shift: F) -> Self::Evaluations {
39        // Observe that
40        //     y_i = \sum_j c_j (s g^i)^j
41        //         = \sum_j (c_j s^j) (g^i)^j
42        // which has the structure of an ordinary DFT, except each coefficient c_j is first replaced
43        // by c_j s^j.
44        mat.rows_mut()
45            .zip(shift.powers())
46            .for_each(|(row, weight)| {
47                row.iter_mut().for_each(|coeff| {
48                    *coeff *= weight;
49                })
50            });
51        self.dft_batch(mat)
52    }
53
54    /// Compute the inverse DFT of `vec`.
55    fn idft(&self, vec: Vec<F>) -> Vec<F> {
56        self.idft_batch(RowMajorMatrix::new(vec, 1)).values
57    }
58
59    /// Compute the inverse DFT of each column in `mat`.
60    fn idft_batch(&self, mat: RowMajorMatrix<F>) -> RowMajorMatrix<F> {
61        let mut dft = self.dft_batch(mat).to_row_major_matrix();
62        let h = dft.height();
63
64        divide_by_height(&mut dft);
65
66        for row in 1..h / 2 {
67            swap_rows(&mut dft, row, h - row);
68        }
69
70        dft
71    }
72
73    /// Compute the "coset iDFT" of `vec`. This can be viewed as an inverse operation of
74    /// "coset DFT", that interpolates over a coset of a multiplicative subgroup, rather than
75    /// subgroup itself.
76    fn coset_idft(&self, vec: Vec<F>, shift: F) -> Vec<F> {
77        self.coset_idft_batch(RowMajorMatrix::new(vec, 1), shift)
78            .values
79    }
80
81    /// Compute the "coset iDFT" of each column in `mat`. This can be viewed as an inverse operation
82    /// of "coset DFT", that interpolates over a coset of a multiplicative subgroup, rather than the
83    /// subgroup itself.
84    fn coset_idft_batch(&self, mut mat: RowMajorMatrix<F>, shift: F) -> RowMajorMatrix<F> {
85        mat = self.idft_batch(mat);
86
87        mat.rows_mut()
88            .zip(shift.inverse().powers())
89            .for_each(|(row, weight)| {
90                row.iter_mut().for_each(|coeff| {
91                    *coeff *= weight;
92                })
93            });
94
95        mat
96    }
97
98    /// Compute the low-degree extension of `vec` onto a larger subgroup.
99    fn lde(&self, vec: Vec<F>, added_bits: usize) -> Vec<F> {
100        self.lde_batch(RowMajorMatrix::new(vec, 1), added_bits)
101            .to_row_major_matrix()
102            .values
103    }
104
105    /// Compute the low-degree extension of each column in `mat` onto a larger subgroup.
106    fn lde_batch(&self, mat: RowMajorMatrix<F>, added_bits: usize) -> Self::Evaluations {
107        let mut coeffs = self.idft_batch(mat);
108        coeffs
109            .values
110            .resize(coeffs.values.len() << added_bits, F::zero());
111        self.dft_batch(coeffs)
112    }
113
114    /// Compute the low-degree extension of each column in `mat` onto a coset of a larger subgroup.
115    fn coset_lde(&self, vec: Vec<F>, added_bits: usize, shift: F) -> Vec<F> {
116        self.coset_lde_batch(RowMajorMatrix::new(vec, 1), added_bits, shift)
117            .to_row_major_matrix()
118            .values
119    }
120
121    /// Compute the low-degree extension of each column in `mat` onto a coset of a larger subgroup.
122    fn coset_lde_batch(
123        &self,
124        mat: RowMajorMatrix<F>,
125        added_bits: usize,
126        shift: F,
127    ) -> Self::Evaluations {
128        let mut coeffs = self.idft_batch(mat);
129        // PANICS: possible panic if the new resized length overflows
130        coeffs.values.resize(
131            coeffs
132                .values
133                .len()
134                .checked_shl(added_bits.try_into().unwrap())
135                .unwrap(),
136            F::zero(),
137        );
138        self.coset_dft_batch(coeffs, shift)
139    }
140}