Skip to main content

oxihuman_morph/
weight_optimizer.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Least-squares blend weight optimization via projected gradient descent.
5
6#[allow(dead_code)]
7#[derive(Debug, Clone)]
8pub struct WeightOptimizer {
9    pub max_iterations: u32,
10    pub learning_rate: f32,
11    pub convergence_eps: f32,
12    pub weight_min: f32,
13    pub weight_max: f32,
14}
15
16impl Default for WeightOptimizer {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22#[allow(dead_code)]
23#[derive(Debug, Clone)]
24pub struct OptimizationResult {
25    pub weights: Vec<f32>,
26    pub final_error: f32,
27    pub iterations: u32,
28    pub converged: bool,
29}
30
31impl WeightOptimizer {
32    #[allow(dead_code)]
33    pub fn new() -> Self {
34        Self {
35            max_iterations: 200,
36            learning_rate: 0.01,
37            convergence_eps: 1e-5,
38            weight_min: 0.0,
39            weight_max: 1.0,
40        }
41    }
42
43    #[allow(dead_code)]
44    pub fn optimize(
45        &self,
46        base: &[[f32; 3]],
47        target: &[[f32; 3]],
48        morph_deltas: &[Vec<[f32; 3]>],
49    ) -> OptimizationResult {
50        let k = morph_deltas.len();
51        if k == 0 {
52            let error = reconstruction_error(base, target, morph_deltas, &[]);
53            return OptimizationResult {
54                weights: vec![],
55                final_error: error,
56                iterations: 0,
57                converged: true,
58            };
59        }
60
61        let mut weights = vec![0.0f32; k];
62        let mut converged = false;
63        let mut iters = 0u32;
64
65        for iter in 0..self.max_iterations {
66            iters = iter + 1;
67            let grad = gradient_wrt_weights(base, target, morph_deltas, &weights);
68            let grad_norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
69
70            for i in 0..k {
71                weights[i] -= self.learning_rate * grad[i];
72            }
73            clamp_weights(&mut weights, self.weight_min, self.weight_max);
74
75            if grad_norm < self.convergence_eps {
76                converged = true;
77                break;
78            }
79        }
80
81        let final_error = reconstruction_error(base, target, morph_deltas, &weights);
82        OptimizationResult {
83            weights,
84            final_error,
85            iterations: iters,
86            converged,
87        }
88    }
89}
90
91#[allow(dead_code)]
92pub fn reconstruction_error(
93    base: &[[f32; 3]],
94    target: &[[f32; 3]],
95    deltas: &[Vec<[f32; 3]>],
96    weights: &[f32],
97) -> f32 {
98    if base.is_empty() {
99        return 0.0;
100    }
101    let blended = apply_weights(base, deltas, weights);
102    let n = base.len() as f32;
103    blended
104        .iter()
105        .zip(target.iter())
106        .map(|(b, t)| {
107            let dx = b[0] - t[0];
108            let dy = b[1] - t[1];
109            let dz = b[2] - t[2];
110            dx * dx + dy * dy + dz * dz
111        })
112        .sum::<f32>()
113        / n
114}
115
116#[allow(dead_code)]
117pub fn gradient_wrt_weights(
118    base: &[[f32; 3]],
119    target: &[[f32; 3]],
120    deltas: &[Vec<[f32; 3]>],
121    weights: &[f32],
122) -> Vec<f32> {
123    let k = weights.len();
124    let n = base.len();
125    if n == 0 || k == 0 {
126        return vec![0.0; k];
127    }
128
129    let blended = apply_weights(base, deltas, weights);
130    let scale = 2.0 / n as f32;
131
132    (0..k)
133        .map(|i| {
134            let delta_i = &deltas[i];
135            let dlen = delta_i.len().min(n);
136            let mut g = 0.0f32;
137            for v in 0..dlen {
138                let rx = blended[v][0] - target[v][0];
139                let ry = blended[v][1] - target[v][1];
140                let rz = blended[v][2] - target[v][2];
141                g += rx * delta_i[v][0] + ry * delta_i[v][1] + rz * delta_i[v][2];
142            }
143            g * scale
144        })
145        .collect()
146}
147
148#[allow(dead_code)]
149pub fn apply_weights(
150    base: &[[f32; 3]],
151    deltas: &[Vec<[f32; 3]>],
152    weights: &[f32],
153) -> Vec<[f32; 3]> {
154    let n = base.len();
155    let mut result: Vec<[f32; 3]> = base.to_vec();
156    for (i, w) in weights.iter().enumerate() {
157        if i >= deltas.len() {
158            break;
159        }
160        let d = &deltas[i];
161        let dlen = d.len().min(n);
162        for v in 0..dlen {
163            result[v][0] += w * d[v][0];
164            result[v][1] += w * d[v][1];
165            result[v][2] += w * d[v][2];
166        }
167    }
168    result
169}
170
171#[allow(dead_code)]
172pub fn clamp_weights(weights: &mut [f32], min: f32, max: f32) {
173    for w in weights.iter_mut() {
174        *w = w.clamp(min, max);
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    fn make_base(n: usize) -> Vec<[f32; 3]> {
183        (0..n).map(|i| [i as f32, 0.0, 0.0]).collect()
184    }
185
186    #[test]
187    fn test_apply_weights_no_deltas() {
188        let base = make_base(3);
189        let result = apply_weights(&base, &[], &[]);
190        assert_eq!(result, base);
191    }
192
193    #[test]
194    fn test_apply_weights_single() {
195        let base = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
196        let deltas = vec![vec![[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]];
197        let weights = [0.5f32];
198        let result = apply_weights(&base, &deltas, &weights);
199        assert!((result[0][0] - 0.5).abs() < 1e-6);
200        assert!((result[1][0] - 1.5).abs() < 1e-6);
201    }
202
203    #[test]
204    fn test_clamp_weights() {
205        let mut w = vec![-0.5, 0.5, 1.5, 0.0];
206        clamp_weights(&mut w, 0.0, 1.0);
207        assert!((w[0] - 0.0).abs() < 1e-6);
208        assert!((w[1] - 0.5).abs() < 1e-6);
209        assert!((w[2] - 1.0).abs() < 1e-6);
210        assert!((w[3] - 0.0).abs() < 1e-6);
211    }
212
213    #[test]
214    fn test_reconstruction_error_zero() {
215        let base = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
216        let target = base.clone();
217        let err = reconstruction_error(&base, &target, &[], &[]);
218        assert!(err.abs() < 1e-6, "identical base/target: {err}");
219    }
220
221    #[test]
222    fn test_reconstruction_error_nonzero() {
223        let base = vec![[0.0, 0.0, 0.0]];
224        let target = vec![[1.0, 0.0, 0.0]];
225        let err = reconstruction_error(&base, &target, &[], &[]);
226        assert!((err - 1.0).abs() < 1e-6, "error = 1.0: {err}");
227    }
228
229    #[test]
230    fn test_gradient_direction() {
231        // If blended > target on x, gradient for that morph should be positive
232        let base = vec![[0.0, 0.0, 0.0]];
233        let target = vec![[1.0, 0.0, 0.0]];
234        let deltas = vec![vec![[1.0, 0.0, 0.0]]];
235        let weights = [0.0f32];
236        let grad = gradient_wrt_weights(&base, &target, &deltas, &weights);
237        // residual = blended - target = -1.0 => gradient should be negative => step toward target
238        assert!(
239            grad[0] < 0.0,
240            "gradient should be negative to increase weight: {}",
241            grad[0]
242        );
243    }
244
245    #[test]
246    fn test_gradient_zero_at_perfect_fit() {
247        let base = vec![[0.0, 0.0, 0.0]];
248        let target = vec![[1.0, 0.0, 0.0]];
249        let deltas = vec![vec![[1.0, 0.0, 0.0]]];
250        let weights = [1.0f32];
251        let grad = gradient_wrt_weights(&base, &target, &deltas, &weights);
252        assert!(
253            grad[0].abs() < 1e-6,
254            "at perfect fit gradient is zero: {}",
255            grad[0]
256        );
257    }
258
259    #[test]
260    fn test_single_target_perfect_fit() {
261        // Single morph that exactly spans base→target; optimizer should find weight≈1
262        let n = 4;
263        let base: Vec<[f32; 3]> = (0..n).map(|i| [i as f32, 0.0, 0.0]).collect();
264        let target: Vec<[f32; 3]> = (0..n).map(|i| [i as f32 + 1.0, 0.0, 0.0]).collect();
265        let deltas = vec![(0..n).map(|_| [1.0f32, 0.0, 0.0]).collect::<Vec<_>>()];
266        let opt = WeightOptimizer {
267            max_iterations: 1000,
268            learning_rate: 0.1,
269            convergence_eps: 1e-6,
270            weight_min: 0.0,
271            weight_max: 1.0,
272        };
273        let result = opt.optimize(&base, &target, &deltas);
274        assert!(
275            (result.weights[0] - 1.0).abs() < 0.01,
276            "should converge to 1.0, got {}",
277            result.weights[0]
278        );
279        assert!(result.final_error < 1e-4);
280    }
281
282    #[test]
283    fn test_zero_target_weight_stays_zero() {
284        // Target == base → zero delta needed → weight stays 0.0
285        let n = 3;
286        let base: Vec<[f32; 3]> = (0..n).map(|i| [i as f32, 0.0, 0.0]).collect();
287        let target = base.clone();
288        let deltas = vec![(0..n).map(|_| [1.0f32, 0.0, 0.0]).collect::<Vec<_>>()];
289        let opt = WeightOptimizer::new();
290        let result = opt.optimize(&base, &target, &deltas);
291        assert!(
292            result.weights[0] < 0.01,
293            "weight should stay near 0: {}",
294            result.weights[0]
295        );
296    }
297
298    #[test]
299    fn test_empty_deltas() {
300        let base = vec![[0.0, 0.0, 0.0]];
301        let target = vec![[1.0, 0.0, 0.0]];
302        let opt = WeightOptimizer::new();
303        let result = opt.optimize(&base, &target, &[]);
304        assert_eq!(result.weights.len(), 0);
305        assert!(result.converged);
306    }
307
308    #[test]
309    fn test_convergence_flag() {
310        let base = vec![[0.0f32, 0.0, 0.0]];
311        let target = vec![[0.0f32, 0.0, 0.0]];
312        let deltas = vec![vec![[1.0f32, 0.0, 0.0]]];
313        let opt = WeightOptimizer {
314            max_iterations: 500,
315            learning_rate: 0.1,
316            convergence_eps: 1e-5,
317            weight_min: 0.0,
318            weight_max: 1.0,
319        };
320        let result = opt.optimize(&base, &target, &deltas);
321        assert!(result.converged, "should converge when target==base");
322    }
323
324    #[test]
325    fn test_reconstruction_error_formula() {
326        // MSE = sum(||blended - target||^2) / n
327        let base = vec![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
328        let target = vec![[2.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
329        // (2^2 + 2^2) / 2 = 4.0
330        let err = reconstruction_error(&base, &target, &[], &[]);
331        assert!((err - 4.0).abs() < 1e-5, "err={err}");
332    }
333
334    #[test]
335    fn test_multiple_morphs_add() {
336        let base = vec![[0.0, 0.0, 0.0]];
337        let target = vec![[2.0, 0.0, 0.0]];
338        let d1 = vec![vec![[1.0f32, 0.0, 0.0]]];
339        let d2 = vec![vec![[1.0f32, 0.0, 0.0]], vec![[1.0f32, 0.0, 0.0]]];
340        let opt = WeightOptimizer {
341            max_iterations: 2000,
342            learning_rate: 0.05,
343            convergence_eps: 1e-6,
344            weight_min: 0.0,
345            weight_max: 1.0,
346        };
347        let r1 = opt.optimize(&base, &target, &d1);
348        let r2 = opt.optimize(&base, &target, &d2);
349        // single morph needs weight=2 but clamped to 1; two morphs can each have weight=1
350        assert!(
351            r2.final_error < r1.final_error + 0.01,
352            "two morphs should fit better: {} vs {}",
353            r2.final_error,
354            r1.final_error
355        );
356    }
357}