Skip to main content

deep_delta_learning/
spectral.rs

1use serde::{Deserialize, Serialize};
2
3use crate::error::SpectralError;
4use crate::utils::cosine_similarity;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum DeltaRegime {
8    NearIdentity,
9    Interpolating,
10    NearProjection,
11    NearReflection,
12    StrongReflection,
13}
14
15impl DeltaRegime {
16    pub fn from_beta(beta: f32) -> Self {
17        match beta {
18            beta if beta < 0.3 => Self::NearIdentity,
19            beta if beta < 0.7 => Self::Interpolating,
20            beta if beta < 1.3 => Self::NearProjection,
21            beta if beta < 1.7 => Self::NearReflection,
22            _ => Self::StrongReflection,
23        }
24    }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
28pub enum Sublayer {
29    Attention,
30    Mlp,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
34pub struct BetaHistogram {
35    pub identity_regime: usize,
36    pub projection_regime: usize,
37    pub reflection_regime: usize,
38}
39
40impl BetaHistogram {
41    pub fn observe(&mut self, beta: f32) {
42        if beta < 0.5 {
43            self.identity_regime += 1;
44        } else if beta < 1.5 {
45            self.projection_regime += 1;
46        } else {
47            self.reflection_regime += 1;
48        }
49    }
50
51    pub fn from_betas(betas: &[f32]) -> Self {
52        let mut histogram = Self::default();
53        for beta in betas {
54            histogram.observe(*beta);
55        }
56        histogram
57    }
58
59    pub fn total(&self) -> usize {
60        self.identity_regime + self.projection_regime + self.reflection_regime
61    }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct SpectralInfo {
66    pub beta_mean: f32,
67    pub k_eigenvalue_mean: f32,
68    pub determinant_mean: f32,
69    pub lifted_determinant_mean: f32,
70    pub regime: DeltaRegime,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct DeltaResInfo {
75    pub beta_mean: f32,
76    pub k_eigenvalue_mean: f32,
77    pub correction_norm: f32,
78    pub regime: DeltaRegime,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct LayerDiagnostics {
83    pub attention: DeltaResInfo,
84    pub mlp: DeltaResInfo,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct LayerSpectralInfo {
89    pub layer_idx: usize,
90    pub sublayer: Sublayer,
91    pub beta_mean: f32,
92    pub beta_std: f32,
93    pub k_eigenvalue_mean: f32,
94    pub spatial_determinant_mean: f32,
95    pub lifted_determinant_mean: f32,
96    pub regime: DeltaRegime,
97    pub correction_magnitude: f32,
98    pub k_direction_entropy: f32,
99    k_direction: Vec<f32>,
100}
101
102impl LayerSpectralInfo {
103    #[allow(clippy::too_many_arguments)]
104    pub fn new(
105        layer_idx: usize,
106        sublayer: Sublayer,
107        beta_mean: f32,
108        beta_std: f32,
109        k_eigenvalue_mean: f32,
110        spatial_determinant_mean: f32,
111        lifted_determinant_mean: f32,
112        regime: DeltaRegime,
113        correction_magnitude: f32,
114        k_direction_entropy: f32,
115        k_direction: Vec<f32>,
116    ) -> Self {
117        Self {
118            layer_idx,
119            sublayer,
120            beta_mean,
121            beta_std,
122            k_eigenvalue_mean,
123            spatial_determinant_mean,
124            lifted_determinant_mean,
125            regime,
126            correction_magnitude,
127            k_direction_entropy,
128            k_direction,
129        }
130    }
131
132    pub fn k_direction(&self) -> &[f32] {
133        &self.k_direction
134    }
135}
136
137#[derive(Debug, Clone, Default, Serialize, Deserialize)]
138pub struct SpectralDiagnostics {
139    pub layers: Vec<LayerSpectralInfo>,
140    pub beta_per_layer: Vec<f32>,
141    pub k_eigenvalue_per_layer: Vec<f32>,
142    pub spatial_determinant_per_layer: Vec<f32>,
143    pub lifted_determinant_per_layer: Vec<f32>,
144    pub regime_per_layer: Vec<DeltaRegime>,
145    pub beta_histogram: BetaHistogram,
146    pub k_coherence_per_layer: Vec<f32>,
147    pub correction_norm_per_layer: Vec<f32>,
148}
149
150impl SpectralDiagnostics {
151    pub fn from_layers(layers: Vec<LayerSpectralInfo>) -> Self {
152        let beta_per_layer = layers
153            .iter()
154            .map(|layer| layer.beta_mean)
155            .collect::<Vec<_>>();
156        let k_eigenvalue_per_layer = layers
157            .iter()
158            .map(|layer| layer.k_eigenvalue_mean)
159            .collect::<Vec<_>>();
160        let spatial_determinant_per_layer = layers
161            .iter()
162            .map(|layer| layer.spatial_determinant_mean)
163            .collect::<Vec<_>>();
164        let lifted_determinant_per_layer = layers
165            .iter()
166            .map(|layer| layer.lifted_determinant_mean)
167            .collect::<Vec<_>>();
168        let regime_per_layer = layers.iter().map(|layer| layer.regime).collect::<Vec<_>>();
169        let correction_norm_per_layer = layers
170            .iter()
171            .map(|layer| layer.correction_magnitude)
172            .collect::<Vec<_>>();
173        let beta_histogram = BetaHistogram::from_betas(&beta_per_layer);
174
175        let mut k_coherence_per_layer = Vec::with_capacity(layers.len());
176        let mut previous = None;
177        for layer in &layers {
178            let coherence =
179                previous.map_or(1.0, |prev| cosine_similarity(prev, layer.k_direction()));
180            k_coherence_per_layer.push(coherence);
181            previous = Some(layer.k_direction());
182        }
183
184        Self {
185            layers,
186            beta_per_layer,
187            k_eigenvalue_per_layer,
188            spatial_determinant_per_layer,
189            lifted_determinant_per_layer,
190            regime_per_layer,
191            beta_histogram,
192            k_coherence_per_layer,
193            correction_norm_per_layer,
194        }
195    }
196}
197
198#[derive(Debug, Clone, Default, Serialize, Deserialize)]
199pub struct ModelDiagnostics {
200    pub layers: Vec<LayerDiagnostics>,
201}
202
203impl ModelDiagnostics {
204    pub fn beta_per_layer(&self) -> Vec<f32> {
205        self.layers
206            .iter()
207            .flat_map(|layer| [layer.attention.beta_mean, layer.mlp.beta_mean])
208            .collect()
209    }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct SpectralCollector {
214    history: Vec<SpectralDiagnostics>,
215    max_history: usize,
216}
217
218impl SpectralCollector {
219    pub fn new(max_history: usize) -> Self {
220        Self::try_new(max_history)
221            .unwrap_or_else(|error| panic!("invalid spectral collector configuration: {error}"))
222    }
223
224    pub fn try_new(max_history: usize) -> Result<Self, SpectralError> {
225        if max_history == 0 {
226            return Err(SpectralError::InvalidMaxHistory(max_history));
227        }
228        Ok(Self {
229            history: Vec::new(),
230            max_history,
231        })
232    }
233
234    pub fn record(&mut self, diagnostics: SpectralDiagnostics) {
235        self.history.push(diagnostics);
236        if self.history.len() > self.max_history {
237            let excess = self.history.len() - self.max_history;
238            self.history.drain(0..excess);
239        }
240    }
241
242    pub fn latest(&self) -> Option<&SpectralDiagnostics> {
243        self.history.last()
244    }
245
246    pub fn history(&self) -> &[SpectralDiagnostics] {
247        &self.history
248    }
249
250    pub fn beta_evolution(&self, layer: usize) -> Vec<f32> {
251        self.history
252            .iter()
253            .filter_map(|diagnostics| diagnostics.beta_per_layer.get(layer).copied())
254            .collect()
255    }
256
257    pub fn regime_transitions(&self, layer: usize) -> Vec<(usize, DeltaRegime)> {
258        let mut transitions = Vec::new();
259        let mut previous = None;
260
261        for (step, diagnostics) in self.history.iter().enumerate() {
262            let Some(regime) = diagnostics.regime_per_layer.get(layer).copied() else {
263                continue;
264            };
265
266            if previous != Some(regime) {
267                transitions.push((step, regime));
268                previous = Some(regime);
269            }
270        }
271
272        transitions
273    }
274}
275
276#[cfg(feature = "spectral")]
277#[derive(Debug, Clone)]
278pub(crate) struct DeltaResTrace {
279    pub info: DeltaResInfo,
280    pub beta_std: f32,
281    pub spatial_determinant_mean: f32,
282    pub lifted_determinant_mean: f32,
283    pub k_direction_entropy: f32,
284    pub k_direction: Vec<f32>,
285}
286
287#[cfg(feature = "spectral")]
288impl DeltaResTrace {
289    pub fn layer_spectral(self, layer_idx: usize, sublayer: Sublayer) -> LayerSpectralInfo {
290        LayerSpectralInfo::new(
291            layer_idx,
292            sublayer,
293            self.info.beta_mean,
294            self.beta_std,
295            self.info.k_eigenvalue_mean,
296            self.spatial_determinant_mean,
297            self.lifted_determinant_mean,
298            self.info.regime,
299            self.info.correction_norm,
300            self.k_direction_entropy,
301            self.k_direction,
302        )
303    }
304}
305
306#[cfg(feature = "spectral")]
307#[derive(Debug, Clone)]
308pub(crate) struct LayerTrace {
309    pub attention: DeltaResTrace,
310    pub mlp: DeltaResTrace,
311}
312
313#[cfg(feature = "spectral")]
314impl LayerTrace {
315    pub fn diagnostics(&self) -> LayerDiagnostics {
316        LayerDiagnostics {
317            attention: self.attention.info.clone(),
318            mlp: self.mlp.info.clone(),
319        }
320    }
321
322    pub fn into_spectral_layers(self, layer_idx: usize) -> [LayerSpectralInfo; 2] {
323        [
324            self.attention
325                .layer_spectral(layer_idx, Sublayer::Attention),
326            self.mlp.layer_spectral(layer_idx, Sublayer::Mlp),
327        ]
328    }
329}
330
331#[cfg(feature = "spectral")]
332pub(crate) fn summarize_layer_traces(
333    layer_traces: Vec<LayerTrace>,
334) -> (ModelDiagnostics, SpectralDiagnostics) {
335    let mut diagnostics = Vec::with_capacity(layer_traces.len());
336    let mut spectral_layers = Vec::with_capacity(layer_traces.len() * 2);
337
338    for (layer_idx, trace) in layer_traces.into_iter().enumerate() {
339        diagnostics.push(trace.diagnostics());
340        let [attention, mlp] = trace.into_spectral_layers(layer_idx);
341        spectral_layers.push(attention);
342        spectral_layers.push(mlp);
343    }
344
345    (
346        ModelDiagnostics {
347            layers: diagnostics,
348        },
349        SpectralDiagnostics::from_layers(spectral_layers),
350    )
351}