1use crate::butterfly::butterfly_train_step_encdec;
19use crate::config::{FftLearnConfig, MultiTrainConfig, MultiTrainSchedule};
20use crate::fused_train::fused_encdec_train_step;
21use crate::second_order::{TwiddleOptState, TwiddleOptimizer};
22use crate::train::random_batch;
23use crate::train_phased::precision_encdec;
24use crate::twiddle::exact_twiddles;
25use crate::weights::{EncDecWeights, export_safetensors};
26use anyhow::{Result, ensure};
27use rand::prelude::*;
28use serde::{Deserialize, Deserializer, Serialize};
29use std::collections::HashMap;
30use std::path::{Path, PathBuf};
31use std::time::Instant;
32
33fn null_as_nan<'de, D: Deserializer<'de>>(deserializer: D) -> Result<f32, D::Error> {
34 Ok(Option::<f32>::deserialize(deserializer)?.unwrap_or(f32::NAN))
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct MultiTrainEvalRow {
39 pub regime: String,
41 pub schedule: String,
42 pub train_sizes: Vec<usize>,
44 pub eval_n_fft: usize,
45 pub train_steps_total: usize,
46 pub train_elapsed_ms: f64,
47 #[serde(deserialize_with = "null_as_nan")]
48 pub encoder_spectrum_mse: f32,
49 #[serde(deserialize_with = "null_as_nan")]
50 pub encoder_spectrum_max_err: f32,
51 #[serde(deserialize_with = "null_as_nan")]
52 pub decoder_time_mse: f32,
53 #[serde(deserialize_with = "null_as_nan")]
54 pub decoder_time_max_err: f32,
55 #[serde(deserialize_with = "null_as_nan")]
56 pub roundtrip_mse: f32,
57 #[serde(deserialize_with = "null_as_nan")]
58 pub roundtrip_max_err: f32,
59 pub converged: bool,
60 #[serde(deserialize_with = "null_as_nan")]
61 pub final_holdout_mse: f32,
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub checkpoint: Option<PathBuf>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct MultiTrainReport {
68 pub batch: usize,
69 pub n_ffts: Vec<usize>,
70 pub max_steps: usize,
71 pub min_steps: usize,
72 pub until_converged: bool,
73 pub eval_batches: usize,
74 pub seed: u64,
75 #[serde(default = "default_grad_clip_report")]
76 pub grad_clip: f32,
77 #[serde(default)]
78 pub project_twiddles: bool,
79 #[serde(default = "default_use_fused")]
80 pub use_fused_train: bool,
81 #[serde(default = "default_optimizer_label")]
82 pub optimizer: String,
83 pub elapsed_ms: f64,
84 pub rows: Vec<MultiTrainEvalRow>,
85}
86
87fn default_grad_clip_report() -> f32 {
88 1.0
89}
90
91fn default_use_fused() -> bool {
92 true
93}
94
95fn default_optimizer_label() -> String {
96 "sgd".into()
97}
98
99struct SizeTwiddles {
100 encoder: Vec<f32>,
101 decoder: Vec<f32>,
102 opt: TwiddleOptState,
103}
104
105fn new_size_twiddles(model: &FftLearnConfig, optimizer: TwiddleOptimizer) -> SizeTwiddles {
106 let stages = model.n_fft.trailing_zeros() as usize;
107 let half = model.n_fft / 2;
108 let tw_len = stages * half * 2;
109 SizeTwiddles {
110 encoder: exact_twiddles(model),
111 decoder: exact_twiddles(model),
112 opt: TwiddleOptState::new(optimizer, tw_len, tw_len),
113 }
114}
115struct ConvergenceTracker {
116 patience: usize,
117 rel_delta: f32,
118 abs_delta: f32,
119 best: f32,
120 stale: usize,
121}
122
123impl ConvergenceTracker {
124 fn new(cfg: &MultiTrainConfig) -> Self {
125 Self {
126 patience: cfg.converge_patience,
127 rel_delta: cfg.converge_delta,
128 abs_delta: cfg.converge_delta * 1e-4,
129 best: f32::INFINITY,
130 stale: 0,
131 }
132 }
133
134 fn observe(&mut self, loss: f32) -> bool {
136 if !loss.is_finite() {
137 self.stale = 0;
138 return false;
139 }
140 let improved = if !self.best.is_finite() {
141 true
142 } else {
143 let drop = self.best - loss;
144 drop > self.abs_delta || drop / self.best.max(1e-12) > self.rel_delta
145 };
146 if improved {
147 self.best = loss;
148 self.stale = 0;
149 } else {
150 self.stale += 1;
151 }
152 self.stale >= self.patience
153 }
154}
155
156pub fn run_multi_train(cfg: &MultiTrainConfig) -> Result<MultiTrainReport> {
157 ensure!(!cfg.n_ffts.is_empty(), "n_ffts must not be empty");
158 ensure!(cfg.steps >= 1, "steps must be >= 1");
159 for &n in &cfg.n_ffts {
160 FftLearnConfig::new(n, cfg.batch)?;
161 }
162
163 let started = Instant::now();
164 let mut rows = Vec::new();
165
166 rows.extend(eval_exact_baseline(cfg)?);
167
168 for &schedule in &cfg.schedules {
169 eprintln!("[train-multi] schedule={}", schedule.label());
170 let regime_rows = train_schedule(cfg, schedule)?;
171 rows.extend(regime_rows);
172 }
173
174 let report = MultiTrainReport {
175 batch: cfg.batch,
176 n_ffts: cfg.n_ffts.clone(),
177 max_steps: cfg.steps,
178 min_steps: cfg.min_steps,
179 until_converged: cfg.until_converged,
180 eval_batches: cfg.eval_batches,
181 seed: cfg.seed,
182 grad_clip: cfg.grad_clip,
183 project_twiddles: cfg.project_twiddles,
184 use_fused_train: cfg.use_fused_train,
185 optimizer: cfg.optimizer.label().to_string(),
186 elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
187 rows,
188 };
189
190 if let Some(out) = &cfg.out_dir {
191 std::fs::create_dir_all(out)?;
192 let path = out.join("multi_train_report.json");
193 std::fs::write(&path, serde_json::to_vec_pretty(&report)?)?;
194 eprintln!("wrote {}", path.display());
195 }
196
197 Ok(report)
198}
199
200fn eval_exact_baseline(cfg: &MultiTrainConfig) -> Result<Vec<MultiTrainEvalRow>> {
201 let mut rows = Vec::new();
202 for &n in &cfg.n_ffts {
203 let model = FftLearnConfig::new(n, cfg.batch)?;
204 let enc = exact_twiddles(&model);
205 let dec = exact_twiddles(&model);
206 let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
207 let (enc_mse, enc_max, dec_mse, dec_max, rt_mse, rt_max) =
208 precision_encdec(&enc, &dec, &model, cfg.eval_batches, &mut rng)?;
209 rows.push(row_from_metrics(
210 "exact",
211 "exact",
212 vec![n],
213 n,
214 0,
215 0.0,
216 enc_mse,
217 enc_max,
218 dec_mse,
219 dec_max,
220 rt_mse,
221 rt_max,
222 true,
223 rt_mse,
224 None,
225 ));
226 }
227 Ok(rows)
228}
229
230fn train_schedule(
231 cfg: &MultiTrainConfig,
232 schedule: MultiTrainSchedule,
233) -> Result<Vec<MultiTrainEvalRow>> {
234 match schedule {
235 MultiTrainSchedule::Single => train_single_per_size(cfg),
236 MultiTrainSchedule::RoundRobin
237 | MultiTrainSchedule::Random
238 | MultiTrainSchedule::Balanced => train_mixed(cfg, schedule),
239 }
240}
241
242fn train_single_per_size(cfg: &MultiTrainConfig) -> Result<Vec<MultiTrainEvalRow>> {
243 let mut all_rows = Vec::new();
244
245 for &n in &cfg.n_ffts {
246 let model = FftLearnConfig::new(n, cfg.batch)?;
247 let tw = new_size_twiddles(&model, cfg.optimizer);
248 let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed.wrapping_add(n as u64));
249 let regime = format!("single_{n}");
250 let label = regime.clone();
251
252 let outcome = train_until_converged(
253 cfg,
254 &label,
255 &mut rng,
256 move |_step, tw_map, rng| {
257 let st = tw_map.get_mut(&n).expect("twiddles");
258 train_encdec_step_on(cfg, st, n, rng)
259 },
260 HashMap::from([(n, tw)]),
261 |tw_map, rng| holdout_mse(cfg, tw_map, &[n], rng),
262 )?;
263
264 let tw = outcome.tw;
265 let st = tw.get(&n).expect("twiddles");
266 let weights = EncDecWeights::from_twiddles(&st.encoder, &st.decoder, n);
267 let checkpoint = save_multi_checkpoint(cfg, ®ime, &weights, n)?;
268
269 all_rows.extend(eval_twiddles_matrix(
270 cfg,
271 ®ime,
272 schedule_label(MultiTrainSchedule::Single),
273 &[n],
274 outcome.steps,
275 outcome.elapsed_ms,
276 outcome.converged,
277 outcome.holdout_mse,
278 &tw,
279 checkpoint,
280 )?);
281 }
282 Ok(all_rows)
283}
284
285fn train_mixed(
286 cfg: &MultiTrainConfig,
287 schedule: MultiTrainSchedule,
288) -> Result<Vec<MultiTrainEvalRow>> {
289 let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
290 let mut tw: HashMap<usize, SizeTwiddles> = HashMap::new();
291 for &n in &cfg.n_ffts {
292 let model = FftLearnConfig::new(n, cfg.batch)?;
293 tw.insert(n, new_size_twiddles(&model, cfg.optimizer));
294 }
295
296 let regime = format!("mixed_{}", schedule.label());
297 let n_sizes = cfg.n_ffts.len();
298 let eval_sizes = cfg.n_ffts.clone();
299
300 let outcome = match schedule {
301 MultiTrainSchedule::Balanced => {
302 let per = cfg.steps / n_sizes;
303 ensure!(
304 per >= 1,
305 "steps={} too small for {} sizes in balanced mode",
306 cfg.steps,
307 n_sizes
308 );
309 train_balanced_until_converged(cfg, ®ime, per, &mut tw, &mut rng)?
310 }
311 MultiTrainSchedule::RoundRobin | MultiTrainSchedule::Random => train_until_converged(
312 cfg,
313 ®ime,
314 &mut rng,
315 move |step, tw_map, rng| {
316 let pick = match schedule {
317 MultiTrainSchedule::RoundRobin => cfg.n_ffts[step % n_sizes],
318 MultiTrainSchedule::Random => cfg.n_ffts[rng.gen_range(0..n_sizes)],
319 _ => unreachable!(),
320 };
321 let st = tw_map.get_mut(&pick).expect("twiddles");
322 train_encdec_step_on(cfg, st, pick, rng)
323 },
324 tw,
325 {
326 let eval_sizes = eval_sizes.clone();
327 move |tw_map, rng| holdout_mse(cfg, tw_map, &eval_sizes, rng)
328 },
329 )?,
330 MultiTrainSchedule::Single => unreachable!(),
331 };
332
333 let tw = outcome.tw;
334 let mut checkpoint = None;
335 if let Some(out_dir) = &cfg.out_dir {
336 let dir = out_dir.join(®ime);
337 std::fs::create_dir_all(&dir)?;
338 for &n in &cfg.n_ffts {
339 let st = tw.get(&n).expect("twiddles");
340 let weights = EncDecWeights::from_twiddles(&st.encoder, &st.decoder, n);
341 let path = dir.join(format!("n{n}_encdec.safetensors"));
342 export_safetensors(&path, &weights.merged())?;
343 }
344 checkpoint = Some(dir);
345 }
346
347 eval_twiddles_matrix(
348 cfg,
349 ®ime,
350 schedule.label().to_string(),
351 &cfg.n_ffts,
352 outcome.steps,
353 outcome.elapsed_ms,
354 outcome.converged,
355 outcome.holdout_mse,
356 &tw,
357 checkpoint,
358 )
359}
360
361struct ConvergeOutcome {
362 tw: HashMap<usize, SizeTwiddles>,
363 steps: usize,
364 elapsed_ms: f64,
365 converged: bool,
366 holdout_mse: f32,
367}
368
369fn train_until_converged<R: Rng>(
370 cfg: &MultiTrainConfig,
371 label: &str,
372 rng: &mut R,
373 mut step_fn: impl FnMut(usize, &mut HashMap<usize, SizeTwiddles>, &mut R) -> Result<()>,
374 mut tw: HashMap<usize, SizeTwiddles>,
375 mut holdout_fn: impl FnMut(&HashMap<usize, SizeTwiddles>, &mut R) -> Result<f32>,
376) -> Result<ConvergeOutcome> {
377 let started = Instant::now();
378 let mut tracker = ConvergenceTracker::new(cfg);
379 let mut step = 0usize;
380 let mut converged = false;
381 let mut holdout_mse = f32::INFINITY;
382
383 while step < cfg.steps {
384 step_fn(step, &mut tw, rng)?;
385 step += 1;
386
387 if cfg.until_converged && step >= cfg.min_steps && step.is_multiple_of(cfg.converge_every) {
388 holdout_mse = holdout_fn(&tw, rng)?;
389 eprintln!(
390 " [{label}] step {step} holdout_mse={holdout_mse:.6e} best={:.6e}",
391 tracker.best
392 );
393 if tracker.observe(holdout_mse) {
394 converged = true;
395 eprintln!(" [{label}] converged at step {step} holdout_mse={holdout_mse:.6e}");
396 break;
397 }
398 } else if cfg.log_every > 0 && step.is_multiple_of(cfg.log_every) {
399 eprintln!(" [{label}] step {step}/{}", cfg.steps);
400 }
401 }
402
403 if !holdout_mse.is_finite() {
404 holdout_mse = holdout_fn(&tw, rng)?;
405 }
406
407 Ok(ConvergeOutcome {
408 tw,
409 steps: step,
410 elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
411 converged: converged && holdout_mse.is_finite(),
412 holdout_mse,
413 })
414}
415
416fn train_balanced_until_converged<R: Rng>(
417 cfg: &MultiTrainConfig,
418 label: &str,
419 per_size: usize,
420 tw: &mut HashMap<usize, SizeTwiddles>,
421 rng: &mut R,
422) -> Result<ConvergeOutcome> {
423 let started = Instant::now();
424 let mut tracker = ConvergenceTracker::new(cfg);
425 let mut step = 0usize;
426 let mut converged = false;
427 let mut final_holdout = f32::INFINITY;
428 let eval_sizes = cfg.n_ffts.clone();
429
430 'outer: while step < cfg.steps {
431 for &n in &cfg.n_ffts {
432 if step >= cfg.steps {
433 break 'outer;
434 }
435 let st = tw.get_mut(&n).expect("twiddles");
436 train_encdec_step_on(cfg, st, n, rng)?;
437 step += 1;
438
439 if cfg.until_converged
440 && step >= cfg.min_steps
441 && step.is_multiple_of(cfg.converge_every)
442 {
443 let loss = holdout_mse(cfg, tw, &eval_sizes, rng)?;
444 eprintln!(
445 " [{label}] step {step} holdout_mse={loss:.6e} best={:.6e}",
446 tracker.best
447 );
448 if tracker.observe(loss) {
449 converged = true;
450 final_holdout = loss;
451 eprintln!(" [{label}] converged at step {step} holdout_mse={loss:.6e}");
452 break 'outer;
453 }
454 final_holdout = loss;
455 } else if cfg.log_every > 0 && step.is_multiple_of(cfg.log_every) {
456 eprintln!(
457 " [{label}] step {step}/{} (balanced ~{per_size}/size)",
458 cfg.steps
459 );
460 }
461 }
462 }
463
464 if !final_holdout.is_finite() {
465 final_holdout = holdout_mse(cfg, tw, &eval_sizes, rng)?;
466 }
467
468 Ok(ConvergeOutcome {
469 tw: std::mem::take(tw),
470 steps: step,
471 elapsed_ms: started.elapsed().as_secs_f64() * 1000.0,
472 converged: converged && final_holdout.is_finite(),
473 holdout_mse: final_holdout,
474 })
475}
476
477fn holdout_mse(
478 cfg: &MultiTrainConfig,
479 tw: &HashMap<usize, SizeTwiddles>,
480 sizes: &[usize],
481 rng: &mut impl Rng,
482) -> Result<f32> {
483 let mut acc = 0f32;
484 let mut n = 0f32;
485 for &size in sizes {
486 let Some(st) = tw.get(&size) else {
487 continue;
488 };
489 let model = FftLearnConfig::new(size, cfg.batch)?;
490 let (_, _, _, _, rt_mse, _) =
491 precision_encdec(&st.encoder, &st.decoder, &model, cfg.eval_batches, rng)?;
492 acc += rt_mse;
493 n += 1.0;
494 }
495 Ok(if n > 0.0 { acc / n } else { f32::INFINITY })
496}
497
498fn train_encdec_step_on(
499 cfg: &MultiTrainConfig,
500 st: &mut SizeTwiddles,
501 n: usize,
502 rng: &mut impl Rng,
503) -> Result<()> {
504 let signal = random_batch(rng, cfg.batch, n);
505 if cfg.use_fused_train {
506 fused_encdec_train_step(
507 &signal,
508 &mut st.encoder,
509 &mut st.decoder,
510 cfg.batch,
511 n,
512 cfg.lr,
513 cfg.spectrum_weight,
514 cfg.grad_clip,
515 cfg.project_twiddles,
516 Some(&mut st.opt),
517 )?;
518 } else {
519 butterfly_train_step_encdec(
520 &signal,
521 &mut st.encoder,
522 &mut st.decoder,
523 cfg.batch,
524 n,
525 cfg.lr as f32,
526 cfg.spectrum_weight,
527 )?;
528 }
529 Ok(())
530}
531
532fn eval_twiddles_matrix(
533 cfg: &MultiTrainConfig,
534 regime: &str,
535 schedule: String,
536 train_sizes: &[usize],
537 train_steps: usize,
538 train_elapsed_ms: f64,
539 converged: bool,
540 holdout_mse: f32,
541 tw: &HashMap<usize, SizeTwiddles>,
542 checkpoint: Option<PathBuf>,
543) -> Result<Vec<MultiTrainEvalRow>> {
544 let mut rows = Vec::new();
545 let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed.wrapping_add(17));
546
547 for &eval_n in &cfg.n_ffts {
548 let Some(st) = tw.get(&eval_n) else {
549 continue;
550 };
551 let model = FftLearnConfig::new(eval_n, cfg.batch)?;
552 let (enc_mse, enc_max, dec_mse, dec_max, rt_mse, rt_max) =
553 precision_encdec(&st.encoder, &st.decoder, &model, cfg.eval_batches, &mut rng)?;
554 rows.push(row_from_metrics(
555 regime,
556 &schedule,
557 train_sizes.to_vec(),
558 eval_n,
559 train_steps,
560 train_elapsed_ms,
561 enc_mse,
562 enc_max,
563 dec_mse,
564 dec_max,
565 rt_mse,
566 rt_max,
567 converged,
568 holdout_mse,
569 checkpoint.clone(),
570 ));
571 }
572 Ok(rows)
573}
574
575#[allow(clippy::too_many_arguments)]
576fn row_from_metrics(
577 regime: &str,
578 schedule: &str,
579 train_sizes: Vec<usize>,
580 eval_n_fft: usize,
581 train_steps_total: usize,
582 train_elapsed_ms: f64,
583 encoder_spectrum_mse: f32,
584 encoder_spectrum_max_err: f32,
585 decoder_time_mse: f32,
586 decoder_time_max_err: f32,
587 roundtrip_mse: f32,
588 roundtrip_max_err: f32,
589 converged: bool,
590 final_holdout_mse: f32,
591 checkpoint: Option<PathBuf>,
592) -> MultiTrainEvalRow {
593 MultiTrainEvalRow {
594 regime: regime.to_string(),
595 schedule: schedule.to_string(),
596 train_sizes,
597 eval_n_fft,
598 train_steps_total,
599 train_elapsed_ms,
600 encoder_spectrum_mse,
601 encoder_spectrum_max_err,
602 decoder_time_mse,
603 decoder_time_max_err,
604 roundtrip_mse,
605 roundtrip_max_err,
606 converged,
607 final_holdout_mse,
608 checkpoint,
609 }
610}
611
612fn save_multi_checkpoint(
613 cfg: &MultiTrainConfig,
614 regime: &str,
615 weights: &EncDecWeights,
616 n: usize,
617) -> Result<Option<PathBuf>> {
618 let Some(out) = &cfg.out_dir else {
619 return Ok(None);
620 };
621 let dir = out.join(regime);
622 std::fs::create_dir_all(&dir)?;
623 let path = dir.join(format!("n{n}_encdec.safetensors"));
624 export_safetensors(&path, &weights.merged())?;
625 Ok(Some(path))
626}
627
628fn schedule_label(s: MultiTrainSchedule) -> String {
629 s.label().to_string()
630}
631
632pub fn write_multi_train_json(path: &Path, report: &MultiTrainReport) -> Result<()> {
633 if let Some(parent) = path.parent() {
634 std::fs::create_dir_all(parent)?;
635 }
636 std::fs::write(path, serde_json::to_vec_pretty(report)?)?;
637 Ok(())
638}
639
640pub fn best_regime_per_eval(report: &MultiTrainReport) -> Vec<(usize, String, f32)> {
641 let mut out = Vec::new();
642 for &n in &report.n_ffts {
643 let best = report
644 .rows
645 .iter()
646 .filter(|r| r.eval_n_fft == n && r.regime != "exact")
647 .min_by(|a, b| {
648 a.roundtrip_max_err
649 .partial_cmp(&b.roundtrip_max_err)
650 .unwrap_or(std::cmp::Ordering::Equal)
651 });
652 if let Some(r) = best {
653 out.push((n, r.regime.clone(), r.roundtrip_max_err));
654 }
655 }
656 out
657}
658
659pub fn print_multi_train_table(report: &MultiTrainReport) {
660 eprintln!(
661 "\n=== Multi-n_fft training study (batch={}, max_steps={}, min_steps={}, until_converged={}) ===\n",
662 report.batch, report.max_steps, report.min_steps, report.until_converged
663 );
664
665 for &eval_n in &report.n_ffts {
666 eprintln!("--- eval n_fft={eval_n} ---");
667 eprintln!(
668 "{:<22} {:>10} {:>6} {:>10} {:>10} {:>10}",
669 "regime", "steps", "conv", "rt_max", "enc_max", "train_ms"
670 );
671 let mut subset: Vec<&MultiTrainEvalRow> = report
672 .rows
673 .iter()
674 .filter(|r| r.eval_n_fft == eval_n)
675 .collect();
676 subset.sort_by(|a, b| {
677 a.roundtrip_max_err
678 .partial_cmp(&b.roundtrip_max_err)
679 .unwrap_or(std::cmp::Ordering::Equal)
680 });
681 for r in &subset {
682 eprintln!(
683 "{:<22} {:>10} {:>6} {:>10.3e} {:>10.3e} {:>10.1}",
684 r.regime,
685 r.train_steps_total,
686 if r.converged { "yes" } else { "no" },
687 r.roundtrip_max_err,
688 r.encoder_spectrum_max_err,
689 r.train_elapsed_ms
690 );
691 }
692 if let Some(best) = subset.first() {
693 eprintln!(
694 " → best: {} (rt_max={:.3e}, steps={})\n",
695 best.regime, best.roundtrip_max_err, best.train_steps_total
696 );
697 }
698 }
699
700 eprintln!("--- train×eval roundtrip max_err matrix ---");
701 let regimes: Vec<String> = report
702 .rows
703 .iter()
704 .map(|r| r.regime.clone())
705 .collect::<std::collections::BTreeSet<_>>()
706 .into_iter()
707 .collect();
708 eprint!("{:>22}", "regime \\ eval");
709 for &n in &report.n_ffts {
710 eprint!(" {:>10}", n);
711 }
712 eprintln!();
713 for regime in ®imes {
714 eprint!("{regime:>22}");
715 for &n in &report.n_ffts {
716 let cell = report
717 .rows
718 .iter()
719 .find(|r| r.regime == *regime && r.eval_n_fft == n);
720 if let Some(r) = cell {
721 eprint!(" {:>10.2e}", r.roundtrip_max_err);
722 } else {
723 eprint!(" {:>10}", "—");
724 }
725 }
726 eprintln!();
727 }
728 eprintln!("\nTotal study time: {:.1} ms\n", report.elapsed_ms);
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use crate::config::MultiTrainSchedule;
735
736 fn test_cfg(steps: usize, schedules: Vec<MultiTrainSchedule>) -> MultiTrainConfig {
737 MultiTrainConfig {
738 n_ffts: vec![64, 128],
739 batch: 4,
740 steps,
741 schedules,
742 lr: 5e-4,
743 spectrum_weight: 1.0,
744 seed: 1,
745 log_every: 0,
746 eval_batches: 2,
747 out_dir: None,
748 until_converged: false,
749 min_steps: 300,
750 converge_every: 25,
751 converge_patience: 5,
752 converge_delta: 1e-4,
753 grad_clip: 1.0,
754 project_twiddles: true,
755 use_fused_train: true,
756 optimizer: TwiddleOptimizer::Sgd,
757 }
758 }
759
760 #[test]
761 fn multi_train_single_schedule() {
762 let report = run_multi_train(&test_cfg(40, vec![MultiTrainSchedule::Single])).unwrap();
763 assert!(report.rows.iter().any(|r| r.regime == "single_64"));
764 assert!(report.rows.iter().any(|r| r.regime == "single_128"));
765 for &n in &[64usize, 128] {
766 let best = report
767 .rows
768 .iter()
769 .filter(|r| r.eval_n_fft == n && r.regime.starts_with("single_"))
770 .map(|r| r.roundtrip_max_err)
771 .fold(f32::INFINITY, f32::min);
772 assert!(best < 0.5, "n={n} single train rt_max={best}");
773 }
774 }
775
776 #[test]
777 fn mixed_round_robin_runs() {
778 let report = run_multi_train(&test_cfg(20, vec![MultiTrainSchedule::RoundRobin])).unwrap();
779 assert!(report.rows.iter().any(|r| r.regime == "mixed_round_robin"));
780 }
781
782 #[test]
783 fn convergence_stops_early() {
784 let mut cfg = test_cfg(2000, vec![MultiTrainSchedule::Single]);
785 cfg.n_ffts = vec![64];
786 cfg.until_converged = true;
787 cfg.min_steps = 20;
788 cfg.converge_every = 10;
789 cfg.converge_patience = 2;
790 cfg.converge_delta = 1e-2;
791 let report = run_multi_train(&cfg).unwrap();
792 let row = report
793 .rows
794 .iter()
795 .find(|r| r.regime == "single_64")
796 .expect("single_64");
797 assert!(row.converged, "expected early convergence");
798 assert!(
799 row.train_steps_total < cfg.steps,
800 "expected fewer than max steps"
801 );
802 }
803
804 #[test]
805 fn fused_single_1024_stays_finite() {
806 let mut cfg = test_cfg(80, vec![MultiTrainSchedule::Single]);
807 cfg.n_ffts = vec![1024];
808 cfg.until_converged = false;
809 cfg.lr = 1e-4;
810 cfg.use_fused_train = true;
811 cfg.optimizer = TwiddleOptimizer::Adam;
812 cfg.project_twiddles = true;
813 let report = run_multi_train(&cfg).unwrap();
814 let row = report
815 .rows
816 .iter()
817 .find(|r| r.regime == "single_1024")
818 .expect("single_1024");
819 assert!(
820 row.roundtrip_max_err.is_finite(),
821 "rt_max={}",
822 row.roundtrip_max_err
823 );
824 assert!(row.encoder_spectrum_max_err.is_finite());
825 }
826}