1use std::collections::HashMap;
8
9#[allow(dead_code)]
11pub struct ParamSpaceConfig {
12 pub correlation_threshold: f32,
14 pub variance_threshold: f32,
16 pub n_keep: Option<usize>,
18}
19
20impl Default for ParamSpaceConfig {
21 fn default() -> Self {
22 Self {
23 correlation_threshold: 0.95,
24 variance_threshold: 1e-4,
25 n_keep: None,
26 }
27 }
28}
29
30#[allow(dead_code)]
32pub struct ParamSpaceAnalysis {
33 pub original_count: usize,
34 pub kept_params: Vec<String>,
35 pub removed_params: Vec<String>,
36 pub correlation_matrix: Vec<Vec<f32>>,
38 pub variances: Vec<f32>,
39}
40
41#[allow(dead_code)]
45pub fn param_variance(values: &[f32]) -> f32 {
46 let n = values.len();
47 if n == 0 {
48 return 0.0;
49 }
50 let mean = values.iter().sum::<f32>() / n as f32;
51 values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n as f32
52}
53
54#[allow(dead_code)]
56pub fn param_correlation(a: &[f32], b: &[f32]) -> f32 {
57 let n = a.len().min(b.len());
58 if n == 0 {
59 return 0.0;
60 }
61 let mean_a = a[..n].iter().sum::<f32>() / n as f32;
62 let mean_b = b[..n].iter().sum::<f32>() / n as f32;
63 let mut cov = 0.0_f32;
64 let mut var_a = 0.0_f32;
65 let mut var_b = 0.0_f32;
66 for i in 0..n {
67 let da = a[i] - mean_a;
68 let db = b[i] - mean_b;
69 cov += da * db;
70 var_a += da * da;
71 var_b += db * db;
72 }
73 let denom = (var_a * var_b).sqrt();
74 if denom < 1e-12 {
75 0.0
76 } else {
77 cov / denom
78 }
79}
80
81#[allow(dead_code)]
84pub fn build_correlation_matrix(samples: &[Vec<f32>]) -> Vec<Vec<f32>> {
85 let n = samples.len();
86 let mut mat = vec![vec![0.0_f32; n]; n];
87 for i in 0..n {
88 for j in 0..n {
89 if i == j {
90 mat[i][j] = 1.0;
91 } else {
92 mat[i][j] = param_correlation(&samples[i], &samples[j]);
93 }
94 }
95 }
96 mat
97}
98
99#[allow(dead_code)]
102pub fn find_redundant_params(corr: &[Vec<f32>], names: &[String], threshold: f32) -> Vec<String> {
103 let n = names.len();
104 let mut removed = vec![false; n];
108 for i in 0..n {
111 if removed[i] {
112 continue;
113 }
114 for j in (i + 1)..n {
115 if removed[j] {
116 continue;
117 }
118 if corr[i][j].abs() > threshold {
119 removed[j] = true;
121 }
122 }
123 }
124 names
125 .iter()
126 .enumerate()
127 .filter(|(i, _)| removed[*i])
128 .map(|(_, name)| name.clone())
129 .collect()
130}
131
132#[allow(dead_code)]
134pub fn reduce_param_set(
135 names: &[String],
136 samples: &[HashMap<String, f32>],
137 cfg: &ParamSpaceConfig,
138) -> Vec<String> {
139 if names.is_empty() || samples.is_empty() {
140 return names.to_vec();
141 }
142
143 let param_values: Vec<Vec<f32>> = names
145 .iter()
146 .map(|n| samples.iter().map(|s| *s.get(n).unwrap_or(&0.0)).collect())
147 .collect();
148
149 let variances: Vec<f32> = param_values.iter().map(|v| param_variance(v)).collect();
150
151 let mut kept: Vec<usize> = (0..names.len())
153 .filter(|&i| variances[i] >= cfg.variance_threshold)
154 .collect();
155
156 let kept_values: Vec<Vec<f32>> = kept.iter().map(|&i| param_values[i].clone()).collect();
158 let corr = build_correlation_matrix(&kept_values);
159 let kept_names: Vec<String> = kept.iter().map(|&i| names[i].clone()).collect();
160 let redundant = find_redundant_params(&corr, &kept_names, cfg.correlation_threshold);
161 let redundant_set: std::collections::HashSet<&String> = redundant.iter().collect();
162 kept.retain(|&i| !redundant_set.contains(&names[i]));
163
164 if let Some(n_keep) = cfg.n_keep {
166 kept.sort_by(|&a, &b| {
167 variances[b]
168 .partial_cmp(&variances[a])
169 .unwrap_or(std::cmp::Ordering::Equal)
170 });
171 kept.truncate(n_keep);
172 }
173
174 kept.iter().map(|&i| names[i].clone()).collect()
175}
176
177#[allow(dead_code)]
180pub fn normalize_param_samples(
181 samples: &mut [HashMap<String, f32>],
182) -> HashMap<String, (f32, f32)> {
183 if samples.is_empty() {
184 return HashMap::new();
185 }
186
187 let names: Vec<String> = samples[0].keys().cloned().collect();
189 let mut ranges: HashMap<String, (f32, f32)> = HashMap::new();
190
191 for name in &names {
192 let vals: Vec<f32> = samples
193 .iter()
194 .map(|s| *s.get(name).unwrap_or(&0.0))
195 .collect();
196 let min = vals.iter().cloned().fold(f32::MAX, f32::min);
197 let max = vals.iter().cloned().fold(f32::MIN, f32::max);
198 ranges.insert(name.clone(), (min, max));
199 }
200
201 for s in samples.iter_mut() {
202 for name in &names {
203 let (min, max) = ranges[name];
204 let span = max - min;
205 if span > 1e-12 {
206 let v = s.entry(name.clone()).or_insert(0.0);
207 *v = (*v - min) / span;
208 } else if let Some(v) = s.get_mut(name) {
209 *v = 0.0;
210 }
211 }
212 }
213
214 ranges
215}
216
217#[allow(dead_code)]
219pub fn param_importance_score(name: &str, samples: &[HashMap<String, f32>]) -> f32 {
220 if samples.is_empty() {
221 return 0.0;
222 }
223 let names: Vec<String> = samples[0].keys().cloned().collect();
224 let variances: Vec<f32> = names
225 .iter()
226 .map(|n| {
227 let vals: Vec<f32> = samples.iter().map(|s| *s.get(n).unwrap_or(&0.0)).collect();
228 param_variance(&vals)
229 })
230 .collect();
231 let max_var = variances.iter().cloned().fold(0.0_f32, f32::max);
232 if max_var < 1e-12 {
233 return 0.0;
234 }
235 let my_vals: Vec<f32> = samples
236 .iter()
237 .map(|s| *s.get(name).unwrap_or(&0.0))
238 .collect();
239 param_variance(&my_vals) / max_var
240}
241
242#[allow(dead_code)]
244pub fn analyze_param_space(
245 param_names: &[String],
246 param_samples: &[HashMap<String, f32>],
247) -> ParamSpaceAnalysis {
248 let cfg = ParamSpaceConfig::default();
249 let original_count = param_names.len();
250
251 let param_values: Vec<Vec<f32>> = param_names
252 .iter()
253 .map(|n| {
254 param_samples
255 .iter()
256 .map(|s| *s.get(n).unwrap_or(&0.0))
257 .collect()
258 })
259 .collect();
260
261 let variances: Vec<f32> = param_values.iter().map(|v| param_variance(v)).collect();
262 let correlation_matrix = build_correlation_matrix(¶m_values);
263
264 let kept_names = reduce_param_set(param_names, param_samples, &cfg);
265 let kept_set: std::collections::HashSet<&String> = kept_names.iter().collect();
266 let removed_params: Vec<String> = param_names
267 .iter()
268 .filter(|n| !kept_set.contains(n))
269 .cloned()
270 .collect();
271
272 ParamSpaceAnalysis {
273 original_count,
274 kept_params: kept_names,
275 removed_params,
276 correlation_matrix,
277 variances,
278 }
279}
280
281#[cfg(test)]
284mod tests {
285 use super::*;
286
287 fn make_samples(data: &[(&str, Vec<f32>)]) -> Vec<HashMap<String, f32>> {
288 if data.is_empty() {
289 return vec![];
290 }
291 let n = data[0].1.len();
292 (0..n)
293 .map(|i| {
294 data.iter()
295 .map(|(name, vals)| (name.to_string(), vals[i]))
296 .collect()
297 })
298 .collect()
299 }
300
301 #[test]
303 fn test_param_variance_formula() {
304 let v = vec![2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
305 let var = param_variance(&v);
306 assert!((var - 4.0).abs() < 1e-4, "expected ~4.0 got {var}");
308 }
309
310 #[test]
312 fn test_correlation_perfect_positive() {
313 let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
314 let b = vec![2.0_f32, 4.0, 6.0, 8.0, 10.0];
315 let r = param_correlation(&a, &b);
316 assert!((r - 1.0).abs() < 1e-5, "expected 1.0 got {r}");
317 }
318
319 #[test]
321 fn test_correlation_perfect_negative() {
322 let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
323 let b = vec![5.0_f32, 4.0, 3.0, 2.0, 1.0];
324 let r = param_correlation(&a, &b);
325 assert!((r + 1.0).abs() < 1e-5, "expected -1.0 got {r}");
326 }
327
328 #[test]
330 fn test_correlation_uncorrelated() {
331 let a = vec![1.0_f32, 1.0, 1.0, 1.0];
332 let b = vec![1.0_f32, 2.0, 3.0, 4.0];
333 let r = param_correlation(&a, &b);
335 assert!(r.abs() < 1e-5, "expected ~0 got {r}");
336 }
337
338 #[test]
340 fn test_correlation_matrix_diagonal() {
341 let samples = vec![
342 vec![1.0_f32, 2.0, 3.0],
343 vec![4.0_f32, 5.0, 6.0],
344 vec![7.0_f32, 8.0, 9.0],
345 ];
346 let mat = build_correlation_matrix(&samples);
347 for (i, row) in mat.iter().enumerate().take(3) {
348 assert!((row[i] - 1.0).abs() < 1e-5, "diagonal[{i}] != 1");
349 }
350 }
351
352 #[test]
354 fn test_find_redundant_removes_correlated() {
355 let corr = vec![vec![1.0, 0.99], vec![0.99, 1.0]];
357 let names = vec!["a".to_string(), "b".to_string()];
358 let redundant = find_redundant_params(&corr, &names, 0.95);
359 assert_eq!(redundant.len(), 1);
360 assert_eq!(redundant[0], "b");
361 }
362
363 #[test]
365 fn test_find_redundant_keeps_uncorrelated() {
366 let corr = vec![vec![1.0, 0.1], vec![0.1, 1.0]];
367 let names = vec!["a".to_string(), "b".to_string()];
368 let redundant = find_redundant_params(&corr, &names, 0.95);
369 assert!(redundant.is_empty());
370 }
371
372 #[test]
374 fn test_reduce_param_set_n_keep() {
375 let names: Vec<String> = (0..4).map(|i| format!("p{i}")).collect();
378 let samples = make_samples(&[
379 ("p0", vec![0.0, 10.0, 0.0, 10.0]),
381 ("p1", vec![0.0, 0.0, 10.0, 10.0]),
383 ("p2", vec![0.1, 0.2, 0.1, 0.2]),
385 ("p3", vec![0.01, 0.02, 0.01, 0.02]),
387 ]);
388 let cfg = ParamSpaceConfig {
389 n_keep: Some(2),
390 correlation_threshold: 1.0, variance_threshold: 0.0,
392 };
393 let kept = reduce_param_set(&names, &samples, &cfg);
394 assert_eq!(kept.len(), 2, "expected 2 kept params, got {}", kept.len());
395 }
396
397 #[test]
399 fn test_normalize_param_samples_range() {
400 let mut samples = make_samples(&[("x", vec![1.0, 5.0, 3.0])]);
401 normalize_param_samples(&mut samples);
402 let vals: Vec<f32> = samples
403 .iter()
404 .map(|s| *s.get("x").expect("should succeed"))
405 .collect();
406 let min = vals.iter().cloned().fold(f32::MAX, f32::min);
407 let max = vals.iter().cloned().fold(f32::MIN, f32::max);
408 assert!((min - 0.0).abs() < 1e-5, "min should be 0, got {min}");
409 assert!((max - 1.0).abs() < 1e-5, "max should be 1, got {max}");
410 }
411
412 #[test]
414 fn test_analyze_removes_zero_variance() {
415 let names = vec!["vary".to_string(), "const".to_string()];
416 let samples = make_samples(&[
417 ("vary", vec![1.0, 2.0, 3.0, 4.0]),
418 ("const", vec![5.0, 5.0, 5.0, 5.0]),
419 ]);
420 let analysis = analyze_param_space(&names, &samples);
421 assert!(
422 analysis.removed_params.contains(&"const".to_string()),
423 "zero-variance param should be removed"
424 );
425 }
426
427 #[test]
429 fn test_original_count() {
430 let names: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
431 let samples = make_samples(&[
432 ("a", vec![1.0, 2.0]),
433 ("b", vec![3.0, 4.0]),
434 ("c", vec![5.0, 6.0]),
435 ]);
436 let analysis = analyze_param_space(&names, &samples);
437 assert_eq!(analysis.original_count, 3);
438 }
439
440 #[test]
442 fn test_kept_plus_removed_eq_original() {
443 let names: Vec<String> = (0..4).map(|i| format!("p{i}")).collect();
444 let samples = make_samples(&[
445 ("p0", vec![1.0, 2.0, 3.0]),
446 ("p1", vec![1.0, 1.0, 1.0]), ("p2", vec![4.0, 5.0, 6.0]),
448 ("p3", vec![7.0, 8.0, 9.0]),
449 ]);
450 let analysis = analyze_param_space(&names, &samples);
451 assert_eq!(
452 analysis.kept_params.len() + analysis.removed_params.len(),
453 analysis.original_count
454 );
455 }
456
457 #[test]
459 fn test_param_importance_score_max() {
460 let samples = make_samples(&[
461 ("big", vec![0.0, 10.0, 20.0, 30.0]),
462 ("small", vec![0.0, 0.1, 0.2, 0.3]),
463 ]);
464 let score = param_importance_score("big", &samples);
465 assert!(
466 (score - 1.0).abs() < 1e-4,
467 "highest-variance param should score 1.0, got {score}"
468 );
469 }
470
471 #[test]
473 fn test_normalize_returns_range_map() {
474 let mut samples = make_samples(&[("y", vec![2.0, 4.0, 6.0])]);
475 let ranges = normalize_param_samples(&mut samples);
476 let (min, max) = ranges["y"];
477 assert!((min - 2.0).abs() < 1e-5);
478 assert!((max - 6.0).abs() < 1e-5);
479 }
480}