Skip to main content

rlx_fft/
study_telemetry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Training telemetry for study reports: loss curves, param counts, activation heatmaps, loss landscape.
17
18use crate::config::FftLearnConfig;
19use crate::pruned::init_gates;
20use crate::twiddle::exact_twiddles;
21use crate::unitary::UnitaryWeights;
22use crate::variants::FftVariantId;
23use serde::{Deserialize, Serialize};
24
25pub const F32_BYTES: usize = 4;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ParamBreakdown {
29    pub twiddles: usize,
30    pub gates: usize,
31    pub freq_mask: usize,
32    pub denoiser: usize,
33    pub unitary: usize,
34    pub mel_filters: usize,
35    pub q8_packed: usize,
36    pub total_params: usize,
37    pub memory_bytes: usize,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LossPoint {
42    pub step: usize,
43    pub total_loss: f32,
44    pub mel_err: f32,
45    pub spec_err: f32,
46    pub welch_err: f32,
47    pub mean_gate: f32,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct LossLandscape3D {
52    /// Grid over twiddle[0] (real) vs twiddle[1] (imag) with other weights fixed.
53    pub x_label: String,
54    pub y_label: String,
55    pub x: Vec<f32>,
56    pub y: Vec<f32>,
57    pub z: Vec<Vec<f32>>,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ActivationHeatmap {
62    /// Row = FFT stage, col = butterfly index; values in [0,1] (gate activation).
63    pub stages: usize,
64    pub butterflies: usize,
65    pub gates: Vec<f32>,
66    /// Optional per-bin frequency mask (n_fft complex bins).
67    #[serde(default, skip_serializing_if = "Vec::is_empty")]
68    pub freq_mask: Vec<f32>,
69    /// Twiddle magnitude per stage×butterfly (|w|).
70    #[serde(default, skip_serializing_if = "Vec::is_empty")]
71    pub twiddle_mag: Vec<f32>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ModelTrainingTrace {
76    pub model_id: String,
77    pub variant: String,
78    pub n_fft: usize,
79    pub batch: usize,
80    pub train_steps: usize,
81    pub params: ParamBreakdown,
82    pub loss_curve: Vec<LossPoint>,
83    pub heatmap: ActivationHeatmap,
84    pub landscape: Option<LossLandscape3D>,
85    pub final_mel_err: f32,
86    pub final_spec_err: f32,
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct StudyTelemetryBundle {
91    pub models: Vec<ModelTrainingTrace>,
92}
93
94pub fn variant_param_breakdown(variant: FftVariantId, cfg: &FftLearnConfig) -> ParamBreakdown {
95    let stages = cfg.num_stages();
96    let half = cfg.n_fft / 2;
97    let tw = stages * half * 2;
98    let (unitary, gates, freq_mask, denoiser, mel_filters, q8) = match variant {
99        FftVariantId::Rustfft | FftVariantId::RlxOpFft | FftVariantId::RlxOpIfft => {
100            (0, 0, 0, 0, 0, 0)
101        }
102        FftVariantId::ButterflyUnitary => (UnitaryWeights::param_count(cfg.n_fft), 0, 0, 0, 0, 0),
103        FftVariantId::ButterflyQ8 => (tw, 0, 0, 0, 0, tw / 2),
104        FftVariantId::WelchRustfft
105        | FftVariantId::WelchRlxOpFft
106        | FftVariantId::WelchButterflyEager
107        | FftVariantId::WelchButterflyCompiled => (tw, 0, 0, 0, 0, 0),
108        _ => (tw, 0, 0, 0, 0, 0),
109    };
110    let total = tw + unitary + gates + freq_mask + denoiser + mel_filters + q8;
111    ParamBreakdown {
112        twiddles: tw,
113        gates,
114        freq_mask,
115        denoiser,
116        unitary,
117        mel_filters,
118        q8_packed: q8,
119        total_params: total,
120        memory_bytes: total * F32_BYTES,
121    }
122}
123
124pub fn learned_model_param_breakdown(n_fft: usize, n_mels: usize) -> ParamBreakdown {
125    let cfg = FftLearnConfig::new(n_fft, 1).expect("n_fft");
126    let tw = exact_twiddles(&cfg).len();
127    let gates = init_gates(n_fft).len();
128    let fm = n_fft * 2;
129    let dn = n_fft * 2 * 2;
130    let mf = n_mels * (n_fft / 2 + 1);
131    let total = tw + gates + fm + dn + mf;
132    ParamBreakdown {
133        twiddles: tw,
134        gates,
135        freq_mask: fm,
136        denoiser: dn,
137        unitary: 0,
138        mel_filters: mf,
139        q8_packed: 0,
140        total_params: total,
141        memory_bytes: total * F32_BYTES,
142    }
143}
144
145pub fn gate_heatmap_from_vec(gates: &[f32], n_fft: usize) -> ActivationHeatmap {
146    let stages = crate::butterfly::num_stages(n_fft);
147    let half = n_fft / 2;
148    let mut tw_mag = vec![0f32; stages * half];
149    let tw = exact_twiddles(&FftLearnConfig::new(n_fft, 1).unwrap());
150    for s in 0..stages {
151        for b in 0..half {
152            let w_base = crate::twiddle::twiddle_index(s, b, half, 0);
153            let re = tw[w_base];
154            let im = tw[w_base + 1];
155            tw_mag[s * half + b] = (re * re + im * im).sqrt();
156        }
157    }
158    ActivationHeatmap {
159        stages,
160        butterflies: half,
161        gates: gates.to_vec(),
162        freq_mask: Vec::new(),
163        twiddle_mag: tw_mag,
164    }
165}