Skip to main content

rlx_fft/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Configuration for learned FFT models.
17
18use anyhow::{Context, Result, bail, ensure};
19use serde::{Deserialize, Serialize};
20
21/// Supported transform sizes (power-of-two complex FFT length).
22pub const SUPPORTED_N_FFT: &[usize] = &[
23    64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072,
24];
25
26/// Full n_fft sweep for limit / capacity studies (same as [`SUPPORTED_N_FFT`]).
27pub const FULL_N_FFT_SWEEP: &[usize] = SUPPORTED_N_FFT;
28
29/// Batch sizes attempted in limit sweeps (descending — find max working batch per n_fft).
30pub const LIMIT_SWEEP_REQUESTED_BATCHES: &[usize] =
31    &[4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1];
32
33/// Max batch to attempt at each n_fft during limit sweeps (push until failure).
34pub fn batch_cap_for_limit_sweep(n_fft: usize) -> usize {
35    match n_fft {
36        n if n <= 128 => 4096,
37        n if n <= 256 => 2048,
38        n if n <= 512 => 1024,
39        n if n <= 1024 => 512,
40        n if n <= 2048 => 256,
41        n if n <= 4096 => 128,
42        n if n <= 8192 => 64,
43        n if n <= 16384 => 32,
44        n if n <= 32768 => 16,
45        n if n <= 65536 => 8,
46        _ => 4,
47    }
48}
49
50/// Cap batch size by n_fft to stay within practical memory / compile limits.
51pub fn adaptive_batches_for_n_fft(n_fft: usize, requested: &[usize]) -> Vec<usize> {
52    adaptive_batches_with_cap(n_fft, requested, batch_cap_for_limit_sweep(n_fft))
53}
54
55pub fn adaptive_batches_with_cap(n_fft: usize, requested: &[usize], cap: usize) -> Vec<usize> {
56    let _ = n_fft;
57    let mut out: Vec<usize> = requested
58        .iter()
59        .copied()
60        .filter(|&b| b >= 1 && b <= cap)
61        .collect();
62    if out.is_empty() {
63        out.push(cap.max(1));
64    }
65    out.sort_unstable();
66    out.dedup();
67    out
68}
69
70/// Batches used by `--limit-sweep`.
71pub fn limit_sweep_batches(n_fft: usize) -> Vec<usize> {
72    adaptive_batches_for_n_fft(n_fft, LIMIT_SWEEP_REQUESTED_BATCHES)
73}
74
75/// Compiled-graph variants are only attempted up to this n_fft (larger sizes hang on compile).
76pub fn compiled_ok_for_n_fft(n_fft: usize) -> bool {
77    n_fft <= 1024
78}
79
80/// Per-device compiled ceiling for limit sweeps (GPU backends tolerate slightly larger graphs).
81pub fn compiled_ok_for_limit_sweep(n_fft: usize, device: &str) -> bool {
82    if n_fft > 4096 {
83        return false;
84    }
85    match device.to_ascii_lowercase().as_str() {
86        "cpu" => n_fft <= 1024,
87        "metal" | "cuda" | "mlx" | "mps" | "rocm" | "wgpu" | "wgu" | "vulkan" | "gpu" => {
88            n_fft <= 2048
89        }
90        _ => n_fft <= 1024,
91    }
92}
93
94pub fn is_gpu_device_label(device: &str) -> bool {
95    matches!(
96        device.to_ascii_lowercase().as_str(),
97        "metal" | "cuda" | "mlx" | "mps" | "rocm" | "wgpu" | "wgu" | "vulkan" | "gpu"
98    )
99}
100
101/// Welch PSD path — skip at extreme sizes (segment buffers grow with n_fft).
102pub fn welch_ok_for_limit_sweep(n_fft: usize) -> bool {
103    n_fft <= 32768
104}
105
106/// Welch signal buffer is `batch × frame_len` floats — skip huge combos.
107pub fn welch_ok_for_config(n_fft: usize, batch: usize) -> bool {
108    if !welch_ok_for_limit_sweep(n_fft) {
109        return false;
110    }
111    let hop = n_fft / 2;
112    let frame = n_fft + 7 * hop;
113    let bytes = batch.saturating_mul(frame).saturating_mul(4);
114    bytes <= 512 * 1024 * 1024
115}
116
117/// Reduce training steps at large FFT sizes during sweeps.
118pub fn train_steps_for_n_fft(base: usize, n_fft: usize) -> usize {
119    match n_fft {
120        n if n > 65536 => base.min(2),
121        n if n > 32768 => base.min(3),
122        n if n > 16384 => base.min(4),
123        n if n > 8192 => base.min(5),
124        n if n > 4096 => base.min(8),
125        n if n > 2048 => base.min(12),
126        n if n > 1024 => base.min(15),
127        _ => base,
128    }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum TransformDir {
133    Forward,
134    Inverse,
135}
136
137impl TransformDir {
138    pub fn is_forward(self) -> bool {
139        matches!(self, Self::Forward)
140    }
141
142    pub fn is_inverse(self) -> bool {
143        matches!(self, Self::Inverse)
144    }
145}
146
147impl std::str::FromStr for TransformDir {
148    type Err = anyhow::Error;
149
150    fn from_str(s: &str) -> Result<Self> {
151        match s.to_ascii_lowercase().as_str() {
152            "forward" | "fft" => Ok(Self::Forward),
153            "inverse" | "ifft" => Ok(Self::Inverse),
154            other => bail!("unknown transform direction: {other} (use fft|ifft)"),
155        }
156    }
157}
158
159pub fn parse_transform_dir(s: &str) -> Result<TransformDir> {
160    s.parse()
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct FftLearnConfig {
165    /// Complex FFT length (must be a power of two).
166    pub n_fft: usize,
167    /// Training / inference batch size.
168    pub batch: usize,
169}
170
171impl FftLearnConfig {
172    pub fn new(n_fft: usize, batch: usize) -> Result<Self> {
173        ensure!(
174            n_fft.is_power_of_two(),
175            "n_fft must be a power of two, got {n_fft}"
176        );
177        ensure!(n_fft >= 4, "n_fft must be at least 4");
178        ensure!(batch >= 1, "batch must be >= 1");
179        Ok(Self { n_fft, batch })
180    }
181
182    pub fn tiny() -> Self {
183        Self {
184            n_fft: 64,
185            batch: 4,
186        }
187    }
188
189    pub fn num_stages(&self) -> usize {
190        self.n_fft.trailing_zeros() as usize
191    }
192
193    pub fn butterflies_per_stage(&self) -> usize {
194        self.n_fft / 2
195    }
196
197    pub fn twiddle_param_count(&self) -> usize {
198        self.num_stages() * self.butterflies_per_stage() * 2
199    }
200
201    pub fn validate(&self) -> Result<()> {
202        Self::new(self.n_fft, self.batch)?;
203        Ok(())
204    }
205}
206
207pub fn parse_n_fft(s: &str) -> Result<usize> {
208    let n: usize = s.parse().context("n_fft: usize")?;
209    FftLearnConfig::new(n, 1).map(|_| n)
210}
211
212pub fn ensure_supported_n_fft(n_fft: usize) -> Result<()> {
213    if SUPPORTED_N_FFT.contains(&n_fft) {
214        return Ok(());
215    }
216    bail!(
217        "unsupported n_fft={n_fft}; supported: {}",
218        SUPPORTED_N_FFT
219            .iter()
220            .map(|n| n.to_string())
221            .collect::<Vec<_>>()
222            .join(", ")
223    );
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct TrainConfig {
228    pub model: FftLearnConfig,
229    pub direction: TransformDir,
230    pub steps: usize,
231    pub lr: f64,
232    pub weight_decay: f32,
233    pub beta1: f64,
234    pub beta2: f64,
235    pub eps: f64,
236    pub grad_clip: f32,
237    pub seed: u64,
238    pub log_every: usize,
239    pub device: String,
240    pub out_dir: Option<std::path::PathBuf>,
241}
242
243impl Default for TrainConfig {
244    fn default() -> Self {
245        Self {
246            model: FftLearnConfig::tiny(),
247            direction: TransformDir::Forward,
248            steps: 500,
249            lr: 1e-3,
250            weight_decay: 0.0,
251            beta1: 0.9,
252            beta2: 0.999,
253            eps: 1e-8,
254            grad_clip: 1.0,
255            seed: 42,
256            log_every: 50,
257            device: "auto".to_string(),
258            out_dir: None,
259        }
260    }
261}
262
263/// Three-phase training: encoder only → decoder only → joint fine-tune.
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct PhasedTrainConfig {
266    pub model: FftLearnConfig,
267    pub encoder_steps: usize,
268    pub decoder_steps: usize,
269    pub joint_steps: usize,
270    pub lr: f64,
271    pub spectrum_weight: f32,
272    pub seed: u64,
273    pub log_every: usize,
274    pub out_dir: Option<std::path::PathBuf>,
275}
276
277impl Default for PhasedTrainConfig {
278    fn default() -> Self {
279        Self {
280            model: FftLearnConfig::tiny(),
281            encoder_steps: 300,
282            decoder_steps: 300,
283            joint_steps: 300,
284            lr: 5e-4,
285            spectrum_weight: 1.0,
286            seed: 42,
287            log_every: 50,
288            out_dir: None,
289        }
290    }
291}
292
293/// Train encoder (FFT) + decoder (IFFT) jointly on synthetic roundtrip data.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct EncDecTrainConfig {
296    pub model: FftLearnConfig,
297    pub steps: usize,
298    pub lr: f64,
299    /// Weight for auxiliary encoder loss vs reference FFT (0 = roundtrip only).
300    pub spectrum_weight: f32,
301    pub seed: u64,
302    pub log_every: usize,
303    pub device: String,
304    pub out_dir: Option<std::path::PathBuf>,
305    #[serde(default = "default_grad_clip")]
306    pub grad_clip: f32,
307    #[serde(default = "default_project_twiddles")]
308    pub project_twiddles: bool,
309}
310
311fn default_grad_clip() -> f32 {
312    1.0
313}
314
315fn default_project_twiddles() -> bool {
316    true
317}
318
319/// Multi-size encoder–decoder training study (compare schedules across n_fft).
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
321pub enum MultiTrainSchedule {
322    /// Train each n_fft in isolation (`steps` per size).
323    Single,
324    /// Cycle sizes each step (`steps` total).
325    RoundRobin,
326    /// Random n_fft each step (`steps` total).
327    Random,
328    /// Equal step count per size (`steps / n_sizes` each, `steps` total).
329    Balanced,
330}
331
332impl MultiTrainSchedule {
333    pub fn label(self) -> &'static str {
334        match self {
335            Self::Single => "single",
336            Self::RoundRobin => "round_robin",
337            Self::Random => "random",
338            Self::Balanced => "balanced",
339        }
340    }
341
342    pub fn all() -> &'static [Self] {
343        &[Self::Single, Self::RoundRobin, Self::Random, Self::Balanced]
344    }
345
346    pub fn parse_csv(s: &str) -> anyhow::Result<Vec<Self>> {
347        let mut out = Vec::new();
348        for part in s.split(',') {
349            let part = part.trim().to_ascii_lowercase();
350            if part.is_empty() {
351                continue;
352            }
353            out.push(match part.as_str() {
354                "single" => Self::Single,
355                "round_robin" | "round-robin" | "rr" => Self::RoundRobin,
356                "random" => Self::Random,
357                "balanced" => Self::Balanced,
358                other => anyhow::bail!(
359                    "unknown schedule {other} (use single,round_robin,random,balanced)"
360                ),
361            });
362        }
363        anyhow::ensure!(!out.is_empty(), "schedules list is empty");
364        Ok(out)
365    }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct MultiTrainConfig {
370    pub n_ffts: Vec<usize>,
371    pub batch: usize,
372    /// Maximum training steps (per size for `single`, total for mixed schedules).
373    pub steps: usize,
374    pub schedules: Vec<MultiTrainSchedule>,
375    pub lr: f64,
376    pub spectrum_weight: f32,
377    pub seed: u64,
378    pub log_every: usize,
379    pub eval_batches: usize,
380    pub out_dir: Option<std::path::PathBuf>,
381    /// Stop early when holdout loss plateaus (after `min_steps`).
382    pub until_converged: bool,
383    /// Minimum steps before convergence checks apply.
384    pub min_steps: usize,
385    /// Re-check holdout loss every N steps.
386    pub converge_every: usize,
387    /// Consecutive checks without meaningful improvement before stop.
388    pub converge_patience: usize,
389    /// Relative improvement threshold (fraction of best loss).
390    pub converge_delta: f32,
391    /// Clip twiddle gradient L2 norm (0 = off).
392    pub grad_clip: f32,
393    /// Keep twiddles on the unit circle after each step.
394    pub project_twiddles: bool,
395    /// Use fused enc–dec step (batched ref FFT + shared backward).
396    pub use_fused_train: bool,
397    pub optimizer: crate::second_order::TwiddleOptimizer,
398}
399
400impl Default for MultiTrainConfig {
401    fn default() -> Self {
402        Self {
403            n_ffts: vec![64, 256],
404            batch: 8,
405            steps: 10_000,
406            schedules: MultiTrainSchedule::all().to_vec(),
407            lr: 5e-4,
408            spectrum_weight: 1.0,
409            seed: 42,
410            log_every: 50,
411            eval_batches: 8,
412            out_dir: None,
413            until_converged: true,
414            min_steps: 300,
415            converge_every: 25,
416            converge_patience: 5,
417            converge_delta: 1e-4,
418            grad_clip: 1.0,
419            project_twiddles: true,
420            use_fused_train: true,
421            optimizer: crate::second_order::TwiddleOptimizer::Sgd,
422        }
423    }
424}
425
426impl Default for EncDecTrainConfig {
427    fn default() -> Self {
428        Self {
429            model: FftLearnConfig::tiny(),
430            steps: 500,
431            lr: 1e-3,
432            spectrum_weight: 1.0,
433            seed: 42,
434            log_every: 50,
435            device: "auto".to_string(),
436            out_dir: None,
437            grad_clip: default_grad_clip(),
438            project_twiddles: default_project_twiddles(),
439        }
440    }
441}