dcrypt_algorithms/poly/fft/
mod.rs

1// Path: dcrypt/crates/algorithms/src/poly/fft/mod.rs
2//! Fast Fourier Transform (FFT) over the BLS12-381 Scalar Field
3//!
4//! This module implements a Number Theoretic Transform (NTT), which is an FFT
5//! adapted for finite fields. It operates on vectors of `Scalar` elements from the
6//! BLS12-381 curve, enabling O(n log n) polynomial multiplication and interpolation.
7//!
8//! This is the high-performance engine required for schemes like Verkle trees that
9//! rely on polynomial commitments over large prime fields.
10
11#![cfg_attr(not(feature = "std"), no_std)]
12#![allow(clippy::needless_range_loop)]
13
14#[cfg(feature = "alloc")]
15extern crate alloc;
16#[cfg(feature = "alloc")]
17use alloc::vec::Vec;
18
19use crate::ec::bls12_381::Bls12_381Scalar as Scalar;
20use crate::error::{Error, Result};
21use std::sync::OnceLock;
22
23const FFT_SIZE: usize = 256;
24
25// --- field-specific 2-adicity and odd cofactor for BLS12-381 Fr ---
26const TWO_ADICITY_FR: u32 = 32;
27const FR_ODD_PART: [u64; 4] = [
28    0xfffe_5bfe_ffff_ffff,
29    0x09a1_d805_53bd_a402,
30    0x299d_7d48_3339_d808,
31    0x0000_0000_73ed_a753,
32];
33
34// Statics
35static ROOT_OF_UNITY: OnceLock<Scalar> = OnceLock::new();
36static FFT_N_ROOT: OnceLock<Scalar> = OnceLock::new();
37static ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
38static INVERSE_ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
39static N_INV: OnceLock<Scalar> = OnceLock::new();
40static PRIMITIVE_2N_ROOT: OnceLock<Scalar> = OnceLock::new();
41static TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
42static INVERSE_TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
43
44// The original hardcoded constant (kept as a seed candidate).
45fn get_root_of_unity() -> &'static Scalar {
46    ROOT_OF_UNITY.get_or_init(|| {
47        Scalar::from_raw([
48            0x4253_d252_a210_b619, 0x81c3_5f15_01a0_2431,
49            0xb734_6a32_008b_0320, 0x0a16_14a8_64b3_09e1
50        ])
51    })
52}
53
54// --- NEW: small helpers ---
55
56#[inline]
57fn pow_vartime_u64x4(base: Scalar, by: &[u64; 4]) -> Scalar {
58    let mut res = Scalar::one();
59    for e in by.iter().rev() {
60        for i in (0..64).rev() {
61            res = res.square();
62            if ((*e >> i) & 1) == 1 {
63                res *= base;
64            }
65        }
66    }
67    res
68}
69
70/// Project an arbitrary element into μ_{2^S}: x ↦ x^T
71#[inline]
72fn project_to_2power(x: Scalar) -> Scalar {
73    pow_vartime_u64x4(x, &FR_ODD_PART)
74}
75
76/// Compute the 2-adic order k of an element r ∈ μ_{2^S}:
77/// the smallest k ≥ 1 such that r^(2^k) = 1.
78fn two_adicity(mut r: Scalar) -> u32 {
79    for k in 1..=TWO_ADICITY_FR {
80        r = r.square();
81        if r == Scalar::one() {
82            return k;
83        }
84    }
85    // FIX: Escape the curly braces in the format string.
86    debug_assert!(false, "two_adicity: element not in μ_{{2^S}}");
87    TWO_ADICITY_FR
88}
89
90/// Deterministically pick a seed in μ_{2^S} whose 2-adic order k ≥ min_k.
91fn select_2power_seed(min_k: u32) -> (Scalar, u32) {
92    let bases: [Scalar; 12] = [
93        *get_root_of_unity(),
94        Scalar::from(5u64), Scalar::from(7u64), Scalar::from(2u64),
95        Scalar::from(3u64), Scalar::from(11u64), Scalar::from(13u64),
96        Scalar::from(17u64), Scalar::from(19u64), Scalar::from(29u64),
97        Scalar::from(31u64), Scalar::from(37u64),
98    ];
99
100    for base in bases.iter() {
101        let seed = project_to_2power(*base);
102        if !bool::from(seed.is_zero()) {
103            let k = two_adicity(seed);
104            if k >= min_k {
105                return (seed, k);
106            }
107        }
108    }
109
110    panic!("Could not find a suitable 2-power root of unity seed");
111}
112
113// --- Derived roots built from a consistent seed ---
114
115fn get_fft_n_root() -> &'static Scalar {
116    FFT_N_ROOT.get_or_init(|| {
117        let need = FFT_SIZE.trailing_zeros();
118        let (seed, k) = select_2power_seed(need);
119
120        let mut w_n = seed;
121        for _ in 0..(k - need) {
122            w_n = w_n.square();
123        }
124
125        #[cfg(debug_assertions)]
126        {
127            let mut t = w_n;
128            for _ in 0..need { t = t.square(); }
129            debug_assert_eq!(t, Scalar::one(), "w_N^N must be 1");
130
131            let mut half = w_n;
132            for _ in 0..(need - 1) { half = half.square(); }
133            debug_assert_eq!(half, -Scalar::one(), "w_N^(N/2) must be -1");
134        }
135        w_n
136    })
137}
138
139fn get_roots_of_unity() -> &'static Vec<Scalar> {
140    ROOTS_OF_UNITY.get_or_init(|| {
141        let w_n = *get_fft_n_root();
142        let mut roots = vec![Scalar::one(); FFT_SIZE];
143        for i in 1..FFT_SIZE {
144            roots[i] = roots[i - 1] * w_n;
145        }
146        roots
147    })
148}
149
150fn get_inverse_roots_of_unity() -> &'static Vec<Scalar> {
151    INVERSE_ROOTS_OF_UNITY.get_or_init(|| {
152        let inv_w_n = get_fft_n_root().invert().unwrap();
153        let mut roots = vec![Scalar::one(); FFT_SIZE];
154        for i in 1..FFT_SIZE {
155            roots[i] = roots[i - 1] * inv_w_n;
156        }
157        roots
158    })
159}
160
161fn get_n_inv() -> &'static Scalar {
162    N_INV.get_or_init(|| Scalar::from(FFT_SIZE as u64).invert().unwrap())
163}
164
165fn get_primitive_2n_root() -> &'static Scalar {
166    PRIMITIVE_2N_ROOT.get_or_init(|| {
167        let need = FFT_SIZE.trailing_zeros();
168        let (seed, k) = select_2power_seed(need + 1);
169
170        let mut g = seed;
171        for _ in 0..(k - (need + 1)) {
172            g = g.square();
173        }
174
175        debug_assert_eq!(g.square(), *get_fft_n_root(), "g^2 must equal w_N");
176        
177        let mut gn = g;
178        for _ in 0..need { gn = gn.square(); }
179        debug_assert_eq!(gn, -Scalar::one(), "g^N must be -1");
180
181        g
182    })
183}
184
185fn get_twist_factors() -> &'static Vec<Scalar> {
186    TWIST_FACTORS.get_or_init(|| {
187        let g = *get_primitive_2n_root();
188        let mut factors = vec![Scalar::one(); FFT_SIZE];
189        for i in 1..FFT_SIZE {
190            factors[i] = factors[i - 1] * g;
191        }
192        factors
193    })
194}
195
196fn get_inverse_twist_factors() -> &'static Vec<Scalar> {
197    INVERSE_TWIST_FACTORS.get_or_init(|| {
198        let inv_g = get_primitive_2n_root().invert().unwrap();
199        let mut factors = vec![Scalar::one(); FFT_SIZE];
200        for i in 1..FFT_SIZE {
201            factors[i] = factors[i - 1] * inv_g;
202        }
203        factors
204    })
205}
206
207
208/// Performs a bit-reversal permutation on the input slice in-place.
209fn bit_reverse_permutation<T>(data: &mut [T]) {
210    let n = data.len();
211    let mut j = 0;
212    for i in 1..n {
213        let mut bit = n >> 1;
214        while (j & bit) != 0 {
215            j ^= bit;
216            bit >>= 1;
217        }
218        j ^= bit;
219        if i < j {
220            data.swap(i, j);
221        }
222    }
223}
224
225/// Core Cooley-Tukey FFT/NTT algorithm.
226fn fft_cooley_tukey(coeffs: &mut [Scalar], roots: &[Scalar]) {
227    let n = coeffs.len();
228    let mut len = 2;
229    while len <= n {
230        let half_len = len >> 1;
231        let step = roots.len() / len;
232        let root = roots[step];
233        for i in (0..n).step_by(len) {
234            let mut w = Scalar::one();
235            for j in 0..half_len {
236                let u = coeffs[i + j];
237                let v = coeffs[i + j + half_len] * w;
238                coeffs[i + j] = u + v;
239                coeffs[i + j + half_len] = u - v;
240                w *= root;
241            }
242        }
243        len <<= 1;
244    }
245}
246
247/// Computes the forward Fast Fourier Transform (NTT) of a polynomial for **cyclic** convolution.
248pub fn fft(coeffs: &mut [Scalar]) -> Result<()> {
249    if coeffs.len() != FFT_SIZE || !coeffs.len().is_power_of_two() {
250        return Err(Error::Parameter {
251            name: "coeffs".into(),
252            reason: "FFT length must be a power of two (256)".into(),
253        });
254    }
255    bit_reverse_permutation(coeffs);
256    fft_cooley_tukey(coeffs, get_roots_of_unity());
257    Ok(())
258}
259
260/// Computes the inverse Fast Fourier Transform (iNTT) for **cyclic** convolution.
261pub fn ifft(evals: &mut [Scalar]) -> Result<()> {
262    if evals.len() != FFT_SIZE || !evals.len().is_power_of_two() {
263        return Err(Error::Parameter {
264            name: "evals".into(),
265            reason: "FFT length must be a power of two (256)".into(),
266        });
267    }
268    bit_reverse_permutation(evals);
269    fft_cooley_tukey(evals, get_inverse_roots_of_unity());
270
271    let n_inv = get_n_inv();
272    for c in evals.iter_mut() {
273        *c *= *n_inv;
274    }
275    Ok(())
276}
277
278/// Computes the forward **negacyclic** NTT.
279pub fn fft_negacyclic(coeffs: &mut [Scalar]) -> Result<()> {
280    if coeffs.len() != FFT_SIZE {
281        return Err(Error::Parameter {
282            name: "coeffs".into(),
283            reason: "Negacyclic FFT requires length 256".into(),
284        });
285    }
286    
287    let twists = get_twist_factors();
288    for i in 0..FFT_SIZE {
289        coeffs[i] *= twists[i];
290    }
291
292    fft(coeffs)
293}
294
295/// Computes the inverse **negacyclic** NTT.
296pub fn ifft_negacyclic(evals: &mut [Scalar]) -> Result<()> {
297    if evals.len() != FFT_SIZE {
298        return Err(Error::Parameter {
299            name: "evals".into(),
300            reason: "Negacyclic IFFT requires length 256".into(),
301        });
302    }
303    
304    ifft(evals)?;
305
306    let inv_twists = get_inverse_twist_factors();
307    for i in 0..FFT_SIZE {
308        evals[i] *= inv_twists[i];
309    }
310    
311    Ok(())
312}
313
314
315#[cfg(test)]
316mod tests;