1use anyhow::{Context, Result, bail, ensure};
19use serde::{Deserialize, Serialize};
20
21pub const SUPPORTED_N_FFT: &[usize] = &[
23 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072,
24];
25
26pub const FULL_N_FFT_SWEEP: &[usize] = SUPPORTED_N_FFT;
28
29pub const LIMIT_SWEEP_REQUESTED_BATCHES: &[usize] =
31 &[4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1];
32
33pub fn batch_cap_for_limit_sweep(n_fft: usize) -> usize {
35 match n_fft {
36 n if n <= 128 => 4096,
37 n if n <= 256 => 2048,
38 n if n <= 512 => 1024,
39 n if n <= 1024 => 512,
40 n if n <= 2048 => 256,
41 n if n <= 4096 => 128,
42 n if n <= 8192 => 64,
43 n if n <= 16384 => 32,
44 n if n <= 32768 => 16,
45 n if n <= 65536 => 8,
46 _ => 4,
47 }
48}
49
50pub fn adaptive_batches_for_n_fft(n_fft: usize, requested: &[usize]) -> Vec<usize> {
52 adaptive_batches_with_cap(n_fft, requested, batch_cap_for_limit_sweep(n_fft))
53}
54
55pub fn adaptive_batches_with_cap(n_fft: usize, requested: &[usize], cap: usize) -> Vec<usize> {
56 let _ = n_fft;
57 let mut out: Vec<usize> = requested
58 .iter()
59 .copied()
60 .filter(|&b| b >= 1 && b <= cap)
61 .collect();
62 if out.is_empty() {
63 out.push(cap.max(1));
64 }
65 out.sort_unstable();
66 out.dedup();
67 out
68}
69
70pub fn limit_sweep_batches(n_fft: usize) -> Vec<usize> {
72 adaptive_batches_for_n_fft(n_fft, LIMIT_SWEEP_REQUESTED_BATCHES)
73}
74
75pub fn compiled_ok_for_n_fft(n_fft: usize) -> bool {
77 n_fft <= 1024
78}
79
80pub fn compiled_ok_for_limit_sweep(n_fft: usize, device: &str) -> bool {
82 if n_fft > 4096 {
83 return false;
84 }
85 match device.to_ascii_lowercase().as_str() {
86 "cpu" => n_fft <= 1024,
87 "metal" | "cuda" | "mlx" | "mps" | "rocm" | "wgpu" | "wgu" | "vulkan" | "gpu" => {
88 n_fft <= 2048
89 }
90 _ => n_fft <= 1024,
91 }
92}
93
94pub fn is_gpu_device_label(device: &str) -> bool {
95 matches!(
96 device.to_ascii_lowercase().as_str(),
97 "metal" | "cuda" | "mlx" | "mps" | "rocm" | "wgpu" | "wgu" | "vulkan" | "gpu"
98 )
99}
100
101pub fn welch_ok_for_limit_sweep(n_fft: usize) -> bool {
103 n_fft <= 32768
104}
105
106pub fn welch_ok_for_config(n_fft: usize, batch: usize) -> bool {
108 if !welch_ok_for_limit_sweep(n_fft) {
109 return false;
110 }
111 let hop = n_fft / 2;
112 let frame = n_fft + 7 * hop;
113 let bytes = batch.saturating_mul(frame).saturating_mul(4);
114 bytes <= 512 * 1024 * 1024
115}
116
117pub fn train_steps_for_n_fft(base: usize, n_fft: usize) -> usize {
119 match n_fft {
120 n if n > 65536 => base.min(2),
121 n if n > 32768 => base.min(3),
122 n if n > 16384 => base.min(4),
123 n if n > 8192 => base.min(5),
124 n if n > 4096 => base.min(8),
125 n if n > 2048 => base.min(12),
126 n if n > 1024 => base.min(15),
127 _ => base,
128 }
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132pub enum TransformDir {
133 Forward,
134 Inverse,
135}
136
137impl TransformDir {
138 pub fn is_forward(self) -> bool {
139 matches!(self, Self::Forward)
140 }
141
142 pub fn is_inverse(self) -> bool {
143 matches!(self, Self::Inverse)
144 }
145}
146
147impl std::str::FromStr for TransformDir {
148 type Err = anyhow::Error;
149
150 fn from_str(s: &str) -> Result<Self> {
151 match s.to_ascii_lowercase().as_str() {
152 "forward" | "fft" => Ok(Self::Forward),
153 "inverse" | "ifft" => Ok(Self::Inverse),
154 other => bail!("unknown transform direction: {other} (use fft|ifft)"),
155 }
156 }
157}
158
159pub fn parse_transform_dir(s: &str) -> Result<TransformDir> {
160 s.parse()
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct FftLearnConfig {
165 pub n_fft: usize,
167 pub batch: usize,
169}
170
171impl FftLearnConfig {
172 pub fn new(n_fft: usize, batch: usize) -> Result<Self> {
173 ensure!(
174 n_fft.is_power_of_two(),
175 "n_fft must be a power of two, got {n_fft}"
176 );
177 ensure!(n_fft >= 4, "n_fft must be at least 4");
178 ensure!(batch >= 1, "batch must be >= 1");
179 Ok(Self { n_fft, batch })
180 }
181
182 pub fn tiny() -> Self {
183 Self {
184 n_fft: 64,
185 batch: 4,
186 }
187 }
188
189 pub fn num_stages(&self) -> usize {
190 self.n_fft.trailing_zeros() as usize
191 }
192
193 pub fn butterflies_per_stage(&self) -> usize {
194 self.n_fft / 2
195 }
196
197 pub fn twiddle_param_count(&self) -> usize {
198 self.num_stages() * self.butterflies_per_stage() * 2
199 }
200
201 pub fn validate(&self) -> Result<()> {
202 Self::new(self.n_fft, self.batch)?;
203 Ok(())
204 }
205}
206
207pub fn parse_n_fft(s: &str) -> Result<usize> {
208 let n: usize = s.parse().context("n_fft: usize")?;
209 FftLearnConfig::new(n, 1).map(|_| n)
210}
211
212pub fn ensure_supported_n_fft(n_fft: usize) -> Result<()> {
213 if SUPPORTED_N_FFT.contains(&n_fft) {
214 return Ok(());
215 }
216 bail!(
217 "unsupported n_fft={n_fft}; supported: {}",
218 SUPPORTED_N_FFT
219 .iter()
220 .map(|n| n.to_string())
221 .collect::<Vec<_>>()
222 .join(", ")
223 );
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct TrainConfig {
228 pub model: FftLearnConfig,
229 pub direction: TransformDir,
230 pub steps: usize,
231 pub lr: f64,
232 pub weight_decay: f32,
233 pub beta1: f64,
234 pub beta2: f64,
235 pub eps: f64,
236 pub grad_clip: f32,
237 pub seed: u64,
238 pub log_every: usize,
239 pub device: String,
240 pub out_dir: Option<std::path::PathBuf>,
241}
242
243impl Default for TrainConfig {
244 fn default() -> Self {
245 Self {
246 model: FftLearnConfig::tiny(),
247 direction: TransformDir::Forward,
248 steps: 500,
249 lr: 1e-3,
250 weight_decay: 0.0,
251 beta1: 0.9,
252 beta2: 0.999,
253 eps: 1e-8,
254 grad_clip: 1.0,
255 seed: 42,
256 log_every: 50,
257 device: "auto".to_string(),
258 out_dir: None,
259 }
260 }
261}
262
263#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct PhasedTrainConfig {
266 pub model: FftLearnConfig,
267 pub encoder_steps: usize,
268 pub decoder_steps: usize,
269 pub joint_steps: usize,
270 pub lr: f64,
271 pub spectrum_weight: f32,
272 pub seed: u64,
273 pub log_every: usize,
274 pub out_dir: Option<std::path::PathBuf>,
275}
276
277impl Default for PhasedTrainConfig {
278 fn default() -> Self {
279 Self {
280 model: FftLearnConfig::tiny(),
281 encoder_steps: 300,
282 decoder_steps: 300,
283 joint_steps: 300,
284 lr: 5e-4,
285 spectrum_weight: 1.0,
286 seed: 42,
287 log_every: 50,
288 out_dir: None,
289 }
290 }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct EncDecTrainConfig {
296 pub model: FftLearnConfig,
297 pub steps: usize,
298 pub lr: f64,
299 pub spectrum_weight: f32,
301 pub seed: u64,
302 pub log_every: usize,
303 pub device: String,
304 pub out_dir: Option<std::path::PathBuf>,
305 #[serde(default = "default_grad_clip")]
306 pub grad_clip: f32,
307 #[serde(default = "default_project_twiddles")]
308 pub project_twiddles: bool,
309}
310
311fn default_grad_clip() -> f32 {
312 1.0
313}
314
315fn default_project_twiddles() -> bool {
316 true
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
321pub enum MultiTrainSchedule {
322 Single,
324 RoundRobin,
326 Random,
328 Balanced,
330}
331
332impl MultiTrainSchedule {
333 pub fn label(self) -> &'static str {
334 match self {
335 Self::Single => "single",
336 Self::RoundRobin => "round_robin",
337 Self::Random => "random",
338 Self::Balanced => "balanced",
339 }
340 }
341
342 pub fn all() -> &'static [Self] {
343 &[Self::Single, Self::RoundRobin, Self::Random, Self::Balanced]
344 }
345
346 pub fn parse_csv(s: &str) -> anyhow::Result<Vec<Self>> {
347 let mut out = Vec::new();
348 for part in s.split(',') {
349 let part = part.trim().to_ascii_lowercase();
350 if part.is_empty() {
351 continue;
352 }
353 out.push(match part.as_str() {
354 "single" => Self::Single,
355 "round_robin" | "round-robin" | "rr" => Self::RoundRobin,
356 "random" => Self::Random,
357 "balanced" => Self::Balanced,
358 other => anyhow::bail!(
359 "unknown schedule {other} (use single,round_robin,random,balanced)"
360 ),
361 });
362 }
363 anyhow::ensure!(!out.is_empty(), "schedules list is empty");
364 Ok(out)
365 }
366}
367
368#[derive(Debug, Clone, Serialize, Deserialize)]
369pub struct MultiTrainConfig {
370 pub n_ffts: Vec<usize>,
371 pub batch: usize,
372 pub steps: usize,
374 pub schedules: Vec<MultiTrainSchedule>,
375 pub lr: f64,
376 pub spectrum_weight: f32,
377 pub seed: u64,
378 pub log_every: usize,
379 pub eval_batches: usize,
380 pub out_dir: Option<std::path::PathBuf>,
381 pub until_converged: bool,
383 pub min_steps: usize,
385 pub converge_every: usize,
387 pub converge_patience: usize,
389 pub converge_delta: f32,
391 pub grad_clip: f32,
393 pub project_twiddles: bool,
395 pub use_fused_train: bool,
397 pub optimizer: crate::second_order::TwiddleOptimizer,
398}
399
400impl Default for MultiTrainConfig {
401 fn default() -> Self {
402 Self {
403 n_ffts: vec![64, 256],
404 batch: 8,
405 steps: 10_000,
406 schedules: MultiTrainSchedule::all().to_vec(),
407 lr: 5e-4,
408 spectrum_weight: 1.0,
409 seed: 42,
410 log_every: 50,
411 eval_batches: 8,
412 out_dir: None,
413 until_converged: true,
414 min_steps: 300,
415 converge_every: 25,
416 converge_patience: 5,
417 converge_delta: 1e-4,
418 grad_clip: 1.0,
419 project_twiddles: true,
420 use_fused_train: true,
421 optimizer: crate::second_order::TwiddleOptimizer::Sgd,
422 }
423 }
424}
425
426impl Default for EncDecTrainConfig {
427 fn default() -> Self {
428 Self {
429 model: FftLearnConfig::tiny(),
430 steps: 500,
431 lr: 1e-3,
432 spectrum_weight: 1.0,
433 seed: 42,
434 log_every: 50,
435 device: "auto".to_string(),
436 out_dir: None,
437 grad_clip: default_grad_clip(),
438 project_twiddles: default_project_twiddles(),
439 }
440 }
441}