1use crate::{BinaryFieldExtras, Flat, HardwareField, PackedFlat};
21use alloc::boxed::Box;
22use alloc::vec::Vec;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum FftError {
28 BadLength { expected: usize, got: usize },
29}
30
31impl core::fmt::Display for FftError {
32 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33 match self {
34 FftError::BadLength { expected, got } => {
35 write!(f, "AdditiveFft data length {got}, expected {expected}")
36 }
37 }
38 }
39}
40
41impl core::error::Error for FftError {}
42
43pub struct AdditiveFft<F> {
48 log_n: u32,
49
50 twiddles: Box<[Flat<F>]>,
53}
54
55impl<F: BinaryFieldExtras + HardwareField> AdditiveFft<F> {
56 pub fn new(log_n: u32) -> Self {
65 assert!(
66 (1..=F::BITS).contains(&(log_n as usize)) && log_n < usize::BITS,
67 "AdditiveFft: log_n must be in 1..=min(F::BITS, 63)"
68 );
69
70 let dim = log_n as usize;
71
72 let mut lift: Vec<Flat<F>> = Vec::with_capacity(dim - 1);
73 let mut beta = F::ONE;
74
75 for _ in 1..dim {
76 beta = F::solve_quadratic(beta).expect("field admits no Cantor basis of this size");
77 lift.push(beta.to_hardware());
78 }
79
80 let half = 1usize << (log_n - 1);
81
82 let mut twiddles = Vec::with_capacity(half);
83 for t in 0..half {
84 let mut acc = Flat::from_raw(F::ZERO);
85 let mut bits = t;
86
87 while bits != 0 {
88 let j = bits.trailing_zeros() as usize;
89 acc += lift[j];
90 bits &= bits - 1;
91 }
92
93 twiddles.push(acc);
94 }
95
96 Self {
97 log_n,
98 twiddles: twiddles.into_boxed_slice(),
99 }
100 }
101
102 pub fn forward_scalar(&self, data: &mut [Flat<F>]) -> Result<(), FftError> {
104 self.forward_coset_scalar(data, Flat::from_raw(F::ZERO))
105 }
106
107 pub fn inverse_scalar(&self, data: &mut [Flat<F>]) -> Result<(), FftError> {
109 self.inverse_coset_scalar(data, Flat::from_raw(F::ZERO))
110 }
111
112 pub fn forward_coset_scalar(
114 &self,
115 data: &mut [Flat<F>],
116 offset: Flat<F>,
117 ) -> Result<(), FftError> {
118 self.check_len(data.len())?;
119 self.fwd_scalar(data, 0, 1, self.log_n, offset);
120
121 Ok(())
122 }
123
124 pub fn inverse_coset_scalar(
126 &self,
127 data: &mut [Flat<F>],
128 offset: Flat<F>,
129 ) -> Result<(), FftError> {
130 self.check_len(data.len())?;
131 self.inv_scalar(data, 0, 1, self.log_n, offset);
132
133 Ok(())
134 }
135
136 pub fn forward(&self, data: &mut [PackedFlat<F>]) -> Result<(), FftError> {
138 self.forward_coset(data, Flat::from_raw(F::ZERO))
139 }
140
141 pub fn inverse(&self, data: &mut [PackedFlat<F>]) -> Result<(), FftError> {
143 self.inverse_coset(data, Flat::from_raw(F::ZERO))
144 }
145
146 pub fn forward_coset(
148 &self,
149 data: &mut [PackedFlat<F>],
150 offset: Flat<F>,
151 ) -> Result<(), FftError> {
152 self.check_len(data.len())?;
153 self.fwd_packed(data, 0, 1, self.log_n, offset);
154
155 Ok(())
156 }
157
158 pub fn inverse_coset(
160 &self,
161 data: &mut [PackedFlat<F>],
162 offset: Flat<F>,
163 ) -> Result<(), FftError> {
164 self.check_len(data.len())?;
165 self.inv_packed(data, 0, 1, self.log_n, offset);
166
167 Ok(())
168 }
169
170 fn check_len(&self, got: usize) -> Result<(), FftError> {
171 let expected = 1usize << self.log_n;
172
173 if got != expected {
174 return Err(FftError::BadLength { expected, got });
175 }
176
177 Ok(())
178 }
179
180 fn fwd_scalar(&self, data: &mut [Flat<F>], off: usize, stride: usize, d: u32, coset: Flat<F>) {
186 if d == 0 {
187 return;
188 }
189
190 let half = 1usize << (d - 1);
191 let child = coset * coset + coset;
192
193 self.fwd_scalar(data, off, stride * 2, d - 1, child);
194 self.fwd_scalar(data, off + stride, stride * 2, d - 1, child);
195
196 for t in 0..half {
197 let tw = coset + self.twiddles[t];
198 let i0 = off + 2 * t * stride;
199 let i1 = i0 + stride;
200
201 let p = data[i0];
202 let q = data[i1];
203 let lo = p + tw * q;
204
205 data[i0] = lo;
206 data[i1] = lo + q;
207 }
208 }
209
210 fn inv_scalar(&self, data: &mut [Flat<F>], off: usize, stride: usize, d: u32, coset: Flat<F>) {
214 if d == 0 {
215 return;
216 }
217
218 let half = 1usize << (d - 1);
219 let child = coset * coset + coset;
220
221 for t in 0..half {
222 let tw = coset + self.twiddles[t];
223 let i0 = off + 2 * t * stride;
224 let i1 = i0 + stride;
225
226 let o0 = data[i0];
227 let o1 = data[i1];
228 let q = o0 + o1;
229
230 data[i0] = o0 + tw * q;
231 data[i1] = q;
232 }
233
234 self.inv_scalar(data, off, stride * 2, d - 1, child);
235 self.inv_scalar(data, off + stride, stride * 2, d - 1, child);
236 }
237
238 fn fwd_packed(
239 &self,
240 data: &mut [PackedFlat<F>],
241 off: usize,
242 stride: usize,
243 d: u32,
244 coset: Flat<F>,
245 ) {
246 if d == 0 {
247 return;
248 }
249
250 let half = 1usize << (d - 1);
251 let child = coset * coset + coset;
252
253 self.fwd_packed(data, off, stride * 2, d - 1, child);
254 self.fwd_packed(data, off + stride, stride * 2, d - 1, child);
255
256 for t in 0..half {
257 let tw = coset + self.twiddles[t];
258 let i0 = off + 2 * t * stride;
259 let i1 = i0 + stride;
260
261 let p = data[i0];
262 let q = data[i1];
263 let lo = F::add_hardware_packed(p, F::mul_hardware_scalar_packed(q, tw));
264
265 data[i0] = lo;
266 data[i1] = F::add_hardware_packed(lo, q);
267 }
268 }
269
270 fn inv_packed(
271 &self,
272 data: &mut [PackedFlat<F>],
273 off: usize,
274 stride: usize,
275 d: u32,
276 coset: Flat<F>,
277 ) {
278 if d == 0 {
279 return;
280 }
281
282 let half = 1usize << (d - 1);
283 let child = coset * coset + coset;
284
285 for t in 0..half {
286 let tw = coset + self.twiddles[t];
287 let i0 = off + 2 * t * stride;
288 let i1 = i0 + stride;
289
290 let o0 = data[i0];
291 let o1 = data[i1];
292 let q = F::add_hardware_packed(o0, o1);
293
294 data[i0] = F::add_hardware_packed(o0, F::mul_hardware_scalar_packed(q, tw));
295 data[i1] = q;
296 }
297
298 self.inv_packed(data, off, stride * 2, d - 1, child);
299 self.inv_packed(data, off + stride, stride * 2, d - 1, child);
300 }
301}