dcrypt_algorithms/poly/fft/
mod.rs1#![cfg_attr(not(feature = "std"), no_std)]
12#![allow(clippy::needless_range_loop)]
13
14#[cfg(feature = "alloc")]
15extern crate alloc;
16#[cfg(feature = "alloc")]
17use alloc::vec::Vec;
18
19use crate::ec::bls12_381::Bls12_381Scalar as Scalar;
20use crate::error::{Error, Result};
21use std::sync::OnceLock;
22
23const FFT_SIZE: usize = 256;
24
25const TWO_ADICITY_FR: u32 = 32;
27const FR_ODD_PART: [u64; 4] = [
28 0xfffe_5bfe_ffff_ffff,
29 0x09a1_d805_53bd_a402,
30 0x299d_7d48_3339_d808,
31 0x0000_0000_73ed_a753,
32];
33
34static ROOT_OF_UNITY: OnceLock<Scalar> = OnceLock::new();
36static FFT_N_ROOT: OnceLock<Scalar> = OnceLock::new();
37static ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
38static INVERSE_ROOTS_OF_UNITY: OnceLock<Vec<Scalar>> = OnceLock::new();
39static N_INV: OnceLock<Scalar> = OnceLock::new();
40static PRIMITIVE_2N_ROOT: OnceLock<Scalar> = OnceLock::new();
41static TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
42static INVERSE_TWIST_FACTORS: OnceLock<Vec<Scalar>> = OnceLock::new();
43
44fn get_root_of_unity() -> &'static Scalar {
46 ROOT_OF_UNITY.get_or_init(|| {
47 Scalar::from_raw([
48 0x4253_d252_a210_b619,
49 0x81c3_5f15_01a0_2431,
50 0xb734_6a32_008b_0320,
51 0x0a16_14a8_64b3_09e1,
52 ])
53 })
54}
55
56#[inline]
59fn pow_vartime_u64x4(base: Scalar, by: &[u64; 4]) -> Scalar {
60 let mut res = Scalar::one();
61 for e in by.iter().rev() {
62 for i in (0..64).rev() {
63 res = res.square();
64 if ((*e >> i) & 1) == 1 {
65 res *= base;
66 }
67 }
68 }
69 res
70}
71
72#[inline]
74fn project_to_2power(x: Scalar) -> Scalar {
75 pow_vartime_u64x4(x, &FR_ODD_PART)
76}
77
78fn two_adicity(mut r: Scalar) -> u32 {
81 for k in 1..=TWO_ADICITY_FR {
82 r = r.square();
83 if r == Scalar::one() {
84 return k;
85 }
86 }
87 debug_assert!(false, "two_adicity: element not in μ_{{2^S}}");
89 TWO_ADICITY_FR
90}
91
92fn select_2power_seed(min_k: u32) -> (Scalar, u32) {
94 let bases: [Scalar; 12] = [
95 *get_root_of_unity(),
96 Scalar::from(5u64),
97 Scalar::from(7u64),
98 Scalar::from(2u64),
99 Scalar::from(3u64),
100 Scalar::from(11u64),
101 Scalar::from(13u64),
102 Scalar::from(17u64),
103 Scalar::from(19u64),
104 Scalar::from(29u64),
105 Scalar::from(31u64),
106 Scalar::from(37u64),
107 ];
108
109 for base in bases.iter() {
110 let seed = project_to_2power(*base);
111 if !bool::from(seed.is_zero()) {
112 let k = two_adicity(seed);
113 if k >= min_k {
114 return (seed, k);
115 }
116 }
117 }
118
119 panic!("Could not find a suitable 2-power root of unity seed");
120}
121
122fn get_fft_n_root() -> &'static Scalar {
125 FFT_N_ROOT.get_or_init(|| {
126 let need = FFT_SIZE.trailing_zeros();
127 let (seed, k) = select_2power_seed(need);
128
129 let mut w_n = seed;
130 for _ in 0..(k - need) {
131 w_n = w_n.square();
132 }
133
134 #[cfg(debug_assertions)]
135 {
136 let mut t = w_n;
137 for _ in 0..need {
138 t = t.square();
139 }
140 debug_assert_eq!(t, Scalar::one(), "w_N^N must be 1");
141
142 let mut half = w_n;
143 for _ in 0..(need - 1) {
144 half = half.square();
145 }
146 debug_assert_eq!(half, -Scalar::one(), "w_N^(N/2) must be -1");
147 }
148 w_n
149 })
150}
151
152fn get_roots_of_unity() -> &'static Vec<Scalar> {
153 ROOTS_OF_UNITY.get_or_init(|| {
154 let w_n = *get_fft_n_root();
155 let mut roots = vec![Scalar::one(); FFT_SIZE];
156 for i in 1..FFT_SIZE {
157 roots[i] = roots[i - 1] * w_n;
158 }
159 roots
160 })
161}
162
163fn get_inverse_roots_of_unity() -> &'static Vec<Scalar> {
164 INVERSE_ROOTS_OF_UNITY.get_or_init(|| {
165 let inv_w_n = get_fft_n_root().invert().unwrap();
166 let mut roots = vec![Scalar::one(); FFT_SIZE];
167 for i in 1..FFT_SIZE {
168 roots[i] = roots[i - 1] * inv_w_n;
169 }
170 roots
171 })
172}
173
174fn get_n_inv() -> &'static Scalar {
175 N_INV.get_or_init(|| Scalar::from(FFT_SIZE as u64).invert().unwrap())
176}
177
178fn get_primitive_2n_root() -> &'static Scalar {
179 PRIMITIVE_2N_ROOT.get_or_init(|| {
180 let need = FFT_SIZE.trailing_zeros();
181 let (seed, k) = select_2power_seed(need + 1);
182
183 let mut g = seed;
184 for _ in 0..(k - (need + 1)) {
185 g = g.square();
186 }
187
188 debug_assert_eq!(g.square(), *get_fft_n_root(), "g^2 must equal w_N");
189
190 let mut gn = g;
191 for _ in 0..need {
192 gn = gn.square();
193 }
194 debug_assert_eq!(gn, -Scalar::one(), "g^N must be -1");
195
196 g
197 })
198}
199
200fn get_twist_factors() -> &'static Vec<Scalar> {
201 TWIST_FACTORS.get_or_init(|| {
202 let g = *get_primitive_2n_root();
203 let mut factors = vec![Scalar::one(); FFT_SIZE];
204 for i in 1..FFT_SIZE {
205 factors[i] = factors[i - 1] * g;
206 }
207 factors
208 })
209}
210
211fn get_inverse_twist_factors() -> &'static Vec<Scalar> {
212 INVERSE_TWIST_FACTORS.get_or_init(|| {
213 let inv_g = get_primitive_2n_root().invert().unwrap();
214 let mut factors = vec![Scalar::one(); FFT_SIZE];
215 for i in 1..FFT_SIZE {
216 factors[i] = factors[i - 1] * inv_g;
217 }
218 factors
219 })
220}
221
222fn bit_reverse_permutation<T>(data: &mut [T]) {
224 let n = data.len();
225 let mut j = 0;
226 for i in 1..n {
227 let mut bit = n >> 1;
228 while (j & bit) != 0 {
229 j ^= bit;
230 bit >>= 1;
231 }
232 j ^= bit;
233 if i < j {
234 data.swap(i, j);
235 }
236 }
237}
238
239fn fft_cooley_tukey(coeffs: &mut [Scalar], roots: &[Scalar]) {
241 let n = coeffs.len();
242 let mut len = 2;
243 while len <= n {
244 let half_len = len >> 1;
245 let step = roots.len() / len;
246 let root = roots[step];
247 for i in (0..n).step_by(len) {
248 let mut w = Scalar::one();
249 for j in 0..half_len {
250 let u = coeffs[i + j];
251 let v = coeffs[i + j + half_len] * w;
252 coeffs[i + j] = u + v;
253 coeffs[i + j + half_len] = u - v;
254 w *= root;
255 }
256 }
257 len <<= 1;
258 }
259}
260
261pub fn fft(coeffs: &mut [Scalar]) -> Result<()> {
263 if coeffs.len() != FFT_SIZE || !coeffs.len().is_power_of_two() {
264 return Err(Error::Parameter {
265 name: "coeffs".into(),
266 reason: "FFT length must be a power of two (256)".into(),
267 });
268 }
269 bit_reverse_permutation(coeffs);
270 fft_cooley_tukey(coeffs, get_roots_of_unity());
271 Ok(())
272}
273
274pub fn ifft(evals: &mut [Scalar]) -> Result<()> {
276 if evals.len() != FFT_SIZE || !evals.len().is_power_of_two() {
277 return Err(Error::Parameter {
278 name: "evals".into(),
279 reason: "FFT length must be a power of two (256)".into(),
280 });
281 }
282 bit_reverse_permutation(evals);
283 fft_cooley_tukey(evals, get_inverse_roots_of_unity());
284
285 let n_inv = get_n_inv();
286 for c in evals.iter_mut() {
287 *c *= *n_inv;
288 }
289 Ok(())
290}
291
292pub fn fft_negacyclic(coeffs: &mut [Scalar]) -> Result<()> {
294 if coeffs.len() != FFT_SIZE {
295 return Err(Error::Parameter {
296 name: "coeffs".into(),
297 reason: "Negacyclic FFT requires length 256".into(),
298 });
299 }
300
301 let twists = get_twist_factors();
302 for i in 0..FFT_SIZE {
303 coeffs[i] *= twists[i];
304 }
305
306 fft(coeffs)
307}
308
309pub fn ifft_negacyclic(evals: &mut [Scalar]) -> Result<()> {
311 if evals.len() != FFT_SIZE {
312 return Err(Error::Parameter {
313 name: "evals".into(),
314 reason: "Negacyclic IFFT requires length 256".into(),
315 });
316 }
317
318 ifft(evals)?;
319
320 let inv_twists = get_inverse_twist_factors();
321 for i in 0..FFT_SIZE {
322 evals[i] *= inv_twists[i];
323 }
324
325 Ok(())
326}
327
328#[cfg(test)]
329mod tests;