1use crate::config::FftLearnConfig;
19use crate::pruned::init_gates;
20use crate::twiddle::exact_twiddles;
21use crate::unitary::UnitaryWeights;
22use crate::variants::FftVariantId;
23use serde::{Deserialize, Serialize};
24
25pub const F32_BYTES: usize = 4;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ParamBreakdown {
29 pub twiddles: usize,
30 pub gates: usize,
31 pub freq_mask: usize,
32 pub denoiser: usize,
33 pub unitary: usize,
34 pub mel_filters: usize,
35 pub q8_packed: usize,
36 pub total_params: usize,
37 pub memory_bytes: usize,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LossPoint {
42 pub step: usize,
43 pub total_loss: f32,
44 pub mel_err: f32,
45 pub spec_err: f32,
46 pub welch_err: f32,
47 pub mean_gate: f32,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LossLandscape3D {
52 pub x_label: String,
54 pub y_label: String,
55 pub x: Vec<f32>,
56 pub y: Vec<f32>,
57 pub z: Vec<Vec<f32>>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ActivationHeatmap {
62 pub stages: usize,
64 pub butterflies: usize,
65 pub gates: Vec<f32>,
66 #[serde(default, skip_serializing_if = "Vec::is_empty")]
68 pub freq_mask: Vec<f32>,
69 #[serde(default, skip_serializing_if = "Vec::is_empty")]
71 pub twiddle_mag: Vec<f32>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelTrainingTrace {
76 pub model_id: String,
77 pub variant: String,
78 pub n_fft: usize,
79 pub batch: usize,
80 pub train_steps: usize,
81 pub params: ParamBreakdown,
82 pub loss_curve: Vec<LossPoint>,
83 pub heatmap: ActivationHeatmap,
84 pub landscape: Option<LossLandscape3D>,
85 pub final_mel_err: f32,
86 pub final_spec_err: f32,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct StudyTelemetryBundle {
91 pub models: Vec<ModelTrainingTrace>,
92}
93
94pub fn variant_param_breakdown(variant: FftVariantId, cfg: &FftLearnConfig) -> ParamBreakdown {
95 let stages = cfg.num_stages();
96 let half = cfg.n_fft / 2;
97 let tw = stages * half * 2;
98 let (unitary, gates, freq_mask, denoiser, mel_filters, q8) = match variant {
99 FftVariantId::Rustfft | FftVariantId::RlxOpFft | FftVariantId::RlxOpIfft => {
100 (0, 0, 0, 0, 0, 0)
101 }
102 FftVariantId::ButterflyUnitary => (UnitaryWeights::param_count(cfg.n_fft), 0, 0, 0, 0, 0),
103 FftVariantId::ButterflyQ8 => (tw, 0, 0, 0, 0, tw / 2),
104 FftVariantId::WelchRustfft
105 | FftVariantId::WelchRlxOpFft
106 | FftVariantId::WelchButterflyEager
107 | FftVariantId::WelchButterflyCompiled => (tw, 0, 0, 0, 0, 0),
108 _ => (tw, 0, 0, 0, 0, 0),
109 };
110 let total = tw + unitary + gates + freq_mask + denoiser + mel_filters + q8;
111 ParamBreakdown {
112 twiddles: tw,
113 gates,
114 freq_mask,
115 denoiser,
116 unitary,
117 mel_filters,
118 q8_packed: q8,
119 total_params: total,
120 memory_bytes: total * F32_BYTES,
121 }
122}
123
124pub fn learned_model_param_breakdown(n_fft: usize, n_mels: usize) -> ParamBreakdown {
125 let cfg = FftLearnConfig::new(n_fft, 1).expect("n_fft");
126 let tw = exact_twiddles(&cfg).len();
127 let gates = init_gates(n_fft).len();
128 let fm = n_fft * 2;
129 let dn = n_fft * 2 * 2;
130 let mf = n_mels * (n_fft / 2 + 1);
131 let total = tw + gates + fm + dn + mf;
132 ParamBreakdown {
133 twiddles: tw,
134 gates,
135 freq_mask: fm,
136 denoiser: dn,
137 unitary: 0,
138 mel_filters: mf,
139 q8_packed: 0,
140 total_params: total,
141 memory_bytes: total * F32_BYTES,
142 }
143}
144
145pub fn gate_heatmap_from_vec(gates: &[f32], n_fft: usize) -> ActivationHeatmap {
146 let stages = crate::butterfly::num_stages(n_fft);
147 let half = n_fft / 2;
148 let mut tw_mag = vec![0f32; stages * half];
149 let tw = exact_twiddles(&FftLearnConfig::new(n_fft, 1).unwrap());
150 for s in 0..stages {
151 for b in 0..half {
152 let w_base = crate::twiddle::twiddle_index(s, b, half, 0);
153 let re = tw[w_base];
154 let im = tw[w_base + 1];
155 tw_mag[s * half + b] = (re * re + im * im).sqrt();
156 }
157 }
158 ActivationHeatmap {
159 stages,
160 butterflies: half,
161 gates: gates.to_vec(),
162 freq_mask: Vec::new(),
163 twiddle_mag: tw_mag,
164 }
165}