Skip to main content

dcrypt_algorithms/poly/fft/
mod.rs

1// Path: dcrypt/crates/algorithms/src/poly/fft/mod.rs
2//! Fast Fourier Transform (FFT) over the BLS12-381 Scalar Field
3//!
4//! This module implements a Number Theoretic Transform (NTT), which is an FFT
5//! adapted for finite fields. It operates on vectors of `Scalar` elements from the
6//! BLS12-381 curve, enabling O(n log n) polynomial multiplication and interpolation.
7//!
8//! This is the high-performance engine required for schemes like Verkle trees that
9//! rely on polynomial commitments over large prime fields.
10
11#![cfg_attr(not(feature = "std"), no_std)]
12#![allow(clippy::needless_range_loop)]
13
14#[cfg(feature = "alloc")]
15extern crate alloc;
16#[cfg(feature = "alloc")]
17use alloc::vec::Vec;
18
19use crate::ec::bls12_381::Bls12_381Scalar as Scalar;
20use crate::error::{Error, Result};
21use std::sync::OnceLock;
22
23const FFT_SIZE: usize = 256;
24
25// --- field-specific 2-adicity and odd cofactor for BLS12-381 Fr ---
26const TWO_ADICITY_FR: u32 = 32;
27const FR_ODD_PART: [u64; 4] = [
28    0xfffe_5bfe_ffff_ffff,
29    0x09a1_d805_53bd_a402,
30    0x299d_7d48_3339_d808,
31    0x0000_0000_73ed_a753,
32];
33
34// Statics
35static ROOT_OF_UNITY: OnceLock<Scalar> = OnceLock::new();
36static FFT_N_ROOT: OnceLock<Scalar> = OnceLock::new();
37static ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
38static INVERSE_ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
39static N_INV: OnceLock<Scalar> = OnceLock::new();
40static PRIMITIVE_2N_ROOT: OnceLock<Scalar> = OnceLock::new();
41static TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
42static INVERSE_TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
43
44// The original hardcoded constant (kept as a seed candidate).
45fn get_root_of_unity() -> &'static Scalar {
46    ROOT_OF_UNITY.get_or_init(|| {
47        Scalar::from_raw([
48            0x4253_d252_a210_b619,
49            0x81c3_5f15_01a0_2431,
50            0xb734_6a32_008b_0320,
51            0x0a16_14a8_64b3_09e1,
52        ])
53    })
54}
55
56// --- NEW: small helpers ---
57
58#[inline]
59fn pow_vartime_u64x4(base: Scalar, by: &[u64; 4]) -> Scalar {
60    let mut res = Scalar::one();
61    for e in by.iter().rev() {
62        for i in (0..64).rev() {
63            res = res.square();
64            if ((*e >> i) & 1) == 1 {
65                res *= base;
66            }
67        }
68    }
69    res
70}
71
72/// Project an arbitrary element into μ_{2^S}: x ↦ x^T
73#[inline]
74fn project_to_2power(x: Scalar) -> Scalar {
75    pow_vartime_u64x4(x, &FR_ODD_PART)
76}
77
78/// Compute the 2-adic order k of an element r ∈ μ_{2^S}:
79/// the smallest k ≥ 1 such that r^(2^k) = 1.
80fn two_adicity(mut r: Scalar) -> u32 {
81    for k in 1..=TWO_ADICITY_FR {
82        r = r.square();
83        if r == Scalar::one() {
84            return k;
85        }
86    }
87    // FIX: Escape the curly braces in the format string.
88    debug_assert!(false, "two_adicity: element not in μ_{{2^S}}");
89    TWO_ADICITY_FR
90}
91
92/// Deterministically pick a seed in μ_{2^S} whose 2-adic order k ≥ min_k.
93fn select_2power_seed(min_k: u32) -> (Scalar, u32) {
94    let bases: [Scalar; 12] = [
95        *get_root_of_unity(),
96        Scalar::from(5u64),
97        Scalar::from(7u64),
98        Scalar::from(2u64),
99        Scalar::from(3u64),
100        Scalar::from(11u64),
101        Scalar::from(13u64),
102        Scalar::from(17u64),
103        Scalar::from(19u64),
104        Scalar::from(29u64),
105        Scalar::from(31u64),
106        Scalar::from(37u64),
107    ];
108
109    for base in bases.iter() {
110        let seed = project_to_2power(*base);
111        if !bool::from(seed.is_zero()) {
112            let k = two_adicity(seed);
113            if k >= min_k {
114                return (seed, k);
115            }
116        }
117    }
118
119    panic!("Could not find a suitable 2-power root of unity seed");
120}
121
122// --- Derived roots built from a consistent seed ---
123
124fn get_fft_n_root() -> &'static Scalar {
125    FFT_N_ROOT.get_or_init(|| {
126        let need = FFT_SIZE.trailing_zeros();
127        let (seed, k) = select_2power_seed(need);
128
129        let mut w_n = seed;
130        for _ in 0..(k - need) {
131            w_n = w_n.square();
132        }
133
134        #[cfg(debug_assertions)]
135        {
136            let mut t = w_n;
137            for _ in 0..need {
138                t = t.square();
139            }
140            debug_assert_eq!(t, Scalar::one(), "w_N^N must be 1");
141
142            let mut half = w_n;
143            for _ in 0..(need - 1) {
144                half = half.square();
145            }
146            debug_assert_eq!(half, -Scalar::one(), "w_N^(N/2) must be -1");
147        }
148        w_n
149    })
150}
151
152fn get_roots_of_unity() -> &'static Vec<Scalar> {
153    ROOTS_OF_UNITY.get_or_init(|| {
154        let w_n = *get_fft_n_root();
155        let mut roots = vec![Scalar::one(); FFT_SIZE];
156        for i in 1..FFT_SIZE {
157            roots[i] = roots[i - 1] * w_n;
158        }
159        roots
160    })
161}
162
163fn get_inverse_roots_of_unity() -> &'static Vec<Scalar> {
164    INVERSE_ROOTS_OF_UNITY.get_or_init(|| {
165        let inv_w_n = get_fft_n_root().invert().unwrap();
166        let mut roots = vec![Scalar::one(); FFT_SIZE];
167        for i in 1..FFT_SIZE {
168            roots[i] = roots[i - 1] * inv_w_n;
169        }
170        roots
171    })
172}
173
174fn get_n_inv() -> &'static Scalar {
175    N_INV.get_or_init(|| Scalar::from(FFT_SIZE as u64).invert().unwrap())
176}
177
178fn get_primitive_2n_root() -> &'static Scalar {
179    PRIMITIVE_2N_ROOT.get_or_init(|| {
180        let need = FFT_SIZE.trailing_zeros();
181        let (seed, k) = select_2power_seed(need + 1);
182
183        let mut g = seed;
184        for _ in 0..(k - (need + 1)) {
185            g = g.square();
186        }
187
188        debug_assert_eq!(g.square(), *get_fft_n_root(), "g^2 must equal w_N");
189
190        let mut gn = g;
191        for _ in 0..need {
192            gn = gn.square();
193        }
194        debug_assert_eq!(gn, -Scalar::one(), "g^N must be -1");
195
196        g
197    })
198}
199
200fn get_twist_factors() -> &'static Vec<Scalar> {
201    TWIST_FACTORS.get_or_init(|| {
202        let g = *get_primitive_2n_root();
203        let mut factors = vec![Scalar::one(); FFT_SIZE];
204        for i in 1..FFT_SIZE {
205            factors[i] = factors[i - 1] * g;
206        }
207        factors
208    })
209}
210
211fn get_inverse_twist_factors() -> &'static Vec<Scalar> {
212    INVERSE_TWIST_FACTORS.get_or_init(|| {
213        let inv_g = get_primitive_2n_root().invert().unwrap();
214        let mut factors = vec![Scalar::one(); FFT_SIZE];
215        for i in 1..FFT_SIZE {
216            factors[i] = factors[i - 1] * inv_g;
217        }
218        factors
219    })
220}
221
222/// Performs a bit-reversal permutation on the input slice in-place.
223fn bit_reverse_permutation<T>(data: &mut [T]) {
224    let n = data.len();
225    let mut j = 0;
226    for i in 1..n {
227        let mut bit = n >> 1;
228        while (j & bit) != 0 {
229            j ^= bit;
230            bit >>= 1;
231        }
232        j ^= bit;
233        if i < j {
234            data.swap(i, j);
235        }
236    }
237}
238
239/// Core Cooley-Tukey FFT/NTT algorithm.
240fn fft_cooley_tukey(coeffs: &mut [Scalar], roots: &[Scalar]) {
241    let n = coeffs.len();
242    let mut len = 2;
243    while len <= n {
244        let half_len = len >> 1;
245        let step = roots.len() / len;
246        let root = roots[step];
247        for i in (0..n).step_by(len) {
248            let mut w = Scalar::one();
249            for j in 0..half_len {
250                let u = coeffs[i + j];
251                let v = coeffs[i + j + half_len] * w;
252                coeffs[i + j] = u + v;
253                coeffs[i + j + half_len] = u - v;
254                w *= root;
255            }
256        }
257        len <<= 1;
258    }
259}
260
261/// Computes the forward Fast Fourier Transform (NTT) of a polynomial for **cyclic** convolution.
262pub fn fft(coeffs: &mut [Scalar]) -> Result<()> {
263    if coeffs.len() != FFT_SIZE || !coeffs.len().is_power_of_two() {
264        return Err(Error::Parameter {
265            name: "coeffs".into(),
266            reason: "FFT length must be a power of two (256)".into(),
267        });
268    }
269    bit_reverse_permutation(coeffs);
270    fft_cooley_tukey(coeffs, get_roots_of_unity());
271    Ok(())
272}
273
274/// Computes the inverse Fast Fourier Transform (iNTT) for **cyclic** convolution.
275pub fn ifft(evals: &mut [Scalar]) -> Result<()> {
276    if evals.len() != FFT_SIZE || !evals.len().is_power_of_two() {
277        return Err(Error::Parameter {
278            name: "evals".into(),
279            reason: "FFT length must be a power of two (256)".into(),
280        });
281    }
282    bit_reverse_permutation(evals);
283    fft_cooley_tukey(evals, get_inverse_roots_of_unity());
284
285    let n_inv = get_n_inv();
286    for c in evals.iter_mut() {
287        *c *= *n_inv;
288    }
289    Ok(())
290}
291
292/// Computes the forward **negacyclic** NTT.
293pub fn fft_negacyclic(coeffs: &mut [Scalar]) -> Result<()> {
294    if coeffs.len() != FFT_SIZE {
295        return Err(Error::Parameter {
296            name: "coeffs".into(),
297            reason: "Negacyclic FFT requires length 256".into(),
298        });
299    }
300
301    let twists = get_twist_factors();
302    for i in 0..FFT_SIZE {
303        coeffs[i] *= twists[i];
304    }
305
306    fft(coeffs)
307}
308
309/// Computes the inverse **negacyclic** NTT.
310pub fn ifft_negacyclic(evals: &mut [Scalar]) -> Result<()> {
311    if evals.len() != FFT_SIZE {
312        return Err(Error::Parameter {
313            name: "evals".into(),
314            reason: "Negacyclic IFFT requires length 256".into(),
315        });
316    }
317
318    ifft(evals)?;
319
320    let inv_twists = get_inverse_twist_factors();
321    for i in 0..FFT_SIZE {
322        evals[i] *= inv_twists[i];
323    }
324
325    Ok(())
326}
327
328#[cfg(test)]
329mod tests;