1#[allow(dead_code)]
7#[derive(Debug, Clone)]
8pub struct WeightOptimizer {
9 pub max_iterations: u32,
10 pub learning_rate: f32,
11 pub convergence_eps: f32,
12 pub weight_min: f32,
13 pub weight_max: f32,
14}
15
16impl Default for WeightOptimizer {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22#[allow(dead_code)]
23#[derive(Debug, Clone)]
24pub struct OptimizationResult {
25 pub weights: Vec<f32>,
26 pub final_error: f32,
27 pub iterations: u32,
28 pub converged: bool,
29}
30
31impl WeightOptimizer {
32 #[allow(dead_code)]
33 pub fn new() -> Self {
34 Self {
35 max_iterations: 200,
36 learning_rate: 0.01,
37 convergence_eps: 1e-5,
38 weight_min: 0.0,
39 weight_max: 1.0,
40 }
41 }
42
43 #[allow(dead_code)]
44 pub fn optimize(
45 &self,
46 base: &[[f32; 3]],
47 target: &[[f32; 3]],
48 morph_deltas: &[Vec<[f32; 3]>],
49 ) -> OptimizationResult {
50 let k = morph_deltas.len();
51 if k == 0 {
52 let error = reconstruction_error(base, target, morph_deltas, &[]);
53 return OptimizationResult {
54 weights: vec![],
55 final_error: error,
56 iterations: 0,
57 converged: true,
58 };
59 }
60
61 let mut weights = vec![0.0f32; k];
62 let mut converged = false;
63 let mut iters = 0u32;
64
65 for iter in 0..self.max_iterations {
66 iters = iter + 1;
67 let grad = gradient_wrt_weights(base, target, morph_deltas, &weights);
68 let grad_norm: f32 = grad.iter().map(|g| g * g).sum::<f32>().sqrt();
69
70 for i in 0..k {
71 weights[i] -= self.learning_rate * grad[i];
72 }
73 clamp_weights(&mut weights, self.weight_min, self.weight_max);
74
75 if grad_norm < self.convergence_eps {
76 converged = true;
77 break;
78 }
79 }
80
81 let final_error = reconstruction_error(base, target, morph_deltas, &weights);
82 OptimizationResult {
83 weights,
84 final_error,
85 iterations: iters,
86 converged,
87 }
88 }
89}
90
91#[allow(dead_code)]
92pub fn reconstruction_error(
93 base: &[[f32; 3]],
94 target: &[[f32; 3]],
95 deltas: &[Vec<[f32; 3]>],
96 weights: &[f32],
97) -> f32 {
98 if base.is_empty() {
99 return 0.0;
100 }
101 let blended = apply_weights(base, deltas, weights);
102 let n = base.len() as f32;
103 blended
104 .iter()
105 .zip(target.iter())
106 .map(|(b, t)| {
107 let dx = b[0] - t[0];
108 let dy = b[1] - t[1];
109 let dz = b[2] - t[2];
110 dx * dx + dy * dy + dz * dz
111 })
112 .sum::<f32>()
113 / n
114}
115
116#[allow(dead_code)]
117pub fn gradient_wrt_weights(
118 base: &[[f32; 3]],
119 target: &[[f32; 3]],
120 deltas: &[Vec<[f32; 3]>],
121 weights: &[f32],
122) -> Vec<f32> {
123 let k = weights.len();
124 let n = base.len();
125 if n == 0 || k == 0 {
126 return vec![0.0; k];
127 }
128
129 let blended = apply_weights(base, deltas, weights);
130 let scale = 2.0 / n as f32;
131
132 (0..k)
133 .map(|i| {
134 let delta_i = &deltas[i];
135 let dlen = delta_i.len().min(n);
136 let mut g = 0.0f32;
137 for v in 0..dlen {
138 let rx = blended[v][0] - target[v][0];
139 let ry = blended[v][1] - target[v][1];
140 let rz = blended[v][2] - target[v][2];
141 g += rx * delta_i[v][0] + ry * delta_i[v][1] + rz * delta_i[v][2];
142 }
143 g * scale
144 })
145 .collect()
146}
147
148#[allow(dead_code)]
149pub fn apply_weights(
150 base: &[[f32; 3]],
151 deltas: &[Vec<[f32; 3]>],
152 weights: &[f32],
153) -> Vec<[f32; 3]> {
154 let n = base.len();
155 let mut result: Vec<[f32; 3]> = base.to_vec();
156 for (i, w) in weights.iter().enumerate() {
157 if i >= deltas.len() {
158 break;
159 }
160 let d = &deltas[i];
161 let dlen = d.len().min(n);
162 for v in 0..dlen {
163 result[v][0] += w * d[v][0];
164 result[v][1] += w * d[v][1];
165 result[v][2] += w * d[v][2];
166 }
167 }
168 result
169}
170
171#[allow(dead_code)]
172pub fn clamp_weights(weights: &mut [f32], min: f32, max: f32) {
173 for w in weights.iter_mut() {
174 *w = w.clamp(min, max);
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 fn make_base(n: usize) -> Vec<[f32; 3]> {
183 (0..n).map(|i| [i as f32, 0.0, 0.0]).collect()
184 }
185
186 #[test]
187 fn test_apply_weights_no_deltas() {
188 let base = make_base(3);
189 let result = apply_weights(&base, &[], &[]);
190 assert_eq!(result, base);
191 }
192
193 #[test]
194 fn test_apply_weights_single() {
195 let base = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
196 let deltas = vec![vec![[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]];
197 let weights = [0.5f32];
198 let result = apply_weights(&base, &deltas, &weights);
199 assert!((result[0][0] - 0.5).abs() < 1e-6);
200 assert!((result[1][0] - 1.5).abs() < 1e-6);
201 }
202
203 #[test]
204 fn test_clamp_weights() {
205 let mut w = vec![-0.5, 0.5, 1.5, 0.0];
206 clamp_weights(&mut w, 0.0, 1.0);
207 assert!((w[0] - 0.0).abs() < 1e-6);
208 assert!((w[1] - 0.5).abs() < 1e-6);
209 assert!((w[2] - 1.0).abs() < 1e-6);
210 assert!((w[3] - 0.0).abs() < 1e-6);
211 }
212
213 #[test]
214 fn test_reconstruction_error_zero() {
215 let base = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]];
216 let target = base.clone();
217 let err = reconstruction_error(&base, &target, &[], &[]);
218 assert!(err.abs() < 1e-6, "identical base/target: {err}");
219 }
220
221 #[test]
222 fn test_reconstruction_error_nonzero() {
223 let base = vec![[0.0, 0.0, 0.0]];
224 let target = vec![[1.0, 0.0, 0.0]];
225 let err = reconstruction_error(&base, &target, &[], &[]);
226 assert!((err - 1.0).abs() < 1e-6, "error = 1.0: {err}");
227 }
228
229 #[test]
230 fn test_gradient_direction() {
231 let base = vec![[0.0, 0.0, 0.0]];
233 let target = vec![[1.0, 0.0, 0.0]];
234 let deltas = vec![vec![[1.0, 0.0, 0.0]]];
235 let weights = [0.0f32];
236 let grad = gradient_wrt_weights(&base, &target, &deltas, &weights);
237 assert!(
239 grad[0] < 0.0,
240 "gradient should be negative to increase weight: {}",
241 grad[0]
242 );
243 }
244
245 #[test]
246 fn test_gradient_zero_at_perfect_fit() {
247 let base = vec![[0.0, 0.0, 0.0]];
248 let target = vec![[1.0, 0.0, 0.0]];
249 let deltas = vec![vec![[1.0, 0.0, 0.0]]];
250 let weights = [1.0f32];
251 let grad = gradient_wrt_weights(&base, &target, &deltas, &weights);
252 assert!(
253 grad[0].abs() < 1e-6,
254 "at perfect fit gradient is zero: {}",
255 grad[0]
256 );
257 }
258
259 #[test]
260 fn test_single_target_perfect_fit() {
261 let n = 4;
263 let base: Vec<[f32; 3]> = (0..n).map(|i| [i as f32, 0.0, 0.0]).collect();
264 let target: Vec<[f32; 3]> = (0..n).map(|i| [i as f32 + 1.0, 0.0, 0.0]).collect();
265 let deltas = vec![(0..n).map(|_| [1.0f32, 0.0, 0.0]).collect::<Vec<_>>()];
266 let opt = WeightOptimizer {
267 max_iterations: 1000,
268 learning_rate: 0.1,
269 convergence_eps: 1e-6,
270 weight_min: 0.0,
271 weight_max: 1.0,
272 };
273 let result = opt.optimize(&base, &target, &deltas);
274 assert!(
275 (result.weights[0] - 1.0).abs() < 0.01,
276 "should converge to 1.0, got {}",
277 result.weights[0]
278 );
279 assert!(result.final_error < 1e-4);
280 }
281
282 #[test]
283 fn test_zero_target_weight_stays_zero() {
284 let n = 3;
286 let base: Vec<[f32; 3]> = (0..n).map(|i| [i as f32, 0.0, 0.0]).collect();
287 let target = base.clone();
288 let deltas = vec![(0..n).map(|_| [1.0f32, 0.0, 0.0]).collect::<Vec<_>>()];
289 let opt = WeightOptimizer::new();
290 let result = opt.optimize(&base, &target, &deltas);
291 assert!(
292 result.weights[0] < 0.01,
293 "weight should stay near 0: {}",
294 result.weights[0]
295 );
296 }
297
298 #[test]
299 fn test_empty_deltas() {
300 let base = vec![[0.0, 0.0, 0.0]];
301 let target = vec![[1.0, 0.0, 0.0]];
302 let opt = WeightOptimizer::new();
303 let result = opt.optimize(&base, &target, &[]);
304 assert_eq!(result.weights.len(), 0);
305 assert!(result.converged);
306 }
307
308 #[test]
309 fn test_convergence_flag() {
310 let base = vec![[0.0f32, 0.0, 0.0]];
311 let target = vec![[0.0f32, 0.0, 0.0]];
312 let deltas = vec![vec![[1.0f32, 0.0, 0.0]]];
313 let opt = WeightOptimizer {
314 max_iterations: 500,
315 learning_rate: 0.1,
316 convergence_eps: 1e-5,
317 weight_min: 0.0,
318 weight_max: 1.0,
319 };
320 let result = opt.optimize(&base, &target, &deltas);
321 assert!(result.converged, "should converge when target==base");
322 }
323
324 #[test]
325 fn test_reconstruction_error_formula() {
326 let base = vec![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]];
328 let target = vec![[2.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
329 let err = reconstruction_error(&base, &target, &[], &[]);
331 assert!((err - 4.0).abs() < 1e-5, "err={err}");
332 }
333
334 #[test]
335 fn test_multiple_morphs_add() {
336 let base = vec![[0.0, 0.0, 0.0]];
337 let target = vec![[2.0, 0.0, 0.0]];
338 let d1 = vec![vec![[1.0f32, 0.0, 0.0]]];
339 let d2 = vec![vec![[1.0f32, 0.0, 0.0]], vec![[1.0f32, 0.0, 0.0]]];
340 let opt = WeightOptimizer {
341 max_iterations: 2000,
342 learning_rate: 0.05,
343 convergence_eps: 1e-6,
344 weight_min: 0.0,
345 weight_max: 1.0,
346 };
347 let r1 = opt.optimize(&base, &target, &d1);
348 let r2 = opt.optimize(&base, &target, &d2);
349 assert!(
351 r2.final_error < r1.final_error + 0.01,
352 "two morphs should fit better: {} vs {}",
353 r2.final_error,
354 r1.final_error
355 );
356 }
357}