1use std::collections::HashMap;
7
8#[allow(dead_code)]
14#[derive(Debug, Clone)]
15pub struct CorrectiveShape {
16 pub name: String,
17 pub driver_params: HashMap<String, f32>,
19 pub deltas: Vec<[f32; 3]>,
21 pub influence_radius: f32,
23}
24
25#[allow(dead_code)]
27#[derive(Debug, Clone)]
28pub struct CorrectiveShapeLibrary {
29 pub shapes: Vec<CorrectiveShape>,
30 pub vertex_count: usize,
31}
32
33#[allow(dead_code)]
35#[derive(Debug, Clone)]
36pub struct CorrectiveEvalResult {
37 pub combined_deltas: Vec<[f32; 3]>,
38 pub active_shapes: Vec<(String, f32)>,
40}
41
42impl CorrectiveShapeLibrary {
47 #[allow(dead_code)]
48 pub fn new(vertex_count: usize) -> Self {
49 Self {
50 shapes: Vec::new(),
51 vertex_count,
52 }
53 }
54
55 #[allow(dead_code)]
56 pub fn add_shape(&mut self, shape: CorrectiveShape) {
57 self.shapes.push(shape);
58 }
59
60 #[allow(dead_code)]
62 pub fn evaluate(&self, current_params: &HashMap<String, f32>) -> CorrectiveEvalResult {
63 let mut pairs: Vec<(Vec<[f32; 3]>, f32)> = Vec::new();
64 let mut active_shapes = Vec::new();
65
66 for shape in &self.shapes {
67 let dist = corrective_distance(current_params, &shape.driver_params);
68 let w = corrective_weight(dist, shape.influence_radius);
69 if w > 0.01 {
70 active_shapes.push((shape.name.clone(), w));
71 pairs.push((shape.deltas.clone(), w));
72 }
73 }
74
75 let combined_deltas = combine_corrective_deltas(&pairs, self.vertex_count);
76 CorrectiveEvalResult {
77 combined_deltas,
78 active_shapes,
79 }
80 }
81}
82
83#[allow(dead_code)]
89pub fn corrective_distance(current: &HashMap<String, f32>, driver: &HashMap<String, f32>) -> f32 {
90 let mut sum_sq = 0.0f32;
91 for (k, &d) in driver {
92 let c = current.get(k).copied().unwrap_or(0.0);
93 sum_sq += (c - d) * (c - d);
94 }
95 sum_sq.sqrt()
96}
97
98#[allow(dead_code)]
100pub fn corrective_weight(distance: f32, radius: f32) -> f32 {
101 let r = radius.max(f32::EPSILON);
102 let t = distance / r;
103 (-t * t).exp()
104}
105
106#[allow(dead_code)]
108pub fn combine_corrective_deltas(
109 deltas_and_weights: &[(Vec<[f32; 3]>, f32)],
110 vertex_count: usize,
111) -> Vec<[f32; 3]> {
112 let mut out = vec![[0.0f32; 3]; vertex_count];
113 for (deltas, w) in deltas_and_weights {
114 let n = deltas.len().min(vertex_count);
115 for (out_v, delta_v) in out.iter_mut().zip(deltas.iter()).take(n) {
116 out_v[0] += delta_v[0] * w;
117 out_v[1] += delta_v[1] * w;
118 out_v[2] += delta_v[2] * w;
119 }
120 }
121 out
122}
123
124#[allow(dead_code)]
126pub fn apply_corrective_to_mesh(base: &[[f32; 3]], result: &CorrectiveEvalResult) -> Vec<[f32; 3]> {
127 let mut out: Vec<[f32; 3]> = base.to_vec();
128 for (out_v, delta_v) in out.iter_mut().zip(result.combined_deltas.iter()) {
129 out_v[0] += delta_v[0];
130 out_v[1] += delta_v[1];
131 out_v[2] += delta_v[2];
132 }
133 out
134}
135
136#[allow(dead_code)]
138pub fn standard_corrective_shapes(vertex_count: usize) -> CorrectiveShapeLibrary {
139 let mut lib = CorrectiveShapeLibrary::new(vertex_count);
140
141 {
143 let mut driver = HashMap::new();
144 driver.insert("shoulder_raise_l".into(), 1.0);
145 let deltas: Vec<[f32; 3]> = (0..vertex_count)
146 .map(|i| {
147 let t = (i as f32) / (vertex_count.max(1) as f32);
148 [0.0, t * 0.02, 0.0]
149 })
150 .collect();
151 lib.add_shape(CorrectiveShape {
152 name: "shoulder_raise_left".into(),
153 driver_params: driver,
154 deltas,
155 influence_radius: 1.0,
156 });
157 }
158
159 {
161 let mut driver = HashMap::new();
162 driver.insert("elbow_bend_r".into(), 1.0);
163 let deltas: Vec<[f32; 3]> = (0..vertex_count)
164 .map(|i| {
165 let t = (i as f32) / (vertex_count.max(1) as f32);
166 [t * 0.01, 0.0, 0.0]
167 })
168 .collect();
169 lib.add_shape(CorrectiveShape {
170 name: "elbow_bend_right".into(),
171 driver_params: driver,
172 deltas,
173 influence_radius: 1.0,
174 });
175 }
176
177 {
179 let mut driver = HashMap::new();
180 driver.insert("knee_bend".into(), 1.0);
181 let deltas: Vec<[f32; 3]> = (0..vertex_count)
182 .map(|i| {
183 let t = (i as f32) / (vertex_count.max(1) as f32);
184 [0.0, 0.0, t * 0.015]
185 })
186 .collect();
187 lib.add_shape(CorrectiveShape {
188 name: "squat_knee".into(),
189 driver_params: driver,
190 deltas,
191 influence_radius: 1.0,
192 });
193 }
194
195 {
197 let mut driver = HashMap::new();
198 driver.insert("belly_weight".into(), 1.0);
199 let deltas: Vec<[f32; 3]> = (0..vertex_count)
200 .map(|i| {
201 let t = (i as f32) / (vertex_count.max(1) as f32);
202 [0.0, -t * 0.01, t * 0.03]
203 })
204 .collect();
205 lib.add_shape(CorrectiveShape {
206 name: "heavy_belly".into(),
207 driver_params: driver,
208 deltas,
209 influence_radius: 1.0,
210 });
211 }
212
213 lib
214}
215
216#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_corrective_weight_at_zero() {
226 assert!((corrective_weight(0.0, 1.0) - 1.0).abs() < 1e-6);
227 }
228
229 #[test]
230 fn test_corrective_weight_at_radius() {
231 let w = corrective_weight(1.0, 1.0);
232 assert!(w < 0.37 && w > 0.35, "w={w}");
234 }
235
236 #[test]
237 fn test_corrective_weight_large_distance() {
238 let w = corrective_weight(100.0, 1.0);
239 assert!(w < 1e-10, "w={w}");
240 }
241
242 #[test]
243 fn test_corrective_distance_same_params() {
244 let mut p = HashMap::new();
245 p.insert("a".into(), 1.0);
246 p.insert("b".into(), 2.0);
247 assert!(corrective_distance(&p, &p) < 1e-6);
248 }
249
250 #[test]
251 fn test_corrective_distance_different_params() {
252 let mut current = HashMap::new();
253 current.insert("x".into(), 0.0);
254 let mut driver = HashMap::new();
255 driver.insert("x".into(), 3.0);
256 driver.insert("y".into(), 4.0); let d = corrective_distance(¤t, &driver);
259 assert!((d - 5.0).abs() < 1e-5, "d={d}");
260 }
261
262 #[test]
263 fn test_combine_corrective_deltas_single_weight() {
264 let deltas = vec![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
265 let combined = combine_corrective_deltas(&[(deltas, 0.5)], 2);
266 assert!((combined[0][0] - 0.5).abs() < 1e-5);
267 assert!((combined[1][2] - 3.0).abs() < 1e-5);
268 }
269
270 #[test]
271 fn test_combine_corrective_deltas_two_shapes() {
272 let d1 = vec![[1.0f32, 0.0, 0.0]];
273 let d2 = vec![[0.0f32, 1.0, 0.0]];
274 let combined = combine_corrective_deltas(&[(d1, 1.0), (d2, 1.0)], 1);
275 assert!((combined[0][0] - 1.0).abs() < 1e-5);
276 assert!((combined[0][1] - 1.0).abs() < 1e-5);
277 }
278
279 #[test]
280 fn test_evaluate_matching_params() {
281 let lib = standard_corrective_shapes(4);
282 let mut params = HashMap::new();
283 params.insert("shoulder_raise_l".into(), 1.0);
284 let result = lib.evaluate(¶ms);
285 assert!(!result.active_shapes.is_empty());
286 assert!(result
287 .active_shapes
288 .iter()
289 .any(|(n, _)| n == "shoulder_raise_left"));
290 }
291
292 #[test]
293 fn test_evaluate_no_matching_params() {
294 let lib = standard_corrective_shapes(4);
295 let params = HashMap::new(); let result = lib.evaluate(¶ms);
297 assert_eq!(result.combined_deltas.len(), 4);
302 }
303
304 #[test]
305 fn test_evaluate_far_params_near_zero() {
306 let lib = standard_corrective_shapes(4);
307 let mut params = HashMap::new();
308 params.insert("shoulder_raise_l".into(), 1000.0); let result = lib.evaluate(¶ms);
310 let shoulder = result
312 .active_shapes
313 .iter()
314 .find(|(n, _)| n == "shoulder_raise_left");
315 if let Some((_, w)) = shoulder {
316 assert!(*w < 0.01 || *w < 1.0);
317 }
318 assert_eq!(result.combined_deltas.len(), 4);
319 }
320
321 #[test]
322 fn test_standard_corrective_shapes_has_4() {
323 let lib = standard_corrective_shapes(10);
324 assert_eq!(lib.shapes.len(), 4);
325 }
326
327 #[test]
328 fn test_apply_corrective_to_mesh_adds_deltas() {
329 let base = vec![[1.0f32, 1.0, 1.0], [2.0, 2.0, 2.0]];
330 let combined_deltas = vec![[0.1f32, 0.2, 0.3], [0.4, 0.5, 0.6]];
331 let result = CorrectiveEvalResult {
332 combined_deltas,
333 active_shapes: Vec::new(),
334 };
335 let out = apply_corrective_to_mesh(&base, &result);
336 assert!((out[0][0] - 1.1).abs() < 1e-5);
337 assert!((out[1][2] - 2.6).abs() < 1e-5);
338 }
339
340 #[test]
341 fn test_apply_corrective_zero_weight_no_change() {
342 let base = vec![[5.0f32, 5.0, 5.0]];
343 let combined_deltas = vec![[0.0f32, 0.0, 0.0]];
344 let result = CorrectiveEvalResult {
345 combined_deltas,
346 active_shapes: Vec::new(),
347 };
348 let out = apply_corrective_to_mesh(&base, &result);
349 assert!((out[0][0] - 5.0).abs() < 1e-5);
350 }
351
352 #[test]
353 fn test_combine_corrective_deltas_empty() {
354 let combined = combine_corrective_deltas(&[], 3);
355 assert_eq!(combined.len(), 3);
356 assert_eq!(combined[0], [0.0, 0.0, 0.0]);
357 }
358}