Skip to main content

oxihuman_morph/
param_space_optimizer.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Optimize morph target parameter space via dimensionality reduction and
5//! redundancy elimination.
6
7use std::collections::HashMap;
8
9/// Configuration for parameter-space optimization.
10#[allow(dead_code)]
11pub struct ParamSpaceConfig {
12    /// Remove parameters whose pairwise Pearson correlation exceeds this (default 0.95).
13    pub correlation_threshold: f32,
14    /// Remove parameters with variance below this (default 1e-4).
15    pub variance_threshold: f32,
16    /// Keep at most N parameters ranked by variance.
17    pub n_keep: Option<usize>,
18}
19
20impl Default for ParamSpaceConfig {
21    fn default() -> Self {
22        Self {
23            correlation_threshold: 0.95,
24            variance_threshold: 1e-4,
25            n_keep: None,
26        }
27    }
28}
29
30/// Result of a parameter-space analysis pass.
31#[allow(dead_code)]
32pub struct ParamSpaceAnalysis {
33    pub original_count: usize,
34    pub kept_params: Vec<String>,
35    pub removed_params: Vec<String>,
36    /// `correlation_matrix[i][j]` — n_params × n_params.
37    pub correlation_matrix: Vec<Vec<f32>>,
38    pub variances: Vec<f32>,
39}
40
41// ── core statistics ───────────────────────────────────────────────────────────
42
43/// Population variance of a slice of values.
44#[allow(dead_code)]
45pub fn param_variance(values: &[f32]) -> f32 {
46    let n = values.len();
47    if n == 0 {
48        return 0.0;
49    }
50    let mean = values.iter().sum::<f32>() / n as f32;
51    values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / n as f32
52}
53
54/// Pearson correlation coefficient between two equal-length slices.
55#[allow(dead_code)]
56pub fn param_correlation(a: &[f32], b: &[f32]) -> f32 {
57    let n = a.len().min(b.len());
58    if n == 0 {
59        return 0.0;
60    }
61    let mean_a = a[..n].iter().sum::<f32>() / n as f32;
62    let mean_b = b[..n].iter().sum::<f32>() / n as f32;
63    let mut cov = 0.0_f32;
64    let mut var_a = 0.0_f32;
65    let mut var_b = 0.0_f32;
66    for i in 0..n {
67        let da = a[i] - mean_a;
68        let db = b[i] - mean_b;
69        cov += da * db;
70        var_a += da * da;
71        var_b += db * db;
72    }
73    let denom = (var_a * var_b).sqrt();
74    if denom < 1e-12 {
75        0.0
76    } else {
77        cov / denom
78    }
79}
80
81/// Build an n_params × n_params correlation matrix.
82/// `samples[i]` is a vector of all sample values for parameter i.
83#[allow(dead_code)]
84pub fn build_correlation_matrix(samples: &[Vec<f32>]) -> Vec<Vec<f32>> {
85    let n = samples.len();
86    let mut mat = vec![vec![0.0_f32; n]; n];
87    for i in 0..n {
88        for j in 0..n {
89            if i == j {
90                mat[i][j] = 1.0;
91            } else {
92                mat[i][j] = param_correlation(&samples[i], &samples[j]);
93            }
94        }
95    }
96    mat
97}
98
99/// Greedy redundancy removal: for each pair with |corr| > threshold,
100/// remove the member with lower variance.  Returns names of removed params.
101#[allow(dead_code)]
102pub fn find_redundant_params(corr: &[Vec<f32>], names: &[String], threshold: f32) -> Vec<String> {
103    let n = names.len();
104    // compute variances from the diagonal = 1, but we need actual values –
105    // caller must ensure corr is built from the same samples.
106    // We derive a simple proxy: mark by index.
107    let mut removed = vec![false; n];
108    // We need variance to break ties; re-derive from corr is not possible,
109    // so we treat index order as a proxy (earlier = higher variance by convention).
110    for i in 0..n {
111        if removed[i] {
112            continue;
113        }
114        for j in (i + 1)..n {
115            if removed[j] {
116                continue;
117            }
118            if corr[i][j].abs() > threshold {
119                // Remove j (higher index = lower variance proxy)
120                removed[j] = true;
121            }
122        }
123    }
124    names
125        .iter()
126        .enumerate()
127        .filter(|(i, _)| removed[*i])
128        .map(|(_, name)| name.clone())
129        .collect()
130}
131
132/// Return kept parameter names after applying variance and correlation filters.
133#[allow(dead_code)]
134pub fn reduce_param_set(
135    names: &[String],
136    samples: &[HashMap<String, f32>],
137    cfg: &ParamSpaceConfig,
138) -> Vec<String> {
139    if names.is_empty() || samples.is_empty() {
140        return names.to_vec();
141    }
142
143    // Gather per-param value vectors
144    let param_values: Vec<Vec<f32>> = names
145        .iter()
146        .map(|n| samples.iter().map(|s| *s.get(n).unwrap_or(&0.0)).collect())
147        .collect();
148
149    let variances: Vec<f32> = param_values.iter().map(|v| param_variance(v)).collect();
150
151    // Step 1: remove low-variance params
152    let mut kept: Vec<usize> = (0..names.len())
153        .filter(|&i| variances[i] >= cfg.variance_threshold)
154        .collect();
155
156    // Step 2: remove highly correlated (greedy)
157    let kept_values: Vec<Vec<f32>> = kept.iter().map(|&i| param_values[i].clone()).collect();
158    let corr = build_correlation_matrix(&kept_values);
159    let kept_names: Vec<String> = kept.iter().map(|&i| names[i].clone()).collect();
160    let redundant = find_redundant_params(&corr, &kept_names, cfg.correlation_threshold);
161    let redundant_set: std::collections::HashSet<&String> = redundant.iter().collect();
162    kept.retain(|&i| !redundant_set.contains(&names[i]));
163
164    // Step 3: keep top-N by variance
165    if let Some(n_keep) = cfg.n_keep {
166        kept.sort_by(|&a, &b| {
167            variances[b]
168                .partial_cmp(&variances[a])
169                .unwrap_or(std::cmp::Ordering::Equal)
170        });
171        kept.truncate(n_keep);
172    }
173
174    kept.iter().map(|&i| names[i].clone()).collect()
175}
176
177/// Min/max normalize each parameter across samples in-place.
178/// Returns `(min, max)` per parameter name.
179#[allow(dead_code)]
180pub fn normalize_param_samples(
181    samples: &mut [HashMap<String, f32>],
182) -> HashMap<String, (f32, f32)> {
183    if samples.is_empty() {
184        return HashMap::new();
185    }
186
187    // Collect all param names
188    let names: Vec<String> = samples[0].keys().cloned().collect();
189    let mut ranges: HashMap<String, (f32, f32)> = HashMap::new();
190
191    for name in &names {
192        let vals: Vec<f32> = samples
193            .iter()
194            .map(|s| *s.get(name).unwrap_or(&0.0))
195            .collect();
196        let min = vals.iter().cloned().fold(f32::MAX, f32::min);
197        let max = vals.iter().cloned().fold(f32::MIN, f32::max);
198        ranges.insert(name.clone(), (min, max));
199    }
200
201    for s in samples.iter_mut() {
202        for name in &names {
203            let (min, max) = ranges[name];
204            let span = max - min;
205            if span > 1e-12 {
206                let v = s.entry(name.clone()).or_insert(0.0);
207                *v = (*v - min) / span;
208            } else if let Some(v) = s.get_mut(name) {
209                *v = 0.0;
210            }
211        }
212    }
213
214    ranges
215}
216
217/// Importance score = variance / max_variance across all parameters.
218#[allow(dead_code)]
219pub fn param_importance_score(name: &str, samples: &[HashMap<String, f32>]) -> f32 {
220    if samples.is_empty() {
221        return 0.0;
222    }
223    let names: Vec<String> = samples[0].keys().cloned().collect();
224    let variances: Vec<f32> = names
225        .iter()
226        .map(|n| {
227            let vals: Vec<f32> = samples.iter().map(|s| *s.get(n).unwrap_or(&0.0)).collect();
228            param_variance(&vals)
229        })
230        .collect();
231    let max_var = variances.iter().cloned().fold(0.0_f32, f32::max);
232    if max_var < 1e-12 {
233        return 0.0;
234    }
235    let my_vals: Vec<f32> = samples
236        .iter()
237        .map(|s| *s.get(name).unwrap_or(&0.0))
238        .collect();
239    param_variance(&my_vals) / max_var
240}
241
242/// Analyze a parameter space and return the full analysis result.
243#[allow(dead_code)]
244pub fn analyze_param_space(
245    param_names: &[String],
246    param_samples: &[HashMap<String, f32>],
247) -> ParamSpaceAnalysis {
248    let cfg = ParamSpaceConfig::default();
249    let original_count = param_names.len();
250
251    let param_values: Vec<Vec<f32>> = param_names
252        .iter()
253        .map(|n| {
254            param_samples
255                .iter()
256                .map(|s| *s.get(n).unwrap_or(&0.0))
257                .collect()
258        })
259        .collect();
260
261    let variances: Vec<f32> = param_values.iter().map(|v| param_variance(v)).collect();
262    let correlation_matrix = build_correlation_matrix(&param_values);
263
264    let kept_names = reduce_param_set(param_names, param_samples, &cfg);
265    let kept_set: std::collections::HashSet<&String> = kept_names.iter().collect();
266    let removed_params: Vec<String> = param_names
267        .iter()
268        .filter(|n| !kept_set.contains(n))
269        .cloned()
270        .collect();
271
272    ParamSpaceAnalysis {
273        original_count,
274        kept_params: kept_names,
275        removed_params,
276        correlation_matrix,
277        variances,
278    }
279}
280
281// ── tests ─────────────────────────────────────────────────────────────────────
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    fn make_samples(data: &[(&str, Vec<f32>)]) -> Vec<HashMap<String, f32>> {
288        if data.is_empty() {
289            return vec![];
290        }
291        let n = data[0].1.len();
292        (0..n)
293            .map(|i| {
294                data.iter()
295                    .map(|(name, vals)| (name.to_string(), vals[i]))
296                    .collect()
297            })
298            .collect()
299    }
300
301    // 1. param_variance formula
302    #[test]
303    fn test_param_variance_formula() {
304        let v = vec![2.0_f32, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
305        let var = param_variance(&v);
306        // population variance = 4.0
307        assert!((var - 4.0).abs() < 1e-4, "expected ~4.0 got {var}");
308    }
309
310    // 2. param_correlation perfect positive = 1.0
311    #[test]
312    fn test_correlation_perfect_positive() {
313        let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
314        let b = vec![2.0_f32, 4.0, 6.0, 8.0, 10.0];
315        let r = param_correlation(&a, &b);
316        assert!((r - 1.0).abs() < 1e-5, "expected 1.0 got {r}");
317    }
318
319    // 3. param_correlation perfect negative = -1.0
320    #[test]
321    fn test_correlation_perfect_negative() {
322        let a = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0];
323        let b = vec![5.0_f32, 4.0, 3.0, 2.0, 1.0];
324        let r = param_correlation(&a, &b);
325        assert!((r + 1.0).abs() < 1e-5, "expected -1.0 got {r}");
326    }
327
328    // 4. param_correlation uncorrelated ≈ 0
329    #[test]
330    fn test_correlation_uncorrelated() {
331        let a = vec![1.0_f32, 1.0, 1.0, 1.0];
332        let b = vec![1.0_f32, 2.0, 3.0, 4.0];
333        // a has zero variance → correlation = 0
334        let r = param_correlation(&a, &b);
335        assert!(r.abs() < 1e-5, "expected ~0 got {r}");
336    }
337
338    // 5. build_correlation_matrix diagonal = 1
339    #[test]
340    fn test_correlation_matrix_diagonal() {
341        let samples = vec![
342            vec![1.0_f32, 2.0, 3.0],
343            vec![4.0_f32, 5.0, 6.0],
344            vec![7.0_f32, 8.0, 9.0],
345        ];
346        let mat = build_correlation_matrix(&samples);
347        for (i, row) in mat.iter().enumerate().take(3) {
348            assert!((row[i] - 1.0).abs() < 1e-5, "diagonal[{i}] != 1");
349        }
350    }
351
352    // 6. find_redundant_params removes correlated
353    #[test]
354    fn test_find_redundant_removes_correlated() {
355        // Build a 2×2 matrix where params 0 and 1 are perfectly correlated
356        let corr = vec![vec![1.0, 0.99], vec![0.99, 1.0]];
357        let names = vec!["a".to_string(), "b".to_string()];
358        let redundant = find_redundant_params(&corr, &names, 0.95);
359        assert_eq!(redundant.len(), 1);
360        assert_eq!(redundant[0], "b");
361    }
362
363    // 7. find_redundant_params keeps uncorrelated
364    #[test]
365    fn test_find_redundant_keeps_uncorrelated() {
366        let corr = vec![vec![1.0, 0.1], vec![0.1, 1.0]];
367        let names = vec!["a".to_string(), "b".to_string()];
368        let redundant = find_redundant_params(&corr, &names, 0.95);
369        assert!(redundant.is_empty());
370    }
371
372    // 8. reduce_param_set respects n_keep
373    #[test]
374    fn test_reduce_param_set_n_keep() {
375        // Use orthogonal (uncorrelated) signals so correlation pruning keeps all,
376        // then n_keep=2 selects the top 2 by variance.
377        let names: Vec<String> = (0..4).map(|i| format!("p{i}")).collect();
378        let samples = make_samples(&[
379            // p0: high variance, orthogonal
380            ("p0", vec![0.0, 10.0, 0.0, 10.0]),
381            // p1: high variance, orthogonal
382            ("p1", vec![0.0, 0.0, 10.0, 10.0]),
383            // p2: low variance
384            ("p2", vec![0.1, 0.2, 0.1, 0.2]),
385            // p3: low variance
386            ("p3", vec![0.01, 0.02, 0.01, 0.02]),
387        ]);
388        let cfg = ParamSpaceConfig {
389            n_keep: Some(2),
390            correlation_threshold: 1.0, // never remove on correlation
391            variance_threshold: 0.0,
392        };
393        let kept = reduce_param_set(&names, &samples, &cfg);
394        assert_eq!(kept.len(), 2, "expected 2 kept params, got {}", kept.len());
395    }
396
397    // 9. normalize_param_samples produces 0..1 range
398    #[test]
399    fn test_normalize_param_samples_range() {
400        let mut samples = make_samples(&[("x", vec![1.0, 5.0, 3.0])]);
401        normalize_param_samples(&mut samples);
402        let vals: Vec<f32> = samples
403            .iter()
404            .map(|s| *s.get("x").expect("should succeed"))
405            .collect();
406        let min = vals.iter().cloned().fold(f32::MAX, f32::min);
407        let max = vals.iter().cloned().fold(f32::MIN, f32::max);
408        assert!((min - 0.0).abs() < 1e-5, "min should be 0, got {min}");
409        assert!((max - 1.0).abs() < 1e-5, "max should be 1, got {max}");
410    }
411
412    // 10. analyze_param_space removes zero-variance
413    #[test]
414    fn test_analyze_removes_zero_variance() {
415        let names = vec!["vary".to_string(), "const".to_string()];
416        let samples = make_samples(&[
417            ("vary", vec![1.0, 2.0, 3.0, 4.0]),
418            ("const", vec![5.0, 5.0, 5.0, 5.0]),
419        ]);
420        let analysis = analyze_param_space(&names, &samples);
421        assert!(
422            analysis.removed_params.contains(&"const".to_string()),
423            "zero-variance param should be removed"
424        );
425    }
426
427    // 11. original_count = n params
428    #[test]
429    fn test_original_count() {
430        let names: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
431        let samples = make_samples(&[
432            ("a", vec![1.0, 2.0]),
433            ("b", vec![3.0, 4.0]),
434            ("c", vec![5.0, 6.0]),
435        ]);
436        let analysis = analyze_param_space(&names, &samples);
437        assert_eq!(analysis.original_count, 3);
438    }
439
440    // 12. kept + removed = original
441    #[test]
442    fn test_kept_plus_removed_eq_original() {
443        let names: Vec<String> = (0..4).map(|i| format!("p{i}")).collect();
444        let samples = make_samples(&[
445            ("p0", vec![1.0, 2.0, 3.0]),
446            ("p1", vec![1.0, 1.0, 1.0]), // zero variance → removed
447            ("p2", vec![4.0, 5.0, 6.0]),
448            ("p3", vec![7.0, 8.0, 9.0]),
449        ]);
450        let analysis = analyze_param_space(&names, &samples);
451        assert_eq!(
452            analysis.kept_params.len() + analysis.removed_params.len(),
453            analysis.original_count
454        );
455    }
456
457    // 13. param_importance_score returns 1.0 for highest-variance param
458    #[test]
459    fn test_param_importance_score_max() {
460        let samples = make_samples(&[
461            ("big", vec![0.0, 10.0, 20.0, 30.0]),
462            ("small", vec![0.0, 0.1, 0.2, 0.3]),
463        ]);
464        let score = param_importance_score("big", &samples);
465        assert!(
466            (score - 1.0).abs() < 1e-4,
467            "highest-variance param should score 1.0, got {score}"
468        );
469    }
470
471    // 14. normalize_param_samples returns correct (min, max) map
472    #[test]
473    fn test_normalize_returns_range_map() {
474        let mut samples = make_samples(&[("y", vec![2.0, 4.0, 6.0])]);
475        let ranges = normalize_param_samples(&mut samples);
476        let (min, max) = ranges["y"];
477        assert!((min - 2.0).abs() < 1e-5);
478        assert!((max - 6.0).abs() < 1e-5);
479    }
480}