1#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub type Params = HashMap<String, f32>;
9
10#[derive(Clone, Debug)]
12pub enum Constraint {
13 MinRelative {
15 source: String,
16 target: String,
17 factor: f32,
18 offset: f32,
19 },
20 MaxRelative {
22 source: String,
23 target: String,
24 factor: f32,
25 offset: f32,
26 },
27 Driven {
29 source: String,
30 target: String,
31 factor: f32,
32 offset: f32,
33 },
34 SumEquals { params: Vec<String>, total: f32 },
36 Clamp { param: String, min: f32, max: f32 },
38 MaxSum {
40 param_a: String,
41 param_b: String,
42 max_sum: f32,
43 },
44 Conditional {
46 condition: String,
47 threshold: f32,
48 target: String,
49 value: f32,
50 },
51}
52
53impl Constraint {
54 pub fn apply(&self, params: &mut Params) -> bool {
56 const EPS: f32 = 1e-6;
57 match self {
58 Constraint::MinRelative {
59 source,
60 target,
61 factor,
62 offset,
63 } => {
64 let src = *params.get(source).unwrap_or(&0.0);
65 let min_val = src * factor + offset;
66 let cur = *params.get(target).unwrap_or(&0.0);
67 if cur < min_val - EPS {
68 params.insert(target.clone(), min_val);
69 true
70 } else {
71 false
72 }
73 }
74 Constraint::MaxRelative {
75 source,
76 target,
77 factor,
78 offset,
79 } => {
80 let src = *params.get(source).unwrap_or(&0.0);
81 let max_val = src * factor + offset;
82 let cur = *params.get(target).unwrap_or(&0.0);
83 if cur > max_val + EPS {
84 params.insert(target.clone(), max_val);
85 true
86 } else {
87 false
88 }
89 }
90 Constraint::Driven {
91 source,
92 target,
93 factor,
94 offset,
95 } => {
96 let src = *params.get(source).unwrap_or(&0.0);
97 let new_val = (src * factor + offset).clamp(0.0, 1.0);
98 let cur = *params.get(target).unwrap_or(&0.0);
99 if (cur - new_val).abs() > EPS {
100 params.insert(target.clone(), new_val);
101 true
102 } else {
103 false
104 }
105 }
106 Constraint::SumEquals {
107 params: keys,
108 total,
109 } => {
110 let current_sum: f32 = keys.iter().map(|k| *params.get(k).unwrap_or(&0.0)).sum();
111 if (current_sum - total).abs() <= EPS {
112 return false;
113 }
114 if current_sum.abs() < EPS {
115 let equal_share = total / keys.len() as f32;
117 for k in keys {
118 params.insert(k.clone(), equal_share);
119 }
120 } else {
121 let scale = total / current_sum;
122 for k in keys {
123 let v = *params.get(k).unwrap_or(&0.0);
124 params.insert(k.clone(), v * scale);
125 }
126 }
127 true
128 }
129 Constraint::Clamp { param, min, max } => {
130 let cur = *params.get(param).unwrap_or(&0.0);
131 let clamped = cur.clamp(*min, *max);
132 if (cur - clamped).abs() > EPS {
133 params.insert(param.clone(), clamped);
134 true
135 } else {
136 false
137 }
138 }
139 Constraint::MaxSum {
140 param_a,
141 param_b,
142 max_sum,
143 } => {
144 let a = *params.get(param_a).unwrap_or(&0.0);
145 let b = *params.get(param_b).unwrap_or(&0.0);
146 let sum = a + b;
147 if sum > max_sum + 1e-6 {
148 let scale = max_sum / sum;
149 params.insert(param_a.clone(), a * scale);
150 params.insert(param_b.clone(), b * scale);
151 true
152 } else {
153 false
154 }
155 }
156 Constraint::Conditional {
157 condition,
158 threshold,
159 target,
160 value,
161 } => {
162 let cond_val = *params.get(condition).unwrap_or(&0.0);
163 if cond_val >= *threshold {
164 let cur = *params.get(target).unwrap_or(&0.0);
165 if (cur - value).abs() > 1e-6 {
166 params.insert(target.clone(), *value);
167 return true;
168 }
169 }
170 false
171 }
172 }
173 }
174
175 pub fn is_satisfied(&self, params: &Params) -> bool {
177 const EPS: f32 = 1e-5;
178 match self {
179 Constraint::MinRelative {
180 source,
181 target,
182 factor,
183 offset,
184 } => {
185 let src = *params.get(source).unwrap_or(&0.0);
186 let cur = *params.get(target).unwrap_or(&0.0);
187 cur >= src * factor + offset - EPS
188 }
189 Constraint::MaxRelative {
190 source,
191 target,
192 factor,
193 offset,
194 } => {
195 let src = *params.get(source).unwrap_or(&0.0);
196 let cur = *params.get(target).unwrap_or(&0.0);
197 cur <= src * factor + offset + EPS
198 }
199 Constraint::Driven {
200 source,
201 target,
202 factor,
203 offset,
204 } => {
205 let src = *params.get(source).unwrap_or(&0.0);
206 let expected = (src * factor + offset).clamp(0.0, 1.0);
207 let cur = *params.get(target).unwrap_or(&0.0);
208 (cur - expected).abs() <= EPS
209 }
210 Constraint::SumEquals {
211 params: keys,
212 total,
213 } => {
214 let s: f32 = keys.iter().map(|k| *params.get(k).unwrap_or(&0.0)).sum();
215 (s - total).abs() <= EPS
216 }
217 Constraint::Clamp { param, min, max } => {
218 let v = *params.get(param).unwrap_or(&0.0);
219 v >= *min - EPS && v <= *max + EPS
220 }
221 Constraint::MaxSum {
222 param_a,
223 param_b,
224 max_sum,
225 } => {
226 let a = *params.get(param_a).unwrap_or(&0.0);
227 let b = *params.get(param_b).unwrap_or(&0.0);
228 a + b <= max_sum + EPS
229 }
230 Constraint::Conditional {
231 condition,
232 threshold,
233 target,
234 value,
235 } => {
236 let cond_val = *params.get(condition).unwrap_or(&0.0);
237 if cond_val >= *threshold {
238 let cur = *params.get(target).unwrap_or(&0.0);
239 (cur - value).abs() <= EPS
240 } else {
241 true
242 }
243 }
244 }
245 }
246
247 pub fn describe(&self) -> String {
249 match self {
250 Constraint::MinRelative {
251 source,
252 target,
253 factor,
254 offset,
255 } => format!(
256 "MinRelative: {} >= {} * {} + {}",
257 target, source, factor, offset
258 ),
259 Constraint::MaxRelative {
260 source,
261 target,
262 factor,
263 offset,
264 } => format!(
265 "MaxRelative: {} <= {} * {} + {}",
266 target, source, factor, offset
267 ),
268 Constraint::Driven {
269 source,
270 target,
271 factor,
272 offset,
273 } => format!(
274 "Driven: {} = {} * {} + {} (clamped to [0,1])",
275 target, source, factor, offset
276 ),
277 Constraint::SumEquals { params, total } => {
278 format!("SumEquals: {:?} sums to {}", params, total)
279 }
280 Constraint::Clamp { param, min, max } => {
281 format!("Clamp: {} in [{}, {}]", param, min, max)
282 }
283 Constraint::MaxSum {
284 param_a,
285 param_b,
286 max_sum,
287 } => format!("MaxSum: {} + {} <= {}", param_a, param_b, max_sum),
288 Constraint::Conditional {
289 condition,
290 threshold,
291 target,
292 value,
293 } => format!(
294 "Conditional: if {} >= {} then {} = {}",
295 condition, threshold, target, value
296 ),
297 }
298 }
299}
300
301pub struct ConstraintSolver {
303 constraints: Vec<Constraint>,
304 max_iterations: usize,
305 tolerance: f32,
306}
307
308pub struct SolveResult {
310 pub iterations: usize,
311 pub converged: bool,
312 pub violations_remaining: usize,
313 pub changes_made: usize,
314}
315
316impl ConstraintSolver {
317 pub fn new() -> Self {
318 Self {
319 constraints: Vec::new(),
320 max_iterations: 100,
321 tolerance: 1e-5,
322 }
323 }
324
325 pub fn with_max_iterations(mut self, n: usize) -> Self {
326 self.max_iterations = n;
327 self
328 }
329
330 pub fn with_tolerance(mut self, tol: f32) -> Self {
331 self.tolerance = tol;
332 self
333 }
334
335 pub fn add(&mut self, constraint: Constraint) {
336 self.constraints.push(constraint);
337 }
338
339 pub fn constraint_count(&self) -> usize {
340 self.constraints.len()
341 }
342
343 pub fn remove(&mut self, index: usize) {
344 if index < self.constraints.len() {
345 self.constraints.remove(index);
346 }
347 }
348
349 pub fn solve(&self, params: &mut Params) -> SolveResult {
351 let mut total_changes = 0usize;
352 let mut iterations = 0usize;
353 let mut converged = false;
354
355 for _ in 0..self.max_iterations {
356 iterations += 1;
357 let mut changed_this_iter = false;
358
359 for constraint in &self.constraints {
360 if constraint.apply(params) {
361 total_changes += 1;
362 changed_this_iter = true;
363 }
364 }
365
366 if !changed_this_iter {
367 converged = true;
368 break;
369 }
370 }
371
372 let violations_remaining = self.check_violations(params).len();
373
374 SolveResult {
375 iterations,
376 converged,
377 violations_remaining,
378 changes_made: total_changes,
379 }
380 }
381
382 pub fn check_violations(&self, params: &Params) -> Vec<usize> {
384 self.constraints
385 .iter()
386 .enumerate()
387 .filter(|(_, c)| !c.is_satisfied(params))
388 .map(|(i, _)| i)
389 .collect()
390 }
391
392 pub fn is_satisfied(&self, params: &Params) -> bool {
394 self.constraints.iter().all(|c| c.is_satisfied(params))
395 }
396}
397
398impl Default for ConstraintSolver {
399 fn default() -> Self {
400 Self::new()
401 }
402}
403
404pub fn bmi_constraints() -> Vec<Constraint> {
408 vec![
409 Constraint::Driven {
411 source: "bmi_factor".to_string(),
412 target: "weight".to_string(),
413 factor: 1.0,
414 offset: 0.0,
415 },
416 Constraint::MaxRelative {
418 source: "weight".to_string(),
419 target: "muscle".to_string(),
420 factor: 0.9,
421 offset: 0.1,
422 },
423 Constraint::Clamp {
425 param: "muscle".to_string(),
426 min: 0.0,
427 max: 1.0,
428 },
429 Constraint::Clamp {
431 param: "weight".to_string(),
432 min: 0.0,
433 max: 1.0,
434 },
435 ]
436}
437
438pub fn proportion_constraints() -> Vec<Constraint> {
440 vec![
441 Constraint::MinRelative {
443 source: "height".to_string(),
444 target: "shoulder_width".to_string(),
445 factor: 0.3,
446 offset: 0.05,
447 },
448 Constraint::MaxRelative {
450 source: "height".to_string(),
451 target: "shoulder_width".to_string(),
452 factor: 0.6,
453 offset: 0.1,
454 },
455 Constraint::Clamp {
457 param: "shoulder_width".to_string(),
458 min: 0.0,
459 max: 1.0,
460 },
461 Constraint::Driven {
463 source: "height".to_string(),
464 target: "leg_length".to_string(),
465 factor: 0.85,
466 offset: 0.05,
467 },
468 ]
469}
470
471pub fn age_constraints() -> Vec<Constraint> {
475 vec![
476 Constraint::MaxRelative {
478 source: "age".to_string(),
479 target: "muscle".to_string(),
480 factor: -0.5,
481 offset: 1.0,
482 },
483 Constraint::Conditional {
485 condition: "age".to_string(),
486 threshold: 0.8,
487 target: "elderly_flag".to_string(),
488 value: 1.0,
489 },
490 Constraint::Clamp {
492 param: "muscle".to_string(),
493 min: 0.0,
494 max: 1.0,
495 },
496 ]
497}
498
499#[cfg(test)]
500mod tests {
501 use super::*;
502
503 fn make_params(pairs: &[(&str, f32)]) -> Params {
504 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
505 }
506
507 #[test]
508 fn test_clamp_constraint() {
509 let c = Constraint::Clamp {
510 param: "height".to_string(),
511 min: 0.0,
512 max: 1.0,
513 };
514 let mut p = make_params(&[("height", 1.5)]);
515 assert!(c.apply(&mut p));
516 assert!((p["height"] - 1.0).abs() < 1e-6);
517
518 let mut p2 = make_params(&[("height", 0.5)]);
519 assert!(!c.apply(&mut p2));
520 assert!((p2["height"] - 0.5).abs() < 1e-6);
521 }
522
523 #[test]
524 fn test_min_relative_constraint() {
525 let c = Constraint::MinRelative {
526 source: "height".to_string(),
527 target: "shoulder_width".to_string(),
528 factor: 0.3,
529 offset: 0.0,
530 };
531 let mut p = make_params(&[("height", 0.8), ("shoulder_width", 0.1)]);
533 assert!(c.apply(&mut p));
534 assert!((p["shoulder_width"] - 0.24).abs() < 1e-5);
535
536 let mut p2 = make_params(&[("height", 0.8), ("shoulder_width", 0.5)]);
538 assert!(!c.apply(&mut p2));
539 }
540
541 #[test]
542 fn test_max_relative_constraint() {
543 let c = Constraint::MaxRelative {
544 source: "weight".to_string(),
545 target: "muscle".to_string(),
546 factor: 0.9,
547 offset: 0.1,
548 };
549 let mut p = make_params(&[("weight", 0.5), ("muscle", 0.8)]);
551 assert!(c.apply(&mut p));
552 assert!((p["muscle"] - 0.55).abs() < 1e-5);
553
554 let mut p2 = make_params(&[("weight", 0.5), ("muscle", 0.3)]);
556 assert!(!c.apply(&mut p2));
557 }
558
559 #[test]
560 fn test_driven_constraint() {
561 let c = Constraint::Driven {
562 source: "bmi_factor".to_string(),
563 target: "weight".to_string(),
564 factor: 1.0,
565 offset: 0.0,
566 };
567 let mut p = make_params(&[("bmi_factor", 0.7), ("weight", 0.0)]);
568 assert!(c.apply(&mut p));
569 assert!((p["weight"] - 0.7).abs() < 1e-6);
570
571 let mut p2 = make_params(&[("bmi_factor", 0.7), ("weight", 0.7)]);
573 assert!(!c.apply(&mut p2));
574
575 let mut p3 = make_params(&[("bmi_factor", 1.5), ("weight", 0.0)]);
577 assert!(c.apply(&mut p3));
578 assert!((p3["weight"] - 1.0).abs() < 1e-6);
579 }
580
581 #[test]
582 fn test_sum_equals_constraint() {
583 let c = Constraint::SumEquals {
584 params: vec!["a".to_string(), "b".to_string(), "c".to_string()],
585 total: 1.0,
586 };
587 let mut p = make_params(&[("a", 0.2), ("b", 0.3), ("c", 0.5)]);
588 assert!(!c.apply(&mut p));
590
591 let mut p2 = make_params(&[("a", 0.5), ("b", 0.5), ("c", 0.5)]);
592 assert!(c.apply(&mut p2));
594 let new_sum: f32 = ["a", "b", "c"].iter().map(|k| p2[*k]).sum();
595 assert!((new_sum - 1.0).abs() < 1e-5);
596
597 let mut p3 = make_params(&[("a", 0.0), ("b", 0.0), ("c", 0.0)]);
599 assert!(c.apply(&mut p3));
600 assert!((p3["a"] - 1.0 / 3.0).abs() < 1e-5);
601 }
602
603 #[test]
604 fn test_max_sum_constraint() {
605 let c = Constraint::MaxSum {
606 param_a: "muscle".to_string(),
607 param_b: "fat".to_string(),
608 max_sum: 1.0,
609 };
610 let mut p = make_params(&[("muscle", 0.7), ("fat", 0.6)]);
611 assert!(c.apply(&mut p));
612 let s = p["muscle"] + p["fat"];
613 assert!((s - 1.0).abs() < 1e-5);
614
615 let mut p2 = make_params(&[("muscle", 0.4), ("fat", 0.4)]);
616 assert!(!c.apply(&mut p2));
617 }
618
619 #[test]
620 fn test_conditional_constraint() {
621 let c = Constraint::Conditional {
622 condition: "age".to_string(),
623 threshold: 0.8,
624 target: "elderly_flag".to_string(),
625 value: 1.0,
626 };
627 let mut p = make_params(&[("age", 0.9), ("elderly_flag", 0.0)]);
629 assert!(c.apply(&mut p));
630 assert!((p["elderly_flag"] - 1.0).abs() < 1e-6);
631
632 let mut p2 = make_params(&[("age", 0.5), ("elderly_flag", 0.0)]);
634 assert!(!c.apply(&mut p2));
635 }
636
637 #[test]
638 fn test_constraint_is_satisfied() {
639 let c = Constraint::Clamp {
640 param: "x".to_string(),
641 min: 0.0,
642 max: 1.0,
643 };
644 let p_ok = make_params(&[("x", 0.5)]);
645 assert!(c.is_satisfied(&p_ok));
646
647 let p_bad = make_params(&[("x", 1.5)]);
648 assert!(!c.is_satisfied(&p_bad));
649
650 let d = Constraint::Driven {
652 source: "s".to_string(),
653 target: "t".to_string(),
654 factor: 2.0,
655 offset: 0.0,
656 };
657 let p_driven_ok = make_params(&[("s", 0.4), ("t", 0.8)]);
659 assert!(d.is_satisfied(&p_driven_ok));
660
661 let p_driven_bad = make_params(&[("s", 0.4), ("t", 0.5)]);
662 assert!(!d.is_satisfied(&p_driven_bad));
663 }
664
665 #[test]
666 fn test_solver_new() {
667 let s = ConstraintSolver::new();
668 assert_eq!(s.constraint_count(), 0);
669 assert_eq!(s.max_iterations, 100);
670 }
671
672 #[test]
673 fn test_solver_add_and_count() {
674 let mut s = ConstraintSolver::new();
675 s.add(Constraint::Clamp {
676 param: "x".to_string(),
677 min: 0.0,
678 max: 1.0,
679 });
680 s.add(Constraint::Clamp {
681 param: "y".to_string(),
682 min: 0.0,
683 max: 1.0,
684 });
685 assert_eq!(s.constraint_count(), 2);
686 s.remove(0);
687 assert_eq!(s.constraint_count(), 1);
688 }
689
690 #[test]
691 fn test_solver_solve_converges() {
692 let mut s = ConstraintSolver::new();
693 s.add(Constraint::Clamp {
694 param: "height".to_string(),
695 min: 0.0,
696 max: 1.0,
697 });
698 s.add(Constraint::Clamp {
699 param: "weight".to_string(),
700 min: 0.0,
701 max: 1.0,
702 });
703
704 let mut p = make_params(&[("height", 1.5), ("weight", -0.2)]);
705 let result = s.solve(&mut p);
706
707 assert!(result.converged);
708 assert!(result.violations_remaining == 0);
709 assert!((p["height"] - 1.0).abs() < 1e-5);
710 assert!((p["weight"] - 0.0).abs() < 1e-5);
711 }
712
713 #[test]
714 fn test_solver_check_violations() {
715 let mut s = ConstraintSolver::new();
716 s.add(Constraint::Clamp {
717 param: "x".to_string(),
718 min: 0.0,
719 max: 1.0,
720 });
721 s.add(Constraint::Clamp {
722 param: "y".to_string(),
723 min: 0.0,
724 max: 1.0,
725 });
726
727 let p = make_params(&[("x", 1.5), ("y", 0.5)]);
728 let violations = s.check_violations(&p);
729 assert_eq!(violations, vec![0]);
730
731 let p2 = make_params(&[("x", 0.5), ("y", 0.5)]);
732 assert!(s.check_violations(&p2).is_empty());
733 assert!(s.is_satisfied(&p2));
734 }
735
736 #[test]
737 fn test_bmi_constraints() {
738 let constraints = bmi_constraints();
739 assert!(!constraints.is_empty());
740
741 let mut s = ConstraintSolver::new();
742 for c in constraints {
743 s.add(c);
744 }
745
746 let mut p = make_params(&[("bmi_factor", 0.6), ("weight", 0.0), ("muscle", 0.9)]);
747 let result = s.solve(&mut p);
748 assert!(result.converged);
749 assert!((p["weight"] - 0.6).abs() < 1e-5);
751 assert!(p["muscle"] <= 0.64 + 1e-4);
753 }
754
755 #[test]
756 fn test_solve_result_fields() {
757 let mut s = ConstraintSolver::new();
758 s.add(Constraint::Clamp {
759 param: "z".to_string(),
760 min: 0.0,
761 max: 1.0,
762 });
763
764 let mut p = make_params(&[("z", 2.0)]);
765 let r = s.solve(&mut p);
766
767 assert!(r.iterations >= 1); assert!(r.converged);
769 assert_eq!(r.violations_remaining, 0);
770 assert_eq!(r.changes_made, 1);
771 }
772}