Skip to main content

oxihuman_morph/
param_constraint.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub type Params = HashMap<String, f32>;
9
10/// Constraint types between parameters.
11#[derive(Clone, Debug)]
12pub enum Constraint {
13    /// target must be >= source * factor + offset
14    MinRelative {
15        source: String,
16        target: String,
17        factor: f32,
18        offset: f32,
19    },
20    /// target must be <= source * factor + offset
21    MaxRelative {
22        source: String,
23        target: String,
24        factor: f32,
25        offset: f32,
26    },
27    /// target = source * factor + offset (hard link), clamped to [0, 1]
28    Driven {
29        source: String,
30        target: String,
31        factor: f32,
32        offset: f32,
33    },
34    /// sum of named params must equal total (redistributes proportionally)
35    SumEquals { params: Vec<String>, total: f32 },
36    /// param must be in [min, max]
37    Clamp { param: String, min: f32, max: f32 },
38    /// param_a + param_b <= max_sum
39    MaxSum {
40        param_a: String,
41        param_b: String,
42        max_sum: f32,
43    },
44    /// if condition_param >= threshold, then target_param gets assigned value
45    Conditional {
46        condition: String,
47        threshold: f32,
48        target: String,
49        value: f32,
50    },
51}
52
53impl Constraint {
54    /// Apply this constraint to params in-place, return true if any change was made.
55    pub fn apply(&self, params: &mut Params) -> bool {
56        const EPS: f32 = 1e-6;
57        match self {
58            Constraint::MinRelative {
59                source,
60                target,
61                factor,
62                offset,
63            } => {
64                let src = *params.get(source).unwrap_or(&0.0);
65                let min_val = src * factor + offset;
66                let cur = *params.get(target).unwrap_or(&0.0);
67                if cur < min_val - EPS {
68                    params.insert(target.clone(), min_val);
69                    true
70                } else {
71                    false
72                }
73            }
74            Constraint::MaxRelative {
75                source,
76                target,
77                factor,
78                offset,
79            } => {
80                let src = *params.get(source).unwrap_or(&0.0);
81                let max_val = src * factor + offset;
82                let cur = *params.get(target).unwrap_or(&0.0);
83                if cur > max_val + EPS {
84                    params.insert(target.clone(), max_val);
85                    true
86                } else {
87                    false
88                }
89            }
90            Constraint::Driven {
91                source,
92                target,
93                factor,
94                offset,
95            } => {
96                let src = *params.get(source).unwrap_or(&0.0);
97                let new_val = (src * factor + offset).clamp(0.0, 1.0);
98                let cur = *params.get(target).unwrap_or(&0.0);
99                if (cur - new_val).abs() > EPS {
100                    params.insert(target.clone(), new_val);
101                    true
102                } else {
103                    false
104                }
105            }
106            Constraint::SumEquals {
107                params: keys,
108                total,
109            } => {
110                let current_sum: f32 = keys.iter().map(|k| *params.get(k).unwrap_or(&0.0)).sum();
111                if (current_sum - total).abs() <= EPS {
112                    return false;
113                }
114                if current_sum.abs() < EPS {
115                    // Distribute equally
116                    let equal_share = total / keys.len() as f32;
117                    for k in keys {
118                        params.insert(k.clone(), equal_share);
119                    }
120                } else {
121                    let scale = total / current_sum;
122                    for k in keys {
123                        let v = *params.get(k).unwrap_or(&0.0);
124                        params.insert(k.clone(), v * scale);
125                    }
126                }
127                true
128            }
129            Constraint::Clamp { param, min, max } => {
130                let cur = *params.get(param).unwrap_or(&0.0);
131                let clamped = cur.clamp(*min, *max);
132                if (cur - clamped).abs() > EPS {
133                    params.insert(param.clone(), clamped);
134                    true
135                } else {
136                    false
137                }
138            }
139            Constraint::MaxSum {
140                param_a,
141                param_b,
142                max_sum,
143            } => {
144                let a = *params.get(param_a).unwrap_or(&0.0);
145                let b = *params.get(param_b).unwrap_or(&0.0);
146                let sum = a + b;
147                if sum > max_sum + 1e-6 {
148                    let scale = max_sum / sum;
149                    params.insert(param_a.clone(), a * scale);
150                    params.insert(param_b.clone(), b * scale);
151                    true
152                } else {
153                    false
154                }
155            }
156            Constraint::Conditional {
157                condition,
158                threshold,
159                target,
160                value,
161            } => {
162                let cond_val = *params.get(condition).unwrap_or(&0.0);
163                if cond_val >= *threshold {
164                    let cur = *params.get(target).unwrap_or(&0.0);
165                    if (cur - value).abs() > 1e-6 {
166                        params.insert(target.clone(), *value);
167                        return true;
168                    }
169                }
170                false
171            }
172        }
173    }
174
175    /// Check if the constraint is currently satisfied.
176    pub fn is_satisfied(&self, params: &Params) -> bool {
177        const EPS: f32 = 1e-5;
178        match self {
179            Constraint::MinRelative {
180                source,
181                target,
182                factor,
183                offset,
184            } => {
185                let src = *params.get(source).unwrap_or(&0.0);
186                let cur = *params.get(target).unwrap_or(&0.0);
187                cur >= src * factor + offset - EPS
188            }
189            Constraint::MaxRelative {
190                source,
191                target,
192                factor,
193                offset,
194            } => {
195                let src = *params.get(source).unwrap_or(&0.0);
196                let cur = *params.get(target).unwrap_or(&0.0);
197                cur <= src * factor + offset + EPS
198            }
199            Constraint::Driven {
200                source,
201                target,
202                factor,
203                offset,
204            } => {
205                let src = *params.get(source).unwrap_or(&0.0);
206                let expected = (src * factor + offset).clamp(0.0, 1.0);
207                let cur = *params.get(target).unwrap_or(&0.0);
208                (cur - expected).abs() <= EPS
209            }
210            Constraint::SumEquals {
211                params: keys,
212                total,
213            } => {
214                let s: f32 = keys.iter().map(|k| *params.get(k).unwrap_or(&0.0)).sum();
215                (s - total).abs() <= EPS
216            }
217            Constraint::Clamp { param, min, max } => {
218                let v = *params.get(param).unwrap_or(&0.0);
219                v >= *min - EPS && v <= *max + EPS
220            }
221            Constraint::MaxSum {
222                param_a,
223                param_b,
224                max_sum,
225            } => {
226                let a = *params.get(param_a).unwrap_or(&0.0);
227                let b = *params.get(param_b).unwrap_or(&0.0);
228                a + b <= max_sum + EPS
229            }
230            Constraint::Conditional {
231                condition,
232                threshold,
233                target,
234                value,
235            } => {
236                let cond_val = *params.get(condition).unwrap_or(&0.0);
237                if cond_val >= *threshold {
238                    let cur = *params.get(target).unwrap_or(&0.0);
239                    (cur - value).abs() <= EPS
240                } else {
241                    true
242                }
243            }
244        }
245    }
246
247    /// Name/description of constraint for debugging.
248    pub fn describe(&self) -> String {
249        match self {
250            Constraint::MinRelative {
251                source,
252                target,
253                factor,
254                offset,
255            } => format!(
256                "MinRelative: {} >= {} * {} + {}",
257                target, source, factor, offset
258            ),
259            Constraint::MaxRelative {
260                source,
261                target,
262                factor,
263                offset,
264            } => format!(
265                "MaxRelative: {} <= {} * {} + {}",
266                target, source, factor, offset
267            ),
268            Constraint::Driven {
269                source,
270                target,
271                factor,
272                offset,
273            } => format!(
274                "Driven: {} = {} * {} + {} (clamped to [0,1])",
275                target, source, factor, offset
276            ),
277            Constraint::SumEquals { params, total } => {
278                format!("SumEquals: {:?} sums to {}", params, total)
279            }
280            Constraint::Clamp { param, min, max } => {
281                format!("Clamp: {} in [{}, {}]", param, min, max)
282            }
283            Constraint::MaxSum {
284                param_a,
285                param_b,
286                max_sum,
287            } => format!("MaxSum: {} + {} <= {}", param_a, param_b, max_sum),
288            Constraint::Conditional {
289                condition,
290                threshold,
291                target,
292                value,
293            } => format!(
294                "Conditional: if {} >= {} then {} = {}",
295                condition, threshold, target, value
296            ),
297        }
298    }
299}
300
301/// Iterative constraint solver.
302pub struct ConstraintSolver {
303    constraints: Vec<Constraint>,
304    max_iterations: usize,
305    tolerance: f32,
306}
307
308/// Result of a solve pass.
309pub struct SolveResult {
310    pub iterations: usize,
311    pub converged: bool,
312    pub violations_remaining: usize,
313    pub changes_made: usize,
314}
315
316impl ConstraintSolver {
317    pub fn new() -> Self {
318        Self {
319            constraints: Vec::new(),
320            max_iterations: 100,
321            tolerance: 1e-5,
322        }
323    }
324
325    pub fn with_max_iterations(mut self, n: usize) -> Self {
326        self.max_iterations = n;
327        self
328    }
329
330    pub fn with_tolerance(mut self, tol: f32) -> Self {
331        self.tolerance = tol;
332        self
333    }
334
335    pub fn add(&mut self, constraint: Constraint) {
336        self.constraints.push(constraint);
337    }
338
339    pub fn constraint_count(&self) -> usize {
340        self.constraints.len()
341    }
342
343    pub fn remove(&mut self, index: usize) {
344        if index < self.constraints.len() {
345            self.constraints.remove(index);
346        }
347    }
348
349    /// Solve all constraints iteratively until convergence or max_iterations.
350    pub fn solve(&self, params: &mut Params) -> SolveResult {
351        let mut total_changes = 0usize;
352        let mut iterations = 0usize;
353        let mut converged = false;
354
355        for _ in 0..self.max_iterations {
356            iterations += 1;
357            let mut changed_this_iter = false;
358
359            for constraint in &self.constraints {
360                if constraint.apply(params) {
361                    total_changes += 1;
362                    changed_this_iter = true;
363                }
364            }
365
366            if !changed_this_iter {
367                converged = true;
368                break;
369            }
370        }
371
372        let violations_remaining = self.check_violations(params).len();
373
374        SolveResult {
375            iterations,
376            converged,
377            violations_remaining,
378            changes_made: total_changes,
379        }
380    }
381
382    /// Check which constraints (by index) are violated.
383    pub fn check_violations(&self, params: &Params) -> Vec<usize> {
384        self.constraints
385            .iter()
386            .enumerate()
387            .filter(|(_, c)| !c.is_satisfied(params))
388            .map(|(i, _)| i)
389            .collect()
390    }
391
392    /// True if all constraints are satisfied.
393    pub fn is_satisfied(&self, params: &Params) -> bool {
394        self.constraints.iter().all(|c| c.is_satisfied(params))
395    }
396}
397
398impl Default for ConstraintSolver {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404/// Human body BMI-related constraint presets.
405///
406/// Encodes: weight is driven by bmi_factor; muscle is bounded relative to weight.
407pub fn bmi_constraints() -> Vec<Constraint> {
408    vec![
409        // weight is driven directly by bmi_factor (normalized)
410        Constraint::Driven {
411            source: "bmi_factor".to_string(),
412            target: "weight".to_string(),
413            factor: 1.0,
414            offset: 0.0,
415        },
416        // muscle must be <= weight * 0.9 + 0.1 (heavier bodies can have more muscle)
417        Constraint::MaxRelative {
418            source: "weight".to_string(),
419            target: "muscle".to_string(),
420            factor: 0.9,
421            offset: 0.1,
422        },
423        // muscle must stay in [0, 1]
424        Constraint::Clamp {
425            param: "muscle".to_string(),
426            min: 0.0,
427            max: 1.0,
428        },
429        // weight must stay in [0, 1]
430        Constraint::Clamp {
431            param: "weight".to_string(),
432            min: 0.0,
433            max: 1.0,
434        },
435    ]
436}
437
438/// Proportion constraints linking limb dimensions to height.
439pub fn proportion_constraints() -> Vec<Constraint> {
440    vec![
441        // shoulder_width >= height * 0.3 + 0.05
442        Constraint::MinRelative {
443            source: "height".to_string(),
444            target: "shoulder_width".to_string(),
445            factor: 0.3,
446            offset: 0.05,
447        },
448        // shoulder_width <= height * 0.6 + 0.1
449        Constraint::MaxRelative {
450            source: "height".to_string(),
451            target: "shoulder_width".to_string(),
452            factor: 0.6,
453            offset: 0.1,
454        },
455        // shoulder_width must stay in [0, 1]
456        Constraint::Clamp {
457            param: "shoulder_width".to_string(),
458            min: 0.0,
459            max: 1.0,
460        },
461        // leg_length is driven by height
462        Constraint::Driven {
463            source: "height".to_string(),
464            target: "leg_length".to_string(),
465            factor: 0.85,
466            offset: 0.05,
467        },
468    ]
469}
470
471/// Age-related constraint presets.
472///
473/// Muscle mass decreases with age; body fat increases slightly.
474pub fn age_constraints() -> Vec<Constraint> {
475    vec![
476        // muscle decreases as age increases: muscle <= 1.0 - age * 0.5
477        Constraint::MaxRelative {
478            source: "age".to_string(),
479            target: "muscle".to_string(),
480            factor: -0.5,
481            offset: 1.0,
482        },
483        // at advanced age (>= 0.8), assign lower muscle cap
484        Constraint::Conditional {
485            condition: "age".to_string(),
486            threshold: 0.8,
487            target: "elderly_flag".to_string(),
488            value: 1.0,
489        },
490        // muscle clamped
491        Constraint::Clamp {
492            param: "muscle".to_string(),
493            min: 0.0,
494            max: 1.0,
495        },
496    ]
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    fn make_params(pairs: &[(&str, f32)]) -> Params {
504        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
505    }
506
507    #[test]
508    fn test_clamp_constraint() {
509        let c = Constraint::Clamp {
510            param: "height".to_string(),
511            min: 0.0,
512            max: 1.0,
513        };
514        let mut p = make_params(&[("height", 1.5)]);
515        assert!(c.apply(&mut p));
516        assert!((p["height"] - 1.0).abs() < 1e-6);
517
518        let mut p2 = make_params(&[("height", 0.5)]);
519        assert!(!c.apply(&mut p2));
520        assert!((p2["height"] - 0.5).abs() < 1e-6);
521    }
522
523    #[test]
524    fn test_min_relative_constraint() {
525        let c = Constraint::MinRelative {
526            source: "height".to_string(),
527            target: "shoulder_width".to_string(),
528            factor: 0.3,
529            offset: 0.0,
530        };
531        // height=0.8 → min shoulder=0.24; shoulder=0.1 should be raised
532        let mut p = make_params(&[("height", 0.8), ("shoulder_width", 0.1)]);
533        assert!(c.apply(&mut p));
534        assert!((p["shoulder_width"] - 0.24).abs() < 1e-5);
535
536        // shoulder already satisfies: no change
537        let mut p2 = make_params(&[("height", 0.8), ("shoulder_width", 0.5)]);
538        assert!(!c.apply(&mut p2));
539    }
540
541    #[test]
542    fn test_max_relative_constraint() {
543        let c = Constraint::MaxRelative {
544            source: "weight".to_string(),
545            target: "muscle".to_string(),
546            factor: 0.9,
547            offset: 0.1,
548        };
549        // weight=0.5 → max muscle=0.55; muscle=0.8 should be capped
550        let mut p = make_params(&[("weight", 0.5), ("muscle", 0.8)]);
551        assert!(c.apply(&mut p));
552        assert!((p["muscle"] - 0.55).abs() < 1e-5);
553
554        // muscle already satisfies: no change
555        let mut p2 = make_params(&[("weight", 0.5), ("muscle", 0.3)]);
556        assert!(!c.apply(&mut p2));
557    }
558
559    #[test]
560    fn test_driven_constraint() {
561        let c = Constraint::Driven {
562            source: "bmi_factor".to_string(),
563            target: "weight".to_string(),
564            factor: 1.0,
565            offset: 0.0,
566        };
567        let mut p = make_params(&[("bmi_factor", 0.7), ("weight", 0.0)]);
568        assert!(c.apply(&mut p));
569        assert!((p["weight"] - 0.7).abs() < 1e-6);
570
571        // Already set correctly
572        let mut p2 = make_params(&[("bmi_factor", 0.7), ("weight", 0.7)]);
573        assert!(!c.apply(&mut p2));
574
575        // Clamped to [0,1]
576        let mut p3 = make_params(&[("bmi_factor", 1.5), ("weight", 0.0)]);
577        assert!(c.apply(&mut p3));
578        assert!((p3["weight"] - 1.0).abs() < 1e-6);
579    }
580
581    #[test]
582    fn test_sum_equals_constraint() {
583        let c = Constraint::SumEquals {
584            params: vec!["a".to_string(), "b".to_string(), "c".to_string()],
585            total: 1.0,
586        };
587        let mut p = make_params(&[("a", 0.2), ("b", 0.3), ("c", 0.5)]);
588        // sum = 1.0, already satisfied
589        assert!(!c.apply(&mut p));
590
591        let mut p2 = make_params(&[("a", 0.5), ("b", 0.5), ("c", 0.5)]);
592        // sum = 1.5, scale to 1.0
593        assert!(c.apply(&mut p2));
594        let new_sum: f32 = ["a", "b", "c"].iter().map(|k| p2[*k]).sum();
595        assert!((new_sum - 1.0).abs() < 1e-5);
596
597        // Zero-sum case: distribute equally
598        let mut p3 = make_params(&[("a", 0.0), ("b", 0.0), ("c", 0.0)]);
599        assert!(c.apply(&mut p3));
600        assert!((p3["a"] - 1.0 / 3.0).abs() < 1e-5);
601    }
602
603    #[test]
604    fn test_max_sum_constraint() {
605        let c = Constraint::MaxSum {
606            param_a: "muscle".to_string(),
607            param_b: "fat".to_string(),
608            max_sum: 1.0,
609        };
610        let mut p = make_params(&[("muscle", 0.7), ("fat", 0.6)]);
611        assert!(c.apply(&mut p));
612        let s = p["muscle"] + p["fat"];
613        assert!((s - 1.0).abs() < 1e-5);
614
615        let mut p2 = make_params(&[("muscle", 0.4), ("fat", 0.4)]);
616        assert!(!c.apply(&mut p2));
617    }
618
619    #[test]
620    fn test_conditional_constraint() {
621        let c = Constraint::Conditional {
622            condition: "age".to_string(),
623            threshold: 0.8,
624            target: "elderly_flag".to_string(),
625            value: 1.0,
626        };
627        // age >= 0.8 → elderly_flag should be set to 1.0
628        let mut p = make_params(&[("age", 0.9), ("elderly_flag", 0.0)]);
629        assert!(c.apply(&mut p));
630        assert!((p["elderly_flag"] - 1.0).abs() < 1e-6);
631
632        // age < 0.8 → no change
633        let mut p2 = make_params(&[("age", 0.5), ("elderly_flag", 0.0)]);
634        assert!(!c.apply(&mut p2));
635    }
636
637    #[test]
638    fn test_constraint_is_satisfied() {
639        let c = Constraint::Clamp {
640            param: "x".to_string(),
641            min: 0.0,
642            max: 1.0,
643        };
644        let p_ok = make_params(&[("x", 0.5)]);
645        assert!(c.is_satisfied(&p_ok));
646
647        let p_bad = make_params(&[("x", 1.5)]);
648        assert!(!c.is_satisfied(&p_bad));
649
650        // Driven
651        let d = Constraint::Driven {
652            source: "s".to_string(),
653            target: "t".to_string(),
654            factor: 2.0,
655            offset: 0.0,
656        };
657        // s=0.4 → t should be 0.8
658        let p_driven_ok = make_params(&[("s", 0.4), ("t", 0.8)]);
659        assert!(d.is_satisfied(&p_driven_ok));
660
661        let p_driven_bad = make_params(&[("s", 0.4), ("t", 0.5)]);
662        assert!(!d.is_satisfied(&p_driven_bad));
663    }
664
665    #[test]
666    fn test_solver_new() {
667        let s = ConstraintSolver::new();
668        assert_eq!(s.constraint_count(), 0);
669        assert_eq!(s.max_iterations, 100);
670    }
671
672    #[test]
673    fn test_solver_add_and_count() {
674        let mut s = ConstraintSolver::new();
675        s.add(Constraint::Clamp {
676            param: "x".to_string(),
677            min: 0.0,
678            max: 1.0,
679        });
680        s.add(Constraint::Clamp {
681            param: "y".to_string(),
682            min: 0.0,
683            max: 1.0,
684        });
685        assert_eq!(s.constraint_count(), 2);
686        s.remove(0);
687        assert_eq!(s.constraint_count(), 1);
688    }
689
690    #[test]
691    fn test_solver_solve_converges() {
692        let mut s = ConstraintSolver::new();
693        s.add(Constraint::Clamp {
694            param: "height".to_string(),
695            min: 0.0,
696            max: 1.0,
697        });
698        s.add(Constraint::Clamp {
699            param: "weight".to_string(),
700            min: 0.0,
701            max: 1.0,
702        });
703
704        let mut p = make_params(&[("height", 1.5), ("weight", -0.2)]);
705        let result = s.solve(&mut p);
706
707        assert!(result.converged);
708        assert!(result.violations_remaining == 0);
709        assert!((p["height"] - 1.0).abs() < 1e-5);
710        assert!((p["weight"] - 0.0).abs() < 1e-5);
711    }
712
713    #[test]
714    fn test_solver_check_violations() {
715        let mut s = ConstraintSolver::new();
716        s.add(Constraint::Clamp {
717            param: "x".to_string(),
718            min: 0.0,
719            max: 1.0,
720        });
721        s.add(Constraint::Clamp {
722            param: "y".to_string(),
723            min: 0.0,
724            max: 1.0,
725        });
726
727        let p = make_params(&[("x", 1.5), ("y", 0.5)]);
728        let violations = s.check_violations(&p);
729        assert_eq!(violations, vec![0]);
730
731        let p2 = make_params(&[("x", 0.5), ("y", 0.5)]);
732        assert!(s.check_violations(&p2).is_empty());
733        assert!(s.is_satisfied(&p2));
734    }
735
736    #[test]
737    fn test_bmi_constraints() {
738        let constraints = bmi_constraints();
739        assert!(!constraints.is_empty());
740
741        let mut s = ConstraintSolver::new();
742        for c in constraints {
743            s.add(c);
744        }
745
746        let mut p = make_params(&[("bmi_factor", 0.6), ("weight", 0.0), ("muscle", 0.9)]);
747        let result = s.solve(&mut p);
748        assert!(result.converged);
749        // weight driven to bmi_factor
750        assert!((p["weight"] - 0.6).abs() < 1e-5);
751        // muscle <= weight * 0.9 + 0.1 = 0.64
752        assert!(p["muscle"] <= 0.64 + 1e-4);
753    }
754
755    #[test]
756    fn test_solve_result_fields() {
757        let mut s = ConstraintSolver::new();
758        s.add(Constraint::Clamp {
759            param: "z".to_string(),
760            min: 0.0,
761            max: 1.0,
762        });
763
764        let mut p = make_params(&[("z", 2.0)]);
765        let r = s.solve(&mut p);
766
767        assert!(r.iterations >= 1); // At least one iteration ran
768        assert!(r.converged);
769        assert_eq!(r.violations_remaining, 0);
770        assert_eq!(r.changes_made, 1);
771    }
772}