Skip to main content

hekate_math/fft/
additive.rs

1// SPDX-License-Identifier: Apache-2.0
2// This file is part of the hekate-math project.
3// Copyright (C) 2026 Andrei Kochergin <andrei@oumuamua.dev>
4// Copyright (C) 2026 Oumuamua Labs <info@oumuamua.dev>.
5//
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10//     http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Gao–Mateer additive FFT (Cantor basis).
19
20use crate::{BinaryFieldExtras, Flat, HardwareField, PackedFlat};
21use alloc::boxed::Box;
22use alloc::vec::Vec;
23
24/// Error returned by the additive-FFT transforms.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum FftError {
28    BadLength { expected: usize, got: usize },
29}
30
31impl core::fmt::Display for FftError {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        match self {
34            FftError::BadLength { expected, got } => {
35                write!(f, "AdditiveFft data length {got}, expected {expected}")
36            }
37        }
38    }
39}
40
41impl core::error::Error for FftError {}
42
43/// In-place additive FFT over a 2^log_n subspace of a binary
44/// tower field. Transforms return `Err(FftError::BadLength)`
45/// unless data.len() == 2^log_n; on success the buffer
46/// is overwritten in place.
47pub struct AdditiveFft<F> {
48    log_n: u32,
49
50    // twiddles[t] = Σ_{bit i of t} β_{i+1},
51    // flat basis.
52    twiddles: Box<[Flat<F>]>,
53}
54
55impl<F: BinaryFieldExtras + HardwareField> AdditiveFft<F> {
56    /// Derives the Cantor basis (via solve_quadratic) and
57    /// the twiddle schedule for transform size 2^log_n.
58    /// This one-time allocation is the only heap use;
59    /// the transforms are in-place.
60    ///
61    /// # Panics
62    /// If log_n is not in 1..=min(F::BITS, 63),
63    /// or F admits no Cantor basis of that size.
64    pub fn new(log_n: u32) -> Self {
65        assert!(
66            (1..=F::BITS).contains(&(log_n as usize)) && log_n < usize::BITS,
67            "AdditiveFft: log_n must be in 1..=min(F::BITS, 63)"
68        );
69
70        let dim = log_n as usize;
71
72        let mut lift: Vec<Flat<F>> = Vec::with_capacity(dim - 1);
73        let mut beta = F::ONE;
74
75        for _ in 1..dim {
76            beta = F::solve_quadratic(beta).expect("field admits no Cantor basis of this size");
77            lift.push(beta.to_hardware());
78        }
79
80        let half = 1usize << (log_n - 1);
81
82        let mut twiddles = Vec::with_capacity(half);
83        for t in 0..half {
84            let mut acc = Flat::from_raw(F::ZERO);
85            let mut bits = t;
86
87            while bits != 0 {
88                let j = bits.trailing_zeros() as usize;
89                acc += lift[j];
90                bits &= bits - 1;
91            }
92
93            twiddles.push(acc);
94        }
95
96        Self {
97            log_n,
98            twiddles: twiddles.into_boxed_slice(),
99        }
100    }
101
102    /// Forward: novel-basis coefficients to evaluations.
103    pub fn forward_scalar(&self, data: &mut [Flat<F>]) -> Result<(), FftError> {
104        self.forward_coset_scalar(data, Flat::from_raw(F::ZERO))
105    }
106
107    /// Inverse: evaluations to novel-basis coefficients.
108    pub fn inverse_scalar(&self, data: &mut [Flat<F>]) -> Result<(), FftError> {
109        self.inverse_coset_scalar(data, Flat::from_raw(F::ZERO))
110    }
111
112    /// Forward over the coset offset + W_log_n.
113    pub fn forward_coset_scalar(
114        &self,
115        data: &mut [Flat<F>],
116        offset: Flat<F>,
117    ) -> Result<(), FftError> {
118        self.check_len(data.len())?;
119        self.fwd_scalar(data, 0, 1, self.log_n, offset);
120
121        Ok(())
122    }
123
124    /// Inverse over the coset offset + W_log_n.
125    pub fn inverse_coset_scalar(
126        &self,
127        data: &mut [Flat<F>],
128        offset: Flat<F>,
129    ) -> Result<(), FftError> {
130        self.check_len(data.len())?;
131        self.inv_scalar(data, 0, 1, self.log_n, offset);
132
133        Ok(())
134    }
135
136    /// Forward, F::WIDTH column-lanes per element in lockstep.
137    pub fn forward(&self, data: &mut [PackedFlat<F>]) -> Result<(), FftError> {
138        self.forward_coset(data, Flat::from_raw(F::ZERO))
139    }
140
141    /// Inverse, F::WIDTH column-lanes per element in lockstep.
142    pub fn inverse(&self, data: &mut [PackedFlat<F>]) -> Result<(), FftError> {
143        self.inverse_coset(data, Flat::from_raw(F::ZERO))
144    }
145
146    /// Packed forward over the coset offset + W_log_n.
147    pub fn forward_coset(
148        &self,
149        data: &mut [PackedFlat<F>],
150        offset: Flat<F>,
151    ) -> Result<(), FftError> {
152        self.check_len(data.len())?;
153        self.fwd_packed(data, 0, 1, self.log_n, offset);
154
155        Ok(())
156    }
157
158    /// Packed inverse over the coset offset + W_log_n.
159    pub fn inverse_coset(
160        &self,
161        data: &mut [PackedFlat<F>],
162        offset: Flat<F>,
163    ) -> Result<(), FftError> {
164        self.check_len(data.len())?;
165        self.inv_packed(data, 0, 1, self.log_n, offset);
166
167        Ok(())
168    }
169
170    fn check_len(&self, got: usize) -> Result<(), FftError> {
171        let expected = 1usize << self.log_n;
172
173        if got != expected {
174            return Err(FftError::BadLength { expected, got });
175        }
176
177        Ok(())
178    }
179
180    // Decimation-in-time radix-2 butterfly (Gao–Mateer).
181    // σ(x) = x^2 + x maps W_d two-to-one onto W_{d-1}; the
182    // pair (2t, 2t+1) differs by β_0 = 1, so the twiddle is
183    // coset + twiddles[t] and the odd output is just + q.
184    // Strided recursion keeps the output in natural order.
185    fn fwd_scalar(&self, data: &mut [Flat<F>], off: usize, stride: usize, d: u32, coset: Flat<F>) {
186        if d == 0 {
187            return;
188        }
189
190        let half = 1usize << (d - 1);
191        let child = coset * coset + coset;
192
193        self.fwd_scalar(data, off, stride * 2, d - 1, child);
194        self.fwd_scalar(data, off + stride, stride * 2, d - 1, child);
195
196        for t in 0..half {
197            let tw = coset + self.twiddles[t];
198            let i0 = off + 2 * t * stride;
199            let i1 = i0 + stride;
200
201            let p = data[i0];
202            let q = data[i1];
203            let lo = p + tw * q;
204
205            data[i0] = lo;
206            data[i1] = lo + q;
207        }
208    }
209
210    // Inverse butterfly:
211    // q = o0 + o1, then p = o0 + tw*q.
212    // No β^-1 needed (pairs differ by β_0 = 1).
213    fn inv_scalar(&self, data: &mut [Flat<F>], off: usize, stride: usize, d: u32, coset: Flat<F>) {
214        if d == 0 {
215            return;
216        }
217
218        let half = 1usize << (d - 1);
219        let child = coset * coset + coset;
220
221        for t in 0..half {
222            let tw = coset + self.twiddles[t];
223            let i0 = off + 2 * t * stride;
224            let i1 = i0 + stride;
225
226            let o0 = data[i0];
227            let o1 = data[i1];
228            let q = o0 + o1;
229
230            data[i0] = o0 + tw * q;
231            data[i1] = q;
232        }
233
234        self.inv_scalar(data, off, stride * 2, d - 1, child);
235        self.inv_scalar(data, off + stride, stride * 2, d - 1, child);
236    }
237
238    fn fwd_packed(
239        &self,
240        data: &mut [PackedFlat<F>],
241        off: usize,
242        stride: usize,
243        d: u32,
244        coset: Flat<F>,
245    ) {
246        if d == 0 {
247            return;
248        }
249
250        let half = 1usize << (d - 1);
251        let child = coset * coset + coset;
252
253        self.fwd_packed(data, off, stride * 2, d - 1, child);
254        self.fwd_packed(data, off + stride, stride * 2, d - 1, child);
255
256        for t in 0..half {
257            let tw = coset + self.twiddles[t];
258            let i0 = off + 2 * t * stride;
259            let i1 = i0 + stride;
260
261            let p = data[i0];
262            let q = data[i1];
263            let lo = F::add_hardware_packed(p, F::mul_hardware_scalar_packed(q, tw));
264
265            data[i0] = lo;
266            data[i1] = F::add_hardware_packed(lo, q);
267        }
268    }
269
270    fn inv_packed(
271        &self,
272        data: &mut [PackedFlat<F>],
273        off: usize,
274        stride: usize,
275        d: u32,
276        coset: Flat<F>,
277    ) {
278        if d == 0 {
279            return;
280        }
281
282        let half = 1usize << (d - 1);
283        let child = coset * coset + coset;
284
285        for t in 0..half {
286            let tw = coset + self.twiddles[t];
287            let i0 = off + 2 * t * stride;
288            let i1 = i0 + stride;
289
290            let o0 = data[i0];
291            let o1 = data[i1];
292            let q = F::add_hardware_packed(o0, o1);
293
294            data[i0] = F::add_hardware_packed(o0, F::mul_hardware_scalar_packed(q, tw));
295            data[i1] = q;
296        }
297
298        self.inv_packed(data, off, stride * 2, d - 1, child);
299        self.inv_packed(data, off + stride, stride * 2, d - 1, child);
300    }
301}