Skip to main content

p3_dft/
butterflies.rs

1use core::mem::MaybeUninit;
2
3use itertools::izip;
4use p3_field::{Field, PackedField, PackedValue};
5
6/// A butterfly operation used in NTT to combine two values into a new pair.
7///
8/// This trait defines how to transform two elements (or vectors of elements)
9/// according to the structure of a butterfly gate.
10///
11/// In an NTT, butterflies are the core units that recursively combine values
12/// across layers. Each butterfly computes:
13/// ```text
14///   (a + b * twiddle, a - b * twiddle)   // DIT
15/// or
16///   (a + b, (a - b) * twiddle)           // DIF
17/// ```
18/// The transformation can be applied:
19/// - in-place (mutating input values)
20/// - to full rows of values (arrays of field elements)
21/// - out-of-place (writing results to separate destination buffers)
22///
23/// Different butterfly variants (DIT, DIF, or twiddle-free) define the exact formula.
24pub trait Butterfly<F: Field>: Copy + Send + Sync {
25    /// Applies the butterfly transformation to two packed field values.
26    ///
27    /// This method takes two inputs `x_1` and `x_2` and returns two outputs `(y_1, y_2)`
28    /// depending on the butterfly type.
29    /// ```text
30    /// Example (DIF):
31    ///   Input:  x_1 = a, x_2 = b
32    ///   Output: (a + b, (a - b) * twiddle)
33    /// ```
34    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF);
35
36    /// Applies the butterfly in-place to two packed values.
37    ///
38    /// Mutates both `x_1` and `x_2` directly, storing the result of `apply`.
39    #[inline]
40    fn apply_in_place<PF: PackedField<Scalar = F>>(&self, x_1: &mut PF, x_2: &mut PF) {
41        (*x_1, *x_2) = self.apply(*x_1, *x_2);
42    }
43
44    /// Applies the butterfly transformation to two rows of scalar field values.
45    ///
46    /// Each row is a slice of `F`. This function processes the rows in packed
47    /// chunks using SIMD where possible, and falls back to scalar operations
48    /// for the suffix (remaining elements).
49    ///
50    /// The transformation is done in-place.
51    #[inline]
52    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
53        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
54        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
55        debug_assert_eq!(shorts_1.len(), shorts_2.len());
56        debug_assert_eq!(suffix_1.len(), suffix_2.len());
57        for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2) {
58            self.apply_in_place(x_1, x_2);
59        }
60        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2) {
61            self.apply_in_place(x_1, x_2);
62        }
63    }
64
65    /// Applies the butterfly out-of-place to two source rows.
66    ///
67    /// This version does not overwrite the source. Instead, it writes the
68    /// result of each butterfly to separate destination slices (which may
69    /// be uninitialized memory).
70    ///
71    /// This is useful when performing LDE's where the size of the output is larger than the size of the input.
72    ///
73    /// - `src_1`, `src_2`: input slices
74    /// - `dst_1`, `dst_2`: output slices to write to (must be MaybeUninit)
75    #[inline]
76    fn apply_to_rows_oop(
77        &self,
78        src_1: &[F],
79        dst_1: &mut [MaybeUninit<F>],
80        src_2: &[F],
81        dst_2: &mut [MaybeUninit<F>],
82    ) {
83        let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
84        let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
85        let (dst_shorts_1, dst_suffix_1) =
86            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
87        let (dst_shorts_2, dst_suffix_2) =
88            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
89        debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
90        debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
91        debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
92        debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
93        for (s_1, s_2, d_1, d_2) in izip!(src_shorts_1, src_shorts_2, dst_shorts_1, dst_shorts_2) {
94            let (res_1, res_2) = self.apply(*s_1, *s_2);
95            d_1.write(res_1);
96            d_2.write(res_2);
97        }
98        for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
99            let (res_1, res_2) = self.apply(*s_1, *s_2);
100            d_1.write(res_1);
101            d_2.write(res_2);
102        }
103    }
104}
105
106/// DIF (Decimation-In-Frequency) butterfly operation.
107///
108/// Used in the *output-ordering* variant of NTT.
109/// This butterfly computes:
110/// ```text
111///   output_1 = x1 + x2
112///   output_2 = (x1 - x2) * twiddle
113/// ```
114/// The twiddle factor is applied after subtraction.
115/// Suitable for DIF-style recursive transforms.
116#[derive(Copy, Clone)]
117#[repr(transparent)] // Allows safe transmutes from F to this.
118pub struct DifButterfly<F>(pub F);
119
120impl<F: Field> Butterfly<F> for DifButterfly<F> {
121    #[inline]
122    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
123        (x_1 + x_2, (x_1 - x_2) * self.0)
124    }
125}
126
127/// DIF (Decimation-In-Frequency) butterfly operation where `x_2` is guaranteed to be zero.
128///
129/// Useful in scenarios where the input has just been padded with zeros.
130///
131/// Used in the *output-ordering* variant of NTT.
132/// This butterfly computes:
133/// ```text
134///   output_1 = x1
135///   output_2 = x1 * twiddle
136/// ```
137#[derive(Copy, Clone)]
138#[repr(transparent)] // Allows safe transmutes from F to this.
139pub struct DifButterflyZeros<F>(pub F);
140
141impl<F: Field> Butterfly<F> for DifButterflyZeros<F> {
142    #[inline]
143    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
144        debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
145        (x_1, x_1 * self.0)
146    }
147
148    #[inline]
149    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
150        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix(row_1);
151        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
152        debug_assert_eq!(shorts_1.len(), shorts_2.len());
153        debug_assert_eq!(suffix_1.len(), suffix_2.len());
154        for (x_1, x_2) in shorts_1.iter().zip(shorts_2) {
155            debug_assert!(x_2.as_slice().iter().all(|x| x.is_zero())); // Slightly convoluted but PF may not implement equality.
156            *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle. 
157        }
158        for (x_1, x_2) in suffix_1.iter().zip(suffix_2) {
159            debug_assert!(x_2.is_zero());
160            *x_2 = *x_1 * self.0; // x_2 is guaranteed to be zero, so we just set it to x_1 * twiddle. 
161        }
162    }
163}
164
165/// DIT (Decimation-In-Time) butterfly operation.
166///
167/// Used in the *input-ordering* variant of NTT/FFT.
168/// This butterfly computes:
169/// ```text
170///   output_1 = x1 + x2 * twiddle
171///   output_2 = x1 - x2 * twiddle
172/// ```
173/// The twiddle factor is applied to x2 before combining.
174/// Suitable for DIT-style recursive transforms.
175#[derive(Copy, Clone)]
176#[repr(transparent)] // Allows safe transmutes from F to this.
177pub struct DitButterfly<F>(pub F);
178
179impl<F: Field> Butterfly<F> for DitButterfly<F> {
180    #[inline]
181    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
182        let x_2_twiddle = x_2 * self.0;
183        (x_1 + x_2_twiddle, x_1 - x_2_twiddle)
184    }
185
186    /// Override `apply_to_rows` to pre-broadcast the twiddle factor into a packed field
187    /// once before the inner loop, avoiding a scalar-to-vector broadcast on each packed
188    /// multiplication. For wide rows (e.g., 256 columns with AVX512 width=16, giving 16
189    /// packed iterations per row-pair), this eliminates 15 redundant broadcasts per call.
190    #[inline]
191    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
192        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
193        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
194        debug_assert_eq!(shorts_1.len(), shorts_2.len());
195        debug_assert_eq!(suffix_1.len(), suffix_2.len());
196        // Pre-broadcast the scalar twiddle into a packed field once outside the loop.
197        let twiddle_packed = F::Packing::from(self.0);
198        for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2.iter_mut()) {
199            let x_2_twiddle = *x_2 * twiddle_packed;
200            let new_x1 = *x_1 + x_2_twiddle;
201            *x_2 = *x_1 - x_2_twiddle;
202            *x_1 = new_x1;
203        }
204        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
205            self.apply_in_place(x_1, x_2);
206        }
207    }
208
209    /// Override `apply_to_rows_oop` similarly, pre-broadcasting the twiddle once.
210    #[inline]
211    fn apply_to_rows_oop(
212        &self,
213        src_1: &[F],
214        dst_1: &mut [MaybeUninit<F>],
215        src_2: &[F],
216        dst_2: &mut [MaybeUninit<F>],
217    ) {
218        let (src_shorts_1, src_suffix_1) = F::Packing::pack_slice_with_suffix(src_1);
219        let (src_shorts_2, src_suffix_2) = F::Packing::pack_slice_with_suffix(src_2);
220        let (dst_shorts_1, dst_suffix_1) =
221            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_1);
222        let (dst_shorts_2, dst_suffix_2) =
223            F::Packing::pack_maybe_uninit_slice_with_suffix_mut(dst_2);
224        debug_assert_eq!(src_shorts_1.len(), src_shorts_2.len());
225        debug_assert_eq!(src_suffix_1.len(), src_suffix_2.len());
226        debug_assert_eq!(dst_shorts_1.len(), dst_shorts_2.len());
227        debug_assert_eq!(dst_suffix_1.len(), dst_suffix_2.len());
228        // Pre-broadcast the scalar twiddle into a packed field once outside the loop.
229        let twiddle_packed = F::Packing::from(self.0);
230        for (s_1, s_2, d_1, d_2) in izip!(src_shorts_1, src_shorts_2, dst_shorts_1, dst_shorts_2) {
231            let x_2_twiddle = *s_2 * twiddle_packed;
232            d_1.write(*s_1 + x_2_twiddle);
233            d_2.write(*s_1 - x_2_twiddle);
234        }
235        for (s_1, s_2, d_1, d_2) in izip!(src_suffix_1, src_suffix_2, dst_suffix_1, dst_suffix_2) {
236            let (res_1, res_2) = self.apply(*s_1, *s_2);
237            d_1.write(res_1);
238            d_2.write(res_2);
239        }
240    }
241}
242
243/// DIT (Decimation-In-Time) butterfly operation with a post-multiplication scale factor.
244///
245/// This butterfly computes:
246/// ```text
247///   output_1 = (x1 + x2 * twiddle) * scale
248///   output_2 = (x1 - x2 * twiddle) * scale
249/// ```
250/// which is equivalent to:
251/// ```text
252///   output_1 = x1 * scale + x2 * (twiddle * scale)
253///   output_2 = x1 * scale - x2 * (twiddle * scale)
254/// ```
255///
256/// This is used to merge a uniform scaling step (e.g., 1/N normalization in inverse DFT)
257/// into a butterfly pass, avoiding a separate memory pass over the data.
258///
259/// The struct stores `scale` and `twiddle_times_scale = twiddle * scale` so that the
260/// `apply` method only needs 2 multiplications instead of 3.
261#[derive(Copy, Clone)]
262pub struct ScaledDitButterfly<F> {
263    pub twiddle: F,
264    pub scale: F,
265    /// Precomputed product `twiddle * scale` to reduce multiplications in the hot loop.
266    pub twiddle_times_scale: F,
267}
268
269impl<F: Field> ScaledDitButterfly<F> {
270    /// Construct a `ScaledDitButterfly`, precomputing `twiddle * scale`.
271    #[inline]
272    pub fn new(twiddle: F, scale: F) -> Self {
273        Self {
274            twiddle,
275            scale,
276            twiddle_times_scale: twiddle * scale,
277        }
278    }
279}
280
281impl<F: Field> Butterfly<F> for ScaledDitButterfly<F> {
282    #[inline]
283    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
284        // 2 multiplications instead of 3:
285        //   x1_s   = x1 * scale
286        //   x2_ts  = x2 * (twiddle * scale)   [precomputed]
287        //   out1   = x1_s + x2_ts
288        //   out2   = x1_s - x2_ts
289        let x_1_scale = x_1 * self.scale;
290        let x_2_twiddle_scale = x_2 * self.twiddle_times_scale;
291        (x_1_scale + x_2_twiddle_scale, x_1_scale - x_2_twiddle_scale)
292    }
293
294    /// Override `apply_to_rows` to pre-broadcast both `scale` and `twiddle_times_scale`
295    /// into packed fields once before the inner loop.
296    #[inline]
297    fn apply_to_rows(&self, row_1: &mut [F], row_2: &mut [F]) {
298        let (shorts_1, suffix_1) = F::Packing::pack_slice_with_suffix_mut(row_1);
299        let (shorts_2, suffix_2) = F::Packing::pack_slice_with_suffix_mut(row_2);
300        debug_assert_eq!(shorts_1.len(), shorts_2.len());
301        debug_assert_eq!(suffix_1.len(), suffix_2.len());
302        let scale_packed = F::Packing::from(self.scale);
303        let twiddle_times_scale_packed = F::Packing::from(self.twiddle_times_scale);
304        for (x_1, x_2) in shorts_1.iter_mut().zip(shorts_2.iter_mut()) {
305            let x_1_scale = *x_1 * scale_packed;
306            let x_2_twiddle_scale = *x_2 * twiddle_times_scale_packed;
307            *x_1 = x_1_scale + x_2_twiddle_scale;
308            *x_2 = x_1_scale - x_2_twiddle_scale;
309        }
310        for (x_1, x_2) in suffix_1.iter_mut().zip(suffix_2.iter_mut()) {
311            self.apply_in_place(x_1, x_2);
312        }
313    }
314}
315
316/// Butterfly with no twiddle factor (`twiddle = 1`).
317///
318/// This is used when no root-of-unity scaling is needed.
319/// It works for either DIT or DIF, and is often used at
320/// the final or base level of a transform tree.
321///
322/// This butterfly computes:
323/// ```text
324///   - output_1 = x1 + x2
325///   - output_2 = x1 - x2
326/// ```
327#[derive(Copy, Clone)]
328pub struct TwiddleFreeButterfly;
329
330impl<F: Field> Butterfly<F> for TwiddleFreeButterfly {
331    #[inline]
332    fn apply<PF: PackedField<Scalar = F>>(&self, x_1: PF, x_2: PF) -> (PF, PF) {
333        (x_1 + x_2, x_1 - x_2)
334    }
335}