1use crate::moments;
16
17#[derive(Debug, Clone)]
19pub struct TransformerConfig {
20 pub n_layers: usize,
21 pub hidden_dim: f64,
22 pub n_samples: f64,
23 pub weight_std: f64,
24 pub learning_rate: f64,
25}
26
27#[derive(Debug, Clone, PartialEq)]
29pub struct InitSuggestion {
30 pub suggested_std: f64,
31 pub suggested_lr_scale: f64,
32 pub condition_number: f64,
33 pub tail_mass: f64,
34}
35
36#[derive(Debug, Clone, PartialEq)]
38pub struct RegularizerSuggestion {
39 pub lambda_reg: f64,
40 pub spectral_radius: f64,
41 pub outlier_fraction: f64,
42 pub stability_score: f64,
43}
44
45pub fn predict_combined_distribution(
57 layer_cumulants: &[f64],
58 n_layers: usize,
59 combined_cumulants: &mut [f64],
60) {
61 let n = layer_cumulants.len().min(combined_cumulants.len());
62 for k in 0..n {
63 combined_cumulants[k] = n_layers as f64 * layer_cumulants[k];
64 }
65}
66
67pub fn suggest_regularizer(
77 moments: &[f64],
78 config: &TransformerConfig,
79) -> RegularizerSuggestion {
80 if moments.is_empty() {
81 return RegularizerSuggestion {
82 lambda_reg: 1e-4,
83 spectral_radius: 1.0,
84 outlier_fraction: 0.0,
85 stability_score: 1.0,
86 };
87 }
88
89 let n = moments.len();
90 let mut cumulants = vec![0.0_f64; n];
91 moments::moment_to_cumulant(moments, &mut cumulants);
92
93 let variance = if n >= 2 { moments[1] } else { 1.0 };
94 let mean = if n >= 1 { moments[0] } else { 0.0 };
95 let centered_var = (variance - mean * mean).max(0.0);
96
97 let spectral_radius = centered_var.sqrt() * 3.0;
98
99 let lambda = (config.hidden_dim / config.n_samples.max(1.0))
101 .max(0.01)
102 .min(4.0);
103
104 let mp_b = (1.0 + lambda.sqrt()) * (1.0 + lambda.sqrt());
105 let mp_scale = mp_b.sqrt().max(1.0);
106 let ratio = spectral_radius / mp_scale;
107
108 let outlier_fraction = if ratio > 1.0 {
109 ((ratio - 1.0) / ratio).min(1.0)
110 } else {
111 0.0
112 };
113
114 let kurtosis_excess = if n >= 4 && centered_var > 1e-12 {
116 let m4 = moments[3];
117 let m2 = centered_var;
118 (m4 / (m2 * m2) - 3.0).max(0.0)
119 } else {
120 0.0
121 };
122
123 let lambda_reg = 1e-4
124 * (1.0 + kurtosis_excess * 0.5)
125 * (1.0 + outlier_fraction)
126 * (config.n_layers as f64).sqrt();
127
128 let stability_score = (1.0 / (1.0 + 0.1 * kurtosis_excess + outlier_fraction))
129 .clamp(0.0, 1.0);
130
131 RegularizerSuggestion {
132 lambda_reg,
133 spectral_radius,
134 outlier_fraction,
135 stability_score,
136 }
137}
138
139pub fn suggest_initialization(config: &TransformerConfig) -> InitSuggestion {
147 if config.hidden_dim <= 0.0 || config.n_samples <= 0.0 || config.n_layers == 0 {
148 return InitSuggestion {
149 suggested_std: 0.02,
150 suggested_lr_scale: 1.0,
151 condition_number: 1.0,
152 tail_mass: 0.0,
153 };
154 }
155
156 let d = config.hidden_dim;
157 let n = config.n_samples.max(d);
158 let lambda = d / n;
159
160 let base_std = 1.0 / d.sqrt();
165 let suggested_std = (base_std / (config.n_layers as f64).sqrt())
166 .clamp(1e-6, 1.0);
167
168 let mp_upper = (1.0 + lambda.sqrt()) * (1.0 + lambda.sqrt());
173 let mp_lower = (1.0 - lambda.sqrt()).max(1e-12) * (1.0 - lambda.sqrt()).max(1e-12);
174
175 let condition_number = mp_upper / mp_lower.max(1e-12);
176
177 let suggested_lr_scale =
179 (1.0 / (1.0 + 0.1 * mp_upper.sqrt())).clamp(0.01, 1.0);
180
181 InitSuggestion {
182 suggested_std,
183 suggested_lr_scale,
184 condition_number,
185 tail_mass: 0.05, }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_combined_distribution() {
195 let layer_cum = vec![1.0, 0.5, 0.1, 0.01];
196 let mut combined = vec![0.0_f64; 4];
197
198 predict_combined_distribution(&layer_cum, 4, &mut combined);
199
200 assert!((combined[0] - 4.0).abs() < 1e-12);
201 assert!((combined[1] - 2.0).abs() < 1e-12);
202 assert!((combined[2] - 0.4).abs() < 1e-12);
203 assert!((combined[3] - 0.04).abs() < 1e-12);
204
205 let z = 0.5;
207 let r_single = crate::r_transform::from_cumulants(&layer_cum, z);
208 let r_combined = crate::r_transform::from_cumulants(&combined, z);
209 assert!((r_combined - 4.0 * r_single).abs() < 1e-12);
210 }
211
212 #[test]
213 fn test_suggest_initialization() {
214 let config = TransformerConfig {
215 n_layers: 12,
216 hidden_dim: 768.0,
217 n_samples: 10000.0,
218 weight_std: 0.02,
219 learning_rate: 0.001,
220 };
221
222 let sug = suggest_initialization(&config);
223
224 assert!(sug.suggested_std > 0.0, "std = {}", sug.suggested_std);
225 assert!(sug.suggested_std.is_finite(), "std is NaN/Inf");
226 assert!(sug.suggested_std < 1.0, "std = {} (too large)", sug.suggested_std);
227 assert!(sug.suggested_std > 1e-10, "std = {} (too small)", sug.suggested_std);
228 assert!(sug.suggested_lr_scale > 0.0, "lr_scale = {}", sug.suggested_lr_scale);
229 assert!(sug.suggested_lr_scale.is_finite(), "lr_scale is NaN/Inf");
230 }
231
232 #[test]
233 fn test_suggest_initialization_null_config_defaults() {
234 let config = TransformerConfig {
235 n_layers: 0,
236 hidden_dim: 0.0,
237 n_samples: 0.0,
238 weight_std: 0.02,
239 learning_rate: 0.001,
240 };
241
242 let sug = suggest_initialization(&config);
243
244 assert!((sug.suggested_std - 0.02).abs() < 1e-12);
245 assert!((sug.suggested_lr_scale - 1.0).abs() < 1e-12);
246 assert!((sug.condition_number - 1.0).abs() < 1e-12);
247 assert!((sug.tail_mass - 0.0).abs() < 1e-12);
248 }
249
250 #[test]
251 fn test_suggest_regularizer() {
252 let moments = vec![0.0, 1.0, 0.0, 3.0, 0.0, 15.0];
253
254 let config = TransformerConfig {
255 n_layers: 6,
256 hidden_dim: 512.0,
257 n_samples: 5000.0,
258 weight_std: 0.02,
259 learning_rate: 0.001,
260 };
261
262 let sug = suggest_regularizer(&moments, &config);
263
264 assert!(sug.lambda_reg > 0.0, "lambda_reg = {}", sug.lambda_reg);
265 assert!(sug.lambda_reg.is_finite(), "lambda_reg is NaN/Inf");
266 assert!(
267 (0.0..=1.0).contains(&sug.stability_score),
268 "stability_score out of [0, 1]: {}",
269 sug.stability_score
270 );
271 assert!(
272 (0.0..=1.0).contains(&sug.outlier_fraction),
273 "outlier_fraction out of [0, 1]: {}",
274 sug.outlier_fraction
275 );
276 }
277
278 #[test]
279 fn test_suggest_regularizer_non_gaussian() {
280 let moments = vec![0.0, 1.0, 0.0, 10.0, 0.0, 100.0];
282
283 let config = TransformerConfig {
284 n_layers: 6,
285 hidden_dim: 512.0,
286 n_samples: 5000.0,
287 weight_std: 0.02,
288 learning_rate: 0.001,
289 };
290
291 let sug = suggest_regularizer(&moments, &config);
292
293 assert!(sug.lambda_reg > 0.0);
294 assert!(sug.stability_score > 0.0 && sug.stability_score <= 1.0);
295 }
296
297 #[test]
298 fn test_suggest_regularizer_empty_moments() {
299 let moments = vec![];
300
301 let config = TransformerConfig {
302 n_layers: 6,
303 hidden_dim: 512.0,
304 n_samples: 5000.0,
305 weight_std: 0.02,
306 learning_rate: 0.001,
307 };
308
309 let sug = suggest_regularizer(&moments, &config);
310
311 assert!((sug.lambda_reg - 1e-4).abs() < 1e-12);
312 assert!((sug.stability_score - 1.0).abs() < 1e-12);
313 }
314}