Skip to main content

free_probability/
prediction.rs

1//! Gradient analysis and eigenvalue-distribution prediction for deep networks.
2//!
3//! Uses free probability (R-transform, S-transform, Marchenko–Pastur law) to:
4//!
5//! - Predict the combined eigenvalue spectrum when stacking multiple layers
6//! - Suggest weight initialisation scales from Marchenko–Pastur theory
7//! - Recommend regularisation strength from eigenvalue tail behaviour
8//!
9//! ## Why this matters
10//!
11//! Xavier init, He init, and Kaiming init all correspond to setting
12//! the weight variance so that the Marchenko–Pastur spectral bulk stays
13//! within a bounded range.  Free probability is the unifying language.
14
15use crate::moments;
16
17/// Parameters describing a transformer-like architecture.
18#[derive(Debug, Clone)]
19pub struct TransformerConfig {
20    pub n_layers: usize,
21    pub hidden_dim: f64,
22    pub n_samples: f64,
23    pub weight_std: f64,
24    pub learning_rate: f64,
25}
26
27/// Initialisation suggestions derived from Marchenko–Pastur theory.
28#[derive(Debug, Clone, PartialEq)]
29pub struct InitSuggestion {
30    pub suggested_std: f64,
31    pub suggested_lr_scale: f64,
32    pub condition_number: f64,
33    pub tail_mass: f64,
34}
35
36/// Regularisation suggestions from eigenvalue tail analysis.
37#[derive(Debug, Clone, PartialEq)]
38pub struct RegularizerSuggestion {
39    pub lambda_reg: f64,
40    pub spectral_radius: f64,
41    pub outlier_fraction: f64,
42    pub stability_score: f64,
43}
44
45/// Predict the combined eigenvalue distribution when stacking `n_layers`
46/// freely independent layers, each with the same eigenvalue cumulants.
47///
48/// Because `R_{combined}(z) = Σ R_i(z)`, if all layers share cumulants:
49///
50/// ```text
51/// κ_combined[k] = n_layers · κ_single[k]
52/// ```
53///
54/// This lets us predict the spectrum of a deep network's covariance from
55/// a single-layer measurement without ever forming the full matrix.
56pub fn predict_combined_distribution(
57    layer_cumulants: &[f64],
58    n_layers: usize,
59    combined_cumulants: &mut [f64],
60) {
61    let n = layer_cumulants.len().min(combined_cumulants.len());
62    for k in 0..n {
63        combined_cumulants[k] = n_layers as f64 * layer_cumulants[k];
64    }
65}
66
67/// Suggest a regularisation strength from eigenvalue moment data.
68///
69/// Uses the following heuristic based on free probability:
70///
71/// - Compute free cumulants from moments
72/// - Estimate the spectral radius as `3σ` (3-sigma rule for eigenvalues)
73/// - Compare with the Marchenko–Pastur upper edge to detect outliers
74/// - Adjust λ_reg proportionally to kurtosis excess and outlier fraction
75/// - Scale by `√(n_layers)` (deeper → more cautious)
76pub fn suggest_regularizer(
77    moments: &[f64],
78    config: &TransformerConfig,
79) -> RegularizerSuggestion {
80    if moments.is_empty() {
81        return RegularizerSuggestion {
82            lambda_reg: 1e-4,
83            spectral_radius: 1.0,
84            outlier_fraction: 0.0,
85            stability_score: 1.0,
86        };
87    }
88
89    let n = moments.len();
90    let mut cumulants = vec![0.0_f64; n];
91    moments::moment_to_cumulant(moments, &mut cumulants);
92
93    let variance = if n >= 2 { moments[1] } else { 1.0 };
94    let mean = if n >= 1 { moments[0] } else { 0.0 };
95    let centered_var = (variance - mean * mean).max(0.0);
96
97    let spectral_radius = centered_var.sqrt() * 3.0;
98
99    // Compare with Marchenko–Pastur upper edge
100    let lambda = (config.hidden_dim / config.n_samples.max(1.0))
101        .max(0.01)
102        .min(4.0);
103
104    let mp_b = (1.0 + lambda.sqrt()) * (1.0 + lambda.sqrt());
105    let mp_scale = mp_b.sqrt().max(1.0);
106    let ratio = spectral_radius / mp_scale;
107
108    let outlier_fraction = if ratio > 1.0 {
109        ((ratio - 1.0) / ratio).min(1.0)
110    } else {
111        0.0
112    };
113
114    // Kurtosis excess = m₄/m₂² - 3 (for centred distribution)
115    let kurtosis_excess = if n >= 4 && centered_var > 1e-12 {
116        let m4 = moments[3];
117        let m2 = centered_var;
118        (m4 / (m2 * m2) - 3.0).max(0.0)
119    } else {
120        0.0
121    };
122
123    let lambda_reg = 1e-4
124        * (1.0 + kurtosis_excess * 0.5)
125        * (1.0 + outlier_fraction)
126        * (config.n_layers as f64).sqrt();
127
128    let stability_score = (1.0 / (1.0 + 0.1 * kurtosis_excess + outlier_fraction))
129        .clamp(0.0, 1.0);
130
131    RegularizerSuggestion {
132        lambda_reg,
133        spectral_radius,
134        outlier_fraction,
135        stability_score,
136    }
137}
138
139/// Suggest a weight initialisation standard deviation from Marchenko–Pastur theory.
140///
141/// For a random weight matrix of size `d × d` (hidden_dim × hidden_dim),
142/// the MP law with ratio `λ = d / n_samples` constrains the spectral bulk.
143///
144/// Xavier/He-like scaling: `σ = 1/√(hidden_dim)` gives eigenvalues in `[0, 4]`.
145/// Depth scaling: divide by `√(n_layers)` for deeper nets.
146pub fn suggest_initialization(config: &TransformerConfig) -> InitSuggestion {
147    if config.hidden_dim <= 0.0 || config.n_samples <= 0.0 || config.n_layers == 0 {
148        return InitSuggestion {
149            suggested_std: 0.02,
150            suggested_lr_scale: 1.0,
151            condition_number: 1.0,
152            tail_mass: 0.0,
153        };
154    }
155
156    let d = config.hidden_dim;
157    let n = config.n_samples.max(d);
158    let lambda = d / n;
159
160    // Xavier/He init: base_std = √(2/d) ≈ O(1/√d).
161    // MP theory says: for W_{ij} ~ N(0, σ²/d), eigenvalues of WWᵀ/n
162    // have support [σ²(1-√λ)², σ²(1+√λ)²].  Setting σ=1 gives support [0,4]
163    // at λ=1 for square matrices.
164    let base_std = 1.0 / d.sqrt();
165    let suggested_std = (base_std / (config.n_layers as f64).sqrt())
166        .clamp(1e-6, 1.0);
167
168    // MP upper / lower edges with the suggested init
169    // For W entries with std σ/√d, λ = d/n:
170    //   support = [ σ²(1-√λ)² , σ²(1+√λ)² ]
171    // where the eigenvalue scaling is independent of d.
172    let mp_upper = (1.0 + lambda.sqrt()) * (1.0 + lambda.sqrt());
173    let mp_lower = (1.0 - lambda.sqrt()).max(1e-12) * (1.0 - lambda.sqrt()).max(1e-12);
174
175    let condition_number = mp_upper / mp_lower.max(1e-12);
176
177    // Learning rate inversely proportional to spectral radius
178    let suggested_lr_scale =
179        (1.0 / (1.0 + 0.1 * mp_upper.sqrt())).clamp(0.01, 1.0);
180
181    InitSuggestion {
182        suggested_std,
183        suggested_lr_scale,
184        condition_number,
185        tail_mass: 0.05, // heuristic default
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_combined_distribution() {
195        let layer_cum = vec![1.0, 0.5, 0.1, 0.01];
196        let mut combined = vec![0.0_f64; 4];
197
198        predict_combined_distribution(&layer_cum, 4, &mut combined);
199
200        assert!((combined[0] - 4.0).abs() < 1e-12);
201        assert!((combined[1] - 2.0).abs() < 1e-12);
202        assert!((combined[2] - 0.4).abs() < 1e-12);
203        assert!((combined[3] - 0.04).abs() < 1e-12);
204
205        // Verify R-transform additivity
206        let z = 0.5;
207        let r_single = crate::r_transform::from_cumulants(&layer_cum, z);
208        let r_combined = crate::r_transform::from_cumulants(&combined, z);
209        assert!((r_combined - 4.0 * r_single).abs() < 1e-12);
210    }
211
212    #[test]
213    fn test_suggest_initialization() {
214        let config = TransformerConfig {
215            n_layers: 12,
216            hidden_dim: 768.0,
217            n_samples: 10000.0,
218            weight_std: 0.02,
219            learning_rate: 0.001,
220        };
221
222        let sug = suggest_initialization(&config);
223
224        assert!(sug.suggested_std > 0.0, "std = {}", sug.suggested_std);
225        assert!(sug.suggested_std.is_finite(), "std is NaN/Inf");
226        assert!(sug.suggested_std < 1.0, "std = {} (too large)", sug.suggested_std);
227        assert!(sug.suggested_std > 1e-10, "std = {} (too small)", sug.suggested_std);
228        assert!(sug.suggested_lr_scale > 0.0, "lr_scale = {}", sug.suggested_lr_scale);
229        assert!(sug.suggested_lr_scale.is_finite(), "lr_scale is NaN/Inf");
230    }
231
232    #[test]
233    fn test_suggest_initialization_null_config_defaults() {
234        let config = TransformerConfig {
235            n_layers: 0,
236            hidden_dim: 0.0,
237            n_samples: 0.0,
238            weight_std: 0.02,
239            learning_rate: 0.001,
240        };
241
242        let sug = suggest_initialization(&config);
243
244        assert!((sug.suggested_std - 0.02).abs() < 1e-12);
245        assert!((sug.suggested_lr_scale - 1.0).abs() < 1e-12);
246        assert!((sug.condition_number - 1.0).abs() < 1e-12);
247        assert!((sug.tail_mass - 0.0).abs() < 1e-12);
248    }
249
250    #[test]
251    fn test_suggest_regularizer() {
252        let moments = vec![0.0, 1.0, 0.0, 3.0, 0.0, 15.0];
253
254        let config = TransformerConfig {
255            n_layers: 6,
256            hidden_dim: 512.0,
257            n_samples: 5000.0,
258            weight_std: 0.02,
259            learning_rate: 0.001,
260        };
261
262        let sug = suggest_regularizer(&moments, &config);
263
264        assert!(sug.lambda_reg > 0.0, "lambda_reg = {}", sug.lambda_reg);
265        assert!(sug.lambda_reg.is_finite(), "lambda_reg is NaN/Inf");
266        assert!(
267            (0.0..=1.0).contains(&sug.stability_score),
268            "stability_score out of [0, 1]: {}",
269            sug.stability_score
270        );
271        assert!(
272            (0.0..=1.0).contains(&sug.outlier_fraction),
273            "outlier_fraction out of [0, 1]: {}",
274            sug.outlier_fraction
275        );
276    }
277
278    #[test]
279    fn test_suggest_regularizer_non_gaussian() {
280        // Heavy-tailed distribution
281        let moments = vec![0.0, 1.0, 0.0, 10.0, 0.0, 100.0];
282
283        let config = TransformerConfig {
284            n_layers: 6,
285            hidden_dim: 512.0,
286            n_samples: 5000.0,
287            weight_std: 0.02,
288            learning_rate: 0.001,
289        };
290
291        let sug = suggest_regularizer(&moments, &config);
292
293        assert!(sug.lambda_reg > 0.0);
294        assert!(sug.stability_score > 0.0 && sug.stability_score <= 1.0);
295    }
296
297    #[test]
298    fn test_suggest_regularizer_empty_moments() {
299        let moments = vec![];
300
301        let config = TransformerConfig {
302            n_layers: 6,
303            hidden_dim: 512.0,
304            n_samples: 5000.0,
305            weight_std: 0.02,
306            learning_rate: 0.001,
307        };
308
309        let sug = suggest_regularizer(&moments, &config);
310
311        assert!((sug.lambda_reg - 1e-4).abs() < 1e-12);
312        assert!((sug.stability_score - 1.0).abs() < 1e-12);
313    }
314}