dcrypt_algorithms/poly/fft/
mod.rs1#![cfg_attr(not(feature = "std"), no_std)]
12#![allow(clippy::needless_range_loop)]
13
14#[cfg(feature = "alloc")]
15extern crate alloc;
16#[cfg(feature = "alloc")]
17use alloc::vec::Vec;
18
19use crate::ec::bls12_381::Bls12_381Scalar as Scalar;
20use crate::error::{Error, Result};
21use std::sync::OnceLock;
22
23const FFT_SIZE: usize = 256;
24
25const TWO_ADICITY_FR: u32 = 32;
27const FR_ODD_PART: [u64; 4] = [
28 0xfffe_5bfe_ffff_ffff,
29 0x09a1_d805_53bd_a402,
30 0x299d_7d48_3339_d808,
31 0x0000_0000_73ed_a753,
32];
33
34static ROOT_OF_UNITY: OnceLock<Scalar> = OnceLock::new();
36static FFT_N_ROOT: OnceLock<Scalar> = OnceLock::new();
37static ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
38static INVERSE_ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
39static N_INV: OnceLock<Scalar> = OnceLock::new();
40static PRIMITIVE_2N_ROOT: OnceLock<Scalar> = OnceLock::new();
41static TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
42static INVERSE_TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
43
44fn get_root_of_unity() -> &'static Scalar {
46 ROOT_OF_UNITY.get_or_init(|| {
47 Scalar::from_raw([
48 0x4253_d252_a210_b619, 0x81c3_5f15_01a0_2431,
49 0xb734_6a32_008b_0320, 0x0a16_14a8_64b3_09e1
50 ])
51 })
52}
53
54#[inline]
57fn pow_vartime_u64x4(base: Scalar, by: &[u64; 4]) -> Scalar {
58 let mut res = Scalar::one();
59 for e in by.iter().rev() {
60 for i in (0..64).rev() {
61 res = res.square();
62 if ((*e >> i) & 1) == 1 {
63 res *= base;
64 }
65 }
66 }
67 res
68}
69
70#[inline]
72fn project_to_2power(x: Scalar) -> Scalar {
73 pow_vartime_u64x4(x, &FR_ODD_PART)
74}
75
76fn two_adicity(mut r: Scalar) -> u32 {
79 for k in 1..=TWO_ADICITY_FR {
80 r = r.square();
81 if r == Scalar::one() {
82 return k;
83 }
84 }
85 debug_assert!(false, "two_adicity: element not in μ_{{2^S}}");
87 TWO_ADICITY_FR
88}
89
90fn select_2power_seed(min_k: u32) -> (Scalar, u32) {
92 let bases: [Scalar; 12] = [
93 *get_root_of_unity(),
94 Scalar::from(5u64), Scalar::from(7u64), Scalar::from(2u64),
95 Scalar::from(3u64), Scalar::from(11u64), Scalar::from(13u64),
96 Scalar::from(17u64), Scalar::from(19u64), Scalar::from(29u64),
97 Scalar::from(31u64), Scalar::from(37u64),
98 ];
99
100 for base in bases.iter() {
101 let seed = project_to_2power(*base);
102 if !bool::from(seed.is_zero()) {
103 let k = two_adicity(seed);
104 if k >= min_k {
105 return (seed, k);
106 }
107 }
108 }
109
110 panic!("Could not find a suitable 2-power root of unity seed");
111}
112
113fn get_fft_n_root() -> &'static Scalar {
116 FFT_N_ROOT.get_or_init(|| {
117 let need = FFT_SIZE.trailing_zeros();
118 let (seed, k) = select_2power_seed(need);
119
120 let mut w_n = seed;
121 for _ in 0..(k - need) {
122 w_n = w_n.square();
123 }
124
125 #[cfg(debug_assertions)]
126 {
127 let mut t = w_n;
128 for _ in 0..need { t = t.square(); }
129 debug_assert_eq!(t, Scalar::one(), "w_N^N must be 1");
130
131 let mut half = w_n;
132 for _ in 0..(need - 1) { half = half.square(); }
133 debug_assert_eq!(half, -Scalar::one(), "w_N^(N/2) must be -1");
134 }
135 w_n
136 })
137}
138
139fn get_roots_of_unity() -> &'static Vec<Scalar> {
140 ROOTS_OF_UNITY.get_or_init(|| {
141 let w_n = *get_fft_n_root();
142 let mut roots = vec![Scalar::one(); FFT_SIZE];
143 for i in 1..FFT_SIZE {
144 roots[i] = roots[i - 1] * w_n;
145 }
146 roots
147 })
148}
149
150fn get_inverse_roots_of_unity() -> &'static Vec<Scalar> {
151 INVERSE_ROOTS_OF_UNITY.get_or_init(|| {
152 let inv_w_n = get_fft_n_root().invert().unwrap();
153 let mut roots = vec![Scalar::one(); FFT_SIZE];
154 for i in 1..FFT_SIZE {
155 roots[i] = roots[i - 1] * inv_w_n;
156 }
157 roots
158 })
159}
160
161fn get_n_inv() -> &'static Scalar {
162 N_INV.get_or_init(|| Scalar::from(FFT_SIZE as u64).invert().unwrap())
163}
164
165fn get_primitive_2n_root() -> &'static Scalar {
166 PRIMITIVE_2N_ROOT.get_or_init(|| {
167 let need = FFT_SIZE.trailing_zeros();
168 let (seed, k) = select_2power_seed(need + 1);
169
170 let mut g = seed;
171 for _ in 0..(k - (need + 1)) {
172 g = g.square();
173 }
174
175 debug_assert_eq!(g.square(), *get_fft_n_root(), "g^2 must equal w_N");
176
177 let mut gn = g;
178 for _ in 0..need { gn = gn.square(); }
179 debug_assert_eq!(gn, -Scalar::one(), "g^N must be -1");
180
181 g
182 })
183}
184
185fn get_twist_factors() -> &'static Vec<Scalar> {
186 TWIST_FACTORS.get_or_init(|| {
187 let g = *get_primitive_2n_root();
188 let mut factors = vec![Scalar::one(); FFT_SIZE];
189 for i in 1..FFT_SIZE {
190 factors[i] = factors[i - 1] * g;
191 }
192 factors
193 })
194}
195
196fn get_inverse_twist_factors() -> &'static Vec<Scalar> {
197 INVERSE_TWIST_FACTORS.get_or_init(|| {
198 let inv_g = get_primitive_2n_root().invert().unwrap();
199 let mut factors = vec![Scalar::one(); FFT_SIZE];
200 for i in 1..FFT_SIZE {
201 factors[i] = factors[i - 1] * inv_g;
202 }
203 factors
204 })
205}
206
207
208fn bit_reverse_permutation<T>(data: &mut [T]) {
210 let n = data.len();
211 let mut j = 0;
212 for i in 1..n {
213 let mut bit = n >> 1;
214 while (j & bit) != 0 {
215 j ^= bit;
216 bit >>= 1;
217 }
218 j ^= bit;
219 if i < j {
220 data.swap(i, j);
221 }
222 }
223}
224
225fn fft_cooley_tukey(coeffs: &mut [Scalar], roots: &[Scalar]) {
227 let n = coeffs.len();
228 let mut len = 2;
229 while len <= n {
230 let half_len = len >> 1;
231 let step = roots.len() / len;
232 let root = roots[step];
233 for i in (0..n).step_by(len) {
234 let mut w = Scalar::one();
235 for j in 0..half_len {
236 let u = coeffs[i + j];
237 let v = coeffs[i + j + half_len] * w;
238 coeffs[i + j] = u + v;
239 coeffs[i + j + half_len] = u - v;
240 w *= root;
241 }
242 }
243 len <<= 1;
244 }
245}
246
247pub fn fft(coeffs: &mut [Scalar]) -> Result<()> {
249 if coeffs.len() != FFT_SIZE || !coeffs.len().is_power_of_two() {
250 return Err(Error::Parameter {
251 name: "coeffs".into(),
252 reason: "FFT length must be a power of two (256)".into(),
253 });
254 }
255 bit_reverse_permutation(coeffs);
256 fft_cooley_tukey(coeffs, get_roots_of_unity());
257 Ok(())
258}
259
260pub fn ifft(evals: &mut [Scalar]) -> Result<()> {
262 if evals.len() != FFT_SIZE || !evals.len().is_power_of_two() {
263 return Err(Error::Parameter {
264 name: "evals".into(),
265 reason: "FFT length must be a power of two (256)".into(),
266 });
267 }
268 bit_reverse_permutation(evals);
269 fft_cooley_tukey(evals, get_inverse_roots_of_unity());
270
271 let n_inv = get_n_inv();
272 for c in evals.iter_mut() {
273 *c *= *n_inv;
274 }
275 Ok(())
276}
277
278pub fn fft_negacyclic(coeffs: &mut [Scalar]) -> Result<()> {
280 if coeffs.len() != FFT_SIZE {
281 return Err(Error::Parameter {
282 name: "coeffs".into(),
283 reason: "Negacyclic FFT requires length 256".into(),
284 });
285 }
286
287 let twists = get_twist_factors();
288 for i in 0..FFT_SIZE {
289 coeffs[i] *= twists[i];
290 }
291
292 fft(coeffs)
293}
294
295pub fn ifft_negacyclic(evals: &mut [Scalar]) -> Result<()> {
297 if evals.len() != FFT_SIZE {
298 return Err(Error::Parameter {
299 name: "evals".into(),
300 reason: "Negacyclic IFFT requires length 256".into(),
301 });
302 }
303
304 ifft(evals)?;
305
306 let inv_twists = get_inverse_twist_factors();
307 for i in 0..FFT_SIZE {
308 evals[i] *= inv_twists[i];
309 }
310
311 Ok(())
312}
313
314
315#[cfg(test)]
316mod tests;