Skip to main content

rlx_fft/
train_multi.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Multi-n_fft encoder–decoder training study and train×eval matrix.
17
18use crate::butterfly::butterfly_train_step_encdec;
19use crate::config::{FftLearnConfig, MultiTrainConfig, MultiTrainSchedule};
20use crate::fused_train::fused_encdec_train_step;
21use crate::second_order::{TwiddleOptState, TwiddleOptimizer};
22use crate::train::random_batch;
23use crate::train_phased::precision_encdec;
24use crate::twiddle::exact_twiddles;
25use crate::weights::{EncDecWeights, export_safetensors};
26use anyhow::{Result, ensure};
27use rand::prelude::*;
28use serde::{Deserialize, Deserializer, Serialize};
29use std::collections::HashMap;
30use std::path::{Path, PathBuf};
31use std::time::Instant;
32
33fn null_as_nan<'de, D: Deserializer<'de>>(deserializer: D) -> Result<f32, D::Error> {
34    Ok(Option::<f32>::deserialize(deserializer)?.unwrap_or(f32::NAN))
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MultiTrainEvalRow {
39    /// Training regime label (`single_64`, `mixed_round_robin`, `exact`, …).
40    pub regime: String,
41    pub schedule: String,
42    /// n_fft sizes included in this training run.
43    pub train_sizes: Vec<usize>,
44    pub eval_n_fft: usize,
45    pub train_steps_total: usize,
46    pub train_elapsed_ms: f64,
47    #[serde(deserialize_with = "null_as_nan")]
48    pub encoder_spectrum_mse: f32,
49    #[serde(deserialize_with = "null_as_nan")]
50    pub encoder_spectrum_max_err: f32,
51    #[serde(deserialize_with = "null_as_nan")]
52    pub decoder_time_mse: f32,
53    #[serde(deserialize_with = "null_as_nan")]
54    pub decoder_time_max_err: f32,
55    #[serde(deserialize_with = "null_as_nan")]
56    pub roundtrip_mse: f32,
57    #[serde(deserialize_with = "null_as_nan")]
58    pub roundtrip_max_err: f32,
59    pub converged: bool,
60    #[serde(deserialize_with = "null_as_nan")]
61    pub final_holdout_mse: f32,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub checkpoint: Option<PathBuf>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MultiTrainReport {
68    pub batch: usize,
69    pub n_ffts: Vec<usize>,
70    pub max_steps: usize,
71    pub min_steps: usize,
72    pub until_converged: bool,
73    pub eval_batches: usize,
74    pub seed: u64,
75    #[serde(default = "default_grad_clip_report")]
76    pub grad_clip: f32,
77    #[serde(default)]
78    pub project_twiddles: bool,
79    #[serde(default = "default_use_fused")]
80    pub use_fused_train: bool,
81    #[serde(default = "default_optimizer_label")]
82    pub optimizer: String,
83    pub elapsed_ms: f64,
84    pub rows: Vec<MultiTrainEvalRow>,
85}
86
87fn default_grad_clip_report() -> f32 {
88    1.0
89}
90
91fn default_use_fused() -> bool {
92    true
93}
94
95fn default_optimizer_label() -> String {
96    "sgd".into()
97}
98
99struct SizeTwiddles {
100    encoder: Vec<f32>,
101    decoder: Vec<f32>,
102    opt: TwiddleOptState,
103}
104
105fn new_size_twiddles(model: &FftLearnConfig, optimizer: TwiddleOptimizer) -> SizeTwiddles {
106    let stages = model.n_fft.trailing_zeros() as usize;
107    let half = model.n_fft / 2;
108    let tw_len = stages * half * 2;
109    SizeTwiddles {
110        encoder: exact_twiddles(model),
111        decoder: exact_twiddles(model),
112        opt: TwiddleOptState::new(optimizer, tw_len, tw_len),
113    }
114}
115struct ConvergenceTracker {
116    patience: usize,
117    rel_delta: f32,
118    abs_delta: f32,
119    best: f32,
120    stale: usize,
121}
122
123impl ConvergenceTracker {
124    fn new(cfg: &MultiTrainConfig) -> Self {
125        Self {
126            patience: cfg.converge_patience,
127            rel_delta: cfg.converge_delta,
128            abs_delta: cfg.converge_delta * 1e-4,
129            best: f32::INFINITY,
130            stale: 0,
131        }
132    }
133
134    /// Returns true when training should stop (loss plateau).
135    fn observe(&mut self, loss: f32) -> bool {
136        if !loss.is_finite() {
137            self.stale = 0;
138            return false;
139        }
140        let improved = if !self.best.is_finite() {
141            true
142        } else {
143            let drop = self.best - loss;
144            drop > self.abs_delta || drop / self.best.max(1e-12) > self.rel_delta
145        };
146        if improved {
147            self.best = loss;
148            self.stale = 0;
149        } else {
150            self.stale += 1;
151        }
152        self.stale >= self.patience
153    }
154}
155
156pub fn run_multi_train(cfg: &MultiTrainConfig) -> Result<MultiTrainReport> {
157    ensure!(!cfg.n_ffts.is_empty(), "n_ffts must not be empty");
158    ensure!(cfg.steps >= 1, "steps must be >= 1");
159    for &n in &cfg.n_ffts {
160        FftLearnConfig::new(n, cfg.batch)?;
161    }
162
163    let started = Instant::now();
164    let mut rows = Vec::new();
165
166    rows.extend(eval_exact_baseline(cfg)?);
167
168    for &schedule in &cfg.schedules {
169        eprintln!("[train-multi] schedule={}", schedule.label());
170        let regime_rows = train_schedule(cfg, schedule)?;
171        rows.extend(regime_rows);
172    }
173
174    let report = MultiTrainReport {
175        batch: cfg.batch,
176        n_ffts: cfg.n_ffts.clone(),
177        max_steps: cfg.steps,
178        min_steps: cfg.min_steps,
179        until_converged: cfg.until_converged,
180        eval_batches: cfg.eval_batches,
181        seed: cfg.seed,
182        grad_clip: cfg.grad_clip,
183        project_twiddles: cfg.project_twiddles,
184        use_fused_train: cfg.use_fused_train,
185        optimizer: cfg.optimizer.label().to_string(),
186        elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
187        rows,
188    };
189
190    if let Some(out) = &cfg.out_dir {
191        std::fs::create_dir_all(out)?;
192        let path = out.join("multi_train_report.json");
193        std::fs::write(&path, serde_json::to_vec_pretty(&report)?)?;
194        eprintln!("wrote {}", path.display());
195    }
196
197    Ok(report)
198}
199
200fn eval_exact_baseline(cfg: &MultiTrainConfig) -> Result<Vec<MultiTrainEvalRow>> {
201    let mut rows = Vec::new();
202    for &n in &cfg.n_ffts {
203        let model = FftLearnConfig::new(n, cfg.batch)?;
204        let enc = exact_twiddles(&model);
205        let dec = exact_twiddles(&model);
206        let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
207        let (enc_mse, enc_max, dec_mse, dec_max, rt_mse, rt_max) =
208            precision_encdec(&enc, &dec, &model, cfg.eval_batches, &mut rng)?;
209        rows.push(row_from_metrics(
210            "exact",
211            "exact",
212            vec![n],
213            n,
214            0,
215            0.0,
216            enc_mse,
217            enc_max,
218            dec_mse,
219            dec_max,
220            rt_mse,
221            rt_max,
222            true,
223            rt_mse,
224            None,
225        ));
226    }
227    Ok(rows)
228}
229
230fn train_schedule(
231    cfg: &MultiTrainConfig,
232    schedule: MultiTrainSchedule,
233) -> Result<Vec<MultiTrainEvalRow>> {
234    match schedule {
235        MultiTrainSchedule::Single => train_single_per_size(cfg),
236        MultiTrainSchedule::RoundRobin
237        | MultiTrainSchedule::Random
238        | MultiTrainSchedule::Balanced => train_mixed(cfg, schedule),
239    }
240}
241
242fn train_single_per_size(cfg: &MultiTrainConfig) -> Result<Vec<MultiTrainEvalRow>> {
243    let mut all_rows = Vec::new();
244
245    for &n in &cfg.n_ffts {
246        let model = FftLearnConfig::new(n, cfg.batch)?;
247        let tw = new_size_twiddles(&model, cfg.optimizer);
248        let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed.wrapping_add(n as u64));
249        let regime = format!("single_{n}");
250        let label = regime.clone();
251
252        let outcome = train_until_converged(
253            cfg,
254            &label,
255            &mut rng,
256            move |_step, tw_map, rng| {
257                let st = tw_map.get_mut(&n).expect("twiddles");
258                train_encdec_step_on(cfg, st, n, rng)
259            },
260            HashMap::from([(n, tw)]),
261            |tw_map, rng| holdout_mse(cfg, tw_map, &[n], rng),
262        )?;
263
264        let tw = outcome.tw;
265        let st = tw.get(&n).expect("twiddles");
266        let weights = EncDecWeights::from_twiddles(&st.encoder, &st.decoder, n);
267        let checkpoint = save_multi_checkpoint(cfg, &regime, &weights, n)?;
268
269        all_rows.extend(eval_twiddles_matrix(
270            cfg,
271            &regime,
272            schedule_label(MultiTrainSchedule::Single),
273            &[n],
274            outcome.steps,
275            outcome.elapsed_ms,
276            outcome.converged,
277            outcome.holdout_mse,
278            &tw,
279            checkpoint,
280        )?);
281    }
282    Ok(all_rows)
283}
284
285fn train_mixed(
286    cfg: &MultiTrainConfig,
287    schedule: MultiTrainSchedule,
288) -> Result<Vec<MultiTrainEvalRow>> {
289    let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
290    let mut tw: HashMap<usize, SizeTwiddles> = HashMap::new();
291    for &n in &cfg.n_ffts {
292        let model = FftLearnConfig::new(n, cfg.batch)?;
293        tw.insert(n, new_size_twiddles(&model, cfg.optimizer));
294    }
295
296    let regime = format!("mixed_{}", schedule.label());
297    let n_sizes = cfg.n_ffts.len();
298    let eval_sizes = cfg.n_ffts.clone();
299
300    let outcome = match schedule {
301        MultiTrainSchedule::Balanced => {
302            let per = cfg.steps / n_sizes;
303            ensure!(
304                per >= 1,
305                "steps={} too small for {} sizes in balanced mode",
306                cfg.steps,
307                n_sizes
308            );
309            train_balanced_until_converged(cfg, &regime, per, &mut tw, &mut rng)?
310        }
311        MultiTrainSchedule::RoundRobin | MultiTrainSchedule::Random => train_until_converged(
312            cfg,
313            &regime,
314            &mut rng,
315            move |step, tw_map, rng| {
316                let pick = match schedule {
317                    MultiTrainSchedule::RoundRobin => cfg.n_ffts[step % n_sizes],
318                    MultiTrainSchedule::Random => cfg.n_ffts[rng.gen_range(0..n_sizes)],
319                    _ => unreachable!(),
320                };
321                let st = tw_map.get_mut(&pick).expect("twiddles");
322                train_encdec_step_on(cfg, st, pick, rng)
323            },
324            tw,
325            {
326                let eval_sizes = eval_sizes.clone();
327                move |tw_map, rng| holdout_mse(cfg, tw_map, &eval_sizes, rng)
328            },
329        )?,
330        MultiTrainSchedule::Single => unreachable!(),
331    };
332
333    let tw = outcome.tw;
334    let mut checkpoint = None;
335    if let Some(out_dir) = &cfg.out_dir {
336        let dir = out_dir.join(&regime);
337        std::fs::create_dir_all(&dir)?;
338        for &n in &cfg.n_ffts {
339            let st = tw.get(&n).expect("twiddles");
340            let weights = EncDecWeights::from_twiddles(&st.encoder, &st.decoder, n);
341            let path = dir.join(format!("n{n}_encdec.safetensors"));
342            export_safetensors(&path, &weights.merged())?;
343        }
344        checkpoint = Some(dir);
345    }
346
347    eval_twiddles_matrix(
348        cfg,
349        &regime,
350        schedule.label().to_string(),
351        &cfg.n_ffts,
352        outcome.steps,
353        outcome.elapsed_ms,
354        outcome.converged,
355        outcome.holdout_mse,
356        &tw,
357        checkpoint,
358    )
359}
360
361struct ConvergeOutcome {
362    tw: HashMap<usize, SizeTwiddles>,
363    steps: usize,
364    elapsed_ms: f64,
365    converged: bool,
366    holdout_mse: f32,
367}
368
369fn train_until_converged<R: Rng>(
370    cfg: &MultiTrainConfig,
371    label: &str,
372    rng: &mut R,
373    mut step_fn: impl FnMut(usize, &mut HashMap<usize, SizeTwiddles>, &mut R) -> Result<()>,
374    mut tw: HashMap<usize, SizeTwiddles>,
375    mut holdout_fn: impl FnMut(&HashMap<usize, SizeTwiddles>, &mut R) -> Result<f32>,
376) -> Result<ConvergeOutcome> {
377    let started = Instant::now();
378    let mut tracker = ConvergenceTracker::new(cfg);
379    let mut step = 0usize;
380    let mut converged = false;
381    let mut holdout_mse = f32::INFINITY;
382
383    while step < cfg.steps {
384        step_fn(step, &mut tw, rng)?;
385        step += 1;
386
387        if cfg.until_converged && step >= cfg.min_steps && step.is_multiple_of(cfg.converge_every) {
388            holdout_mse = holdout_fn(&tw, rng)?;
389            eprintln!(
390                "  [{label}] step {step} holdout_mse={holdout_mse:.6e} best={:.6e}",
391                tracker.best
392            );
393            if tracker.observe(holdout_mse) {
394                converged = true;
395                eprintln!("  [{label}] converged at step {step} holdout_mse={holdout_mse:.6e}");
396                break;
397            }
398        } else if cfg.log_every > 0 && step.is_multiple_of(cfg.log_every) {
399            eprintln!("  [{label}] step {step}/{}", cfg.steps);
400        }
401    }
402
403    if !holdout_mse.is_finite() {
404        holdout_mse = holdout_fn(&tw, rng)?;
405    }
406
407    Ok(ConvergeOutcome {
408        tw,
409        steps: step,
410        elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
411        converged: converged && holdout_mse.is_finite(),
412        holdout_mse,
413    })
414}
415
416fn train_balanced_until_converged<R: Rng>(
417    cfg: &MultiTrainConfig,
418    label: &str,
419    per_size: usize,
420    tw: &mut HashMap<usize, SizeTwiddles>,
421    rng: &mut R,
422) -> Result<ConvergeOutcome> {
423    let started = Instant::now();
424    let mut tracker = ConvergenceTracker::new(cfg);
425    let mut step = 0usize;
426    let mut converged = false;
427    let mut final_holdout = f32::INFINITY;
428    let eval_sizes = cfg.n_ffts.clone();
429
430    'outer: while step < cfg.steps {
431        for &n in &cfg.n_ffts {
432            if step >= cfg.steps {
433                break 'outer;
434            }
435            let st = tw.get_mut(&n).expect("twiddles");
436            train_encdec_step_on(cfg, st, n, rng)?;
437            step += 1;
438
439            if cfg.until_converged
440                && step >= cfg.min_steps
441                && step.is_multiple_of(cfg.converge_every)
442            {
443                let loss = holdout_mse(cfg, tw, &eval_sizes, rng)?;
444                eprintln!(
445                    "  [{label}] step {step} holdout_mse={loss:.6e} best={:.6e}",
446                    tracker.best
447                );
448                if tracker.observe(loss) {
449                    converged = true;
450                    final_holdout = loss;
451                    eprintln!("  [{label}] converged at step {step} holdout_mse={loss:.6e}");
452                    break 'outer;
453                }
454                final_holdout = loss;
455            } else if cfg.log_every > 0 && step.is_multiple_of(cfg.log_every) {
456                eprintln!(
457                    "  [{label}] step {step}/{} (balanced ~{per_size}/size)",
458                    cfg.steps
459                );
460            }
461        }
462    }
463
464    if !final_holdout.is_finite() {
465        final_holdout = holdout_mse(cfg, tw, &eval_sizes, rng)?;
466    }
467
468    Ok(ConvergeOutcome {
469        tw: std::mem::take(tw),
470        steps: step,
471        elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
472        converged: converged && final_holdout.is_finite(),
473        holdout_mse: final_holdout,
474    })
475}
476
477fn holdout_mse(
478    cfg: &MultiTrainConfig,
479    tw: &HashMap<usize, SizeTwiddles>,
480    sizes: &[usize],
481    rng: &mut impl Rng,
482) -> Result<f32> {
483    let mut acc = 0f32;
484    let mut n = 0f32;
485    for &size in sizes {
486        let Some(st) = tw.get(&size) else {
487            continue;
488        };
489        let model = FftLearnConfig::new(size, cfg.batch)?;
490        let (_, _, _, _, rt_mse, _) =
491            precision_encdec(&st.encoder, &st.decoder, &model, cfg.eval_batches, rng)?;
492        acc += rt_mse;
493        n += 1.0;
494    }
495    Ok(if n > 0.0 { acc / n } else { f32::INFINITY })
496}
497
498fn train_encdec_step_on(
499    cfg: &MultiTrainConfig,
500    st: &mut SizeTwiddles,
501    n: usize,
502    rng: &mut impl Rng,
503) -> Result<()> {
504    let signal = random_batch(rng, cfg.batch, n);
505    if cfg.use_fused_train {
506        fused_encdec_train_step(
507            &signal,
508            &mut st.encoder,
509            &mut st.decoder,
510            cfg.batch,
511            n,
512            cfg.lr,
513            cfg.spectrum_weight,
514            cfg.grad_clip,
515            cfg.project_twiddles,
516            Some(&mut st.opt),
517        )?;
518    } else {
519        butterfly_train_step_encdec(
520            &signal,
521            &mut st.encoder,
522            &mut st.decoder,
523            cfg.batch,
524            n,
525            cfg.lr as f32,
526            cfg.spectrum_weight,
527        )?;
528    }
529    Ok(())
530}
531
532fn eval_twiddles_matrix(
533    cfg: &MultiTrainConfig,
534    regime: &str,
535    schedule: String,
536    train_sizes: &[usize],
537    train_steps: usize,
538    train_elapsed_ms: f64,
539    converged: bool,
540    holdout_mse: f32,
541    tw: &HashMap<usize, SizeTwiddles>,
542    checkpoint: Option<PathBuf>,
543) -> Result<Vec<MultiTrainEvalRow>> {
544    let mut rows = Vec::new();
545    let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed.wrapping_add(17));
546
547    for &eval_n in &cfg.n_ffts {
548        let Some(st) = tw.get(&eval_n) else {
549            continue;
550        };
551        let model = FftLearnConfig::new(eval_n, cfg.batch)?;
552        let (enc_mse, enc_max, dec_mse, dec_max, rt_mse, rt_max) =
553            precision_encdec(&st.encoder, &st.decoder, &model, cfg.eval_batches, &mut rng)?;
554        rows.push(row_from_metrics(
555            regime,
556            &schedule,
557            train_sizes.to_vec(),
558            eval_n,
559            train_steps,
560            train_elapsed_ms,
561            enc_mse,
562            enc_max,
563            dec_mse,
564            dec_max,
565            rt_mse,
566            rt_max,
567            converged,
568            holdout_mse,
569            checkpoint.clone(),
570        ));
571    }
572    Ok(rows)
573}
574
575#[allow(clippy::too_many_arguments)]
576fn row_from_metrics(
577    regime: &str,
578    schedule: &str,
579    train_sizes: Vec<usize>,
580    eval_n_fft: usize,
581    train_steps_total: usize,
582    train_elapsed_ms: f64,
583    encoder_spectrum_mse: f32,
584    encoder_spectrum_max_err: f32,
585    decoder_time_mse: f32,
586    decoder_time_max_err: f32,
587    roundtrip_mse: f32,
588    roundtrip_max_err: f32,
589    converged: bool,
590    final_holdout_mse: f32,
591    checkpoint: Option<PathBuf>,
592) -> MultiTrainEvalRow {
593    MultiTrainEvalRow {
594        regime: regime.to_string(),
595        schedule: schedule.to_string(),
596        train_sizes,
597        eval_n_fft,
598        train_steps_total,
599        train_elapsed_ms,
600        encoder_spectrum_mse,
601        encoder_spectrum_max_err,
602        decoder_time_mse,
603        decoder_time_max_err,
604        roundtrip_mse,
605        roundtrip_max_err,
606        converged,
607        final_holdout_mse,
608        checkpoint,
609    }
610}
611
612fn save_multi_checkpoint(
613    cfg: &MultiTrainConfig,
614    regime: &str,
615    weights: &EncDecWeights,
616    n: usize,
617) -> Result<Option<PathBuf>> {
618    let Some(out) = &cfg.out_dir else {
619        return Ok(None);
620    };
621    let dir = out.join(regime);
622    std::fs::create_dir_all(&dir)?;
623    let path = dir.join(format!("n{n}_encdec.safetensors"));
624    export_safetensors(&path, &weights.merged())?;
625    Ok(Some(path))
626}
627
628fn schedule_label(s: MultiTrainSchedule) -> String {
629    s.label().to_string()
630}
631
632pub fn write_multi_train_json(path: &Path, report: &MultiTrainReport) -> Result<()> {
633    if let Some(parent) = path.parent() {
634        std::fs::create_dir_all(parent)?;
635    }
636    std::fs::write(path, serde_json::to_vec_pretty(report)?)?;
637    Ok(())
638}
639
640pub fn best_regime_per_eval(report: &MultiTrainReport) -> Vec<(usize, String, f32)> {
641    let mut out = Vec::new();
642    for &n in &report.n_ffts {
643        let best = report
644            .rows
645            .iter()
646            .filter(|r| r.eval_n_fft == n && r.regime != "exact")
647            .min_by(|a, b| {
648                a.roundtrip_max_err
649                    .partial_cmp(&b.roundtrip_max_err)
650                    .unwrap_or(std::cmp::Ordering::Equal)
651            });
652        if let Some(r) = best {
653            out.push((n, r.regime.clone(), r.roundtrip_max_err));
654        }
655    }
656    out
657}
658
659pub fn print_multi_train_table(report: &MultiTrainReport) {
660    eprintln!(
661        "\n=== Multi-n_fft training study (batch={}, max_steps={}, min_steps={}, until_converged={}) ===\n",
662        report.batch, report.max_steps, report.min_steps, report.until_converged
663    );
664
665    for &eval_n in &report.n_ffts {
666        eprintln!("--- eval n_fft={eval_n} ---");
667        eprintln!(
668            "{:<22} {:>10} {:>6} {:>10} {:>10} {:>10}",
669            "regime", "steps", "conv", "rt_max", "enc_max", "train_ms"
670        );
671        let mut subset: Vec<&MultiTrainEvalRow> = report
672            .rows
673            .iter()
674            .filter(|r| r.eval_n_fft == eval_n)
675            .collect();
676        subset.sort_by(|a, b| {
677            a.roundtrip_max_err
678                .partial_cmp(&b.roundtrip_max_err)
679                .unwrap_or(std::cmp::Ordering::Equal)
680        });
681        for r in &subset {
682            eprintln!(
683                "{:<22} {:>10} {:>6} {:>10.3e} {:>10.3e} {:>10.1}",
684                r.regime,
685                r.train_steps_total,
686                if r.converged { "yes" } else { "no" },
687                r.roundtrip_max_err,
688                r.encoder_spectrum_max_err,
689                r.train_elapsed_ms
690            );
691        }
692        if let Some(best) = subset.first() {
693            eprintln!(
694                "  → best: {} (rt_max={:.3e}, steps={})\n",
695                best.regime, best.roundtrip_max_err, best.train_steps_total
696            );
697        }
698    }
699
700    eprintln!("--- train×eval roundtrip max_err matrix ---");
701    let regimes: Vec<String> = report
702        .rows
703        .iter()
704        .map(|r| r.regime.clone())
705        .collect::<std::collections::BTreeSet<_>>()
706        .into_iter()
707        .collect();
708    eprint!("{:>22}", "regime \\ eval");
709    for &n in &report.n_ffts {
710        eprint!(" {:>10}", n);
711    }
712    eprintln!();
713    for regime in &regimes {
714        eprint!("{regime:>22}");
715        for &n in &report.n_ffts {
716            let cell = report
717                .rows
718                .iter()
719                .find(|r| r.regime == *regime && r.eval_n_fft == n);
720            if let Some(r) = cell {
721                eprint!(" {:>10.2e}", r.roundtrip_max_err);
722            } else {
723                eprint!(" {:>10}", "—");
724            }
725        }
726        eprintln!();
727    }
728    eprintln!("\nTotal study time: {:.1} ms\n", report.elapsed_ms);
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use crate::config::MultiTrainSchedule;
735
736    fn test_cfg(steps: usize, schedules: Vec<MultiTrainSchedule>) -> MultiTrainConfig {
737        MultiTrainConfig {
738            n_ffts: vec![64, 128],
739            batch: 4,
740            steps,
741            schedules,
742            lr: 5e-4,
743            spectrum_weight: 1.0,
744            seed: 1,
745            log_every: 0,
746            eval_batches: 2,
747            out_dir: None,
748            until_converged: false,
749            min_steps: 300,
750            converge_every: 25,
751            converge_patience: 5,
752            converge_delta: 1e-4,
753            grad_clip: 1.0,
754            project_twiddles: true,
755            use_fused_train: true,
756            optimizer: TwiddleOptimizer::Sgd,
757        }
758    }
759
760    #[test]
761    fn multi_train_single_schedule() {
762        let report = run_multi_train(&test_cfg(40, vec![MultiTrainSchedule::Single])).unwrap();
763        assert!(report.rows.iter().any(|r| r.regime == "single_64"));
764        assert!(report.rows.iter().any(|r| r.regime == "single_128"));
765        for &n in &[64usize, 128] {
766            let best = report
767                .rows
768                .iter()
769                .filter(|r| r.eval_n_fft == n && r.regime.starts_with("single_"))
770                .map(|r| r.roundtrip_max_err)
771                .fold(f32::INFINITY, f32::min);
772            assert!(best < 0.5, "n={n} single train rt_max={best}");
773        }
774    }
775
776    #[test]
777    fn mixed_round_robin_runs() {
778        let report = run_multi_train(&test_cfg(20, vec![MultiTrainSchedule::RoundRobin])).unwrap();
779        assert!(report.rows.iter().any(|r| r.regime == "mixed_round_robin"));
780    }
781
782    #[test]
783    fn convergence_stops_early() {
784        let mut cfg = test_cfg(2000, vec![MultiTrainSchedule::Single]);
785        cfg.n_ffts = vec![64];
786        cfg.until_converged = true;
787        cfg.min_steps = 20;
788        cfg.converge_every = 10;
789        cfg.converge_patience = 2;
790        cfg.converge_delta = 1e-2;
791        let report = run_multi_train(&cfg).unwrap();
792        let row = report
793            .rows
794            .iter()
795            .find(|r| r.regime == "single_64")
796            .expect("single_64");
797        assert!(row.converged, "expected early convergence");
798        assert!(
799            row.train_steps_total < cfg.steps,
800            "expected fewer than max steps"
801        );
802    }
803
804    #[test]
805    fn fused_single_1024_stays_finite() {
806        let mut cfg = test_cfg(80, vec![MultiTrainSchedule::Single]);
807        cfg.n_ffts = vec![1024];
808        cfg.until_converged = false;
809        cfg.lr = 1e-4;
810        cfg.use_fused_train = true;
811        cfg.optimizer = TwiddleOptimizer::Adam;
812        cfg.project_twiddles = true;
813        let report = run_multi_train(&cfg).unwrap();
814        let row = report
815            .rows
816            .iter()
817            .find(|r| r.regime == "single_1024")
818            .expect("single_1024");
819        assert!(
820            row.roundtrip_max_err.is_finite(),
821            "rt_max={}",
822            row.roundtrip_max_err
823        );
824        assert!(row.encoder_spectrum_max_err.is_finite());
825    }
826}