1#![allow(dead_code)]
5#![allow(clippy::too_many_arguments)]
6
7use std::collections::HashMap;
13
14pub struct CrowdConfig {
20 pub count: usize,
22 pub seed: u32,
24 pub height_range: (f32, f32),
26 pub weight_range: (f32, f32),
28 pub age_range: (f32, f32),
30 pub muscle_range: (f32, f32),
32 pub diversity_target: f32,
34 pub allow_duplicates: bool,
36 pub extra_params: HashMap<String, (f32, f32)>,
38}
39
40impl Default for CrowdConfig {
41 fn default() -> Self {
42 Self {
43 count: 10,
44 seed: 42,
45 height_range: (0.0, 1.0),
46 weight_range: (0.0, 1.0),
47 age_range: (0.0, 1.0),
48 muscle_range: (0.0, 1.0),
49 diversity_target: 0.5,
50 allow_duplicates: true,
51 extra_params: HashMap::new(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub enum VariationClass {
63 Petite,
65 Slim,
67 Average,
69 Athletic,
71 Stocky,
73 Tall,
75 Heavy,
77 Custom,
79}
80
81impl VariationClass {
82 pub fn classify(params: &HashMap<String, f32>) -> VariationClass {
84 let height = params.get("height").copied().unwrap_or(0.5);
85 let weight = params.get("weight").copied().unwrap_or(0.5);
86 let muscle = params.get("muscle").copied().unwrap_or(0.3);
87
88 let lo = 0.35_f32;
90 let hi = 0.65_f32;
91
92 let short = height < lo;
93 let tall = height > hi;
94 let light = weight < lo;
95 let heavy_w = weight > hi;
96 let avg_h = !short && !tall;
97 let avg_w = !light && !heavy_w;
98 let muscular = muscle > hi;
99
100 if short && light {
101 VariationClass::Petite
102 } else if avg_h && light {
103 VariationClass::Slim
104 } else if short && heavy_w {
105 VariationClass::Stocky
106 } else if heavy_w {
107 VariationClass::Heavy
108 } else if tall && avg_w {
109 VariationClass::Tall
110 } else if muscular && avg_w {
111 VariationClass::Athletic
112 } else if avg_h && avg_w && !muscular {
113 VariationClass::Average
114 } else {
115 VariationClass::Custom
116 }
117 }
118
119 pub fn all() -> Vec<VariationClass> {
121 vec![
122 VariationClass::Petite,
123 VariationClass::Slim,
124 VariationClass::Average,
125 VariationClass::Athletic,
126 VariationClass::Stocky,
127 VariationClass::Tall,
128 VariationClass::Heavy,
129 VariationClass::Custom,
130 ]
131 }
132
133 pub fn name(&self) -> &'static str {
135 match self {
136 VariationClass::Petite => "Petite",
137 VariationClass::Slim => "Slim",
138 VariationClass::Average => "Average",
139 VariationClass::Athletic => "Athletic",
140 VariationClass::Stocky => "Stocky",
141 VariationClass::Tall => "Tall",
142 VariationClass::Heavy => "Heavy",
143 VariationClass::Custom => "Custom",
144 }
145 }
146}
147
148pub struct CrowdCharacter {
154 pub id: usize,
156 pub params: HashMap<String, f32>,
158 pub variation_class: VariationClass,
160}
161
162pub struct Crowd {
168 pub characters: Vec<CrowdCharacter>,
170 pub config: CrowdConfig,
172}
173
174impl Crowd {
175 pub fn count(&self) -> usize {
177 self.characters.len()
178 }
179
180 pub fn get(&self, id: usize) -> Option<&CrowdCharacter> {
182 self.characters.iter().find(|c| c.id == id)
183 }
184
185 pub fn mean_params(&self) -> HashMap<String, f32> {
191 if self.characters.is_empty() {
192 return HashMap::new();
193 }
194 let n = self.characters.len() as f32;
195 let mut sums: HashMap<String, f32> = HashMap::new();
196 for ch in &self.characters {
197 for (k, v) in &ch.params {
198 *sums.entry(k.clone()).or_insert(0.0) += v;
199 }
200 }
201 sums.iter_mut().for_each(|(_, v)| *v /= n);
202 sums
203 }
204
205 pub fn std_params(&self) -> HashMap<String, f32> {
207 if self.characters.len() < 2 {
208 return HashMap::new();
209 }
210 let means = self.mean_params();
211 let n = self.characters.len() as f32;
212 let mut sq_sums: HashMap<String, f32> = HashMap::new();
213 for ch in &self.characters {
214 for (k, v) in &ch.params {
215 let mean = means.get(k).copied().unwrap_or(0.0);
216 let d = v - mean;
217 *sq_sums.entry(k.clone()).or_insert(0.0) += d * d;
218 }
219 }
220 sq_sums.iter_mut().for_each(|(_, v)| *v = (*v / n).sqrt());
221 sq_sums
222 }
223
224 pub fn diversity_score(&self) -> f32 {
226 let n = self.characters.len();
227 if n < 2 {
228 return 0.0;
229 }
230 let mut total = 0.0_f32;
231 let mut count = 0usize;
232 for i in 0..n {
233 for j in (i + 1)..n {
234 total += param_distance(&self.characters[i].params, &self.characters[j].params);
235 count += 1;
236 }
237 }
238 if count == 0 {
239 0.0
240 } else {
241 total / count as f32
242 }
243 }
244
245 pub fn by_class(&self, class: &VariationClass) -> Vec<&CrowdCharacter> {
251 self.characters
252 .iter()
253 .filter(|c| &c.variation_class == class)
254 .collect()
255 }
256
257 pub fn class_distribution(&self) -> HashMap<VariationClass, usize> {
259 let mut dist: HashMap<VariationClass, usize> = HashMap::new();
260 for ch in &self.characters {
261 *dist.entry(ch.variation_class.clone()).or_insert(0) += 1;
262 }
263 dist
264 }
265
266 pub fn sorted_by(&self, param: &str) -> Vec<&CrowdCharacter> {
270 let mut refs: Vec<&CrowdCharacter> = self.characters.iter().collect();
271 refs.sort_by(|a, b| {
272 let va = a.params.get(param).copied().unwrap_or(f32::MAX);
273 let vb = b.params.get(param).copied().unwrap_or(f32::MAX);
274 va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
275 });
276 refs
277 }
278
279 pub fn to_param_list(&self) -> Vec<HashMap<String, f32>> {
281 self.characters.iter().map(|c| c.params.clone()).collect()
282 }
283
284 pub fn summary(&self) -> String {
286 let mean = self.mean_params();
287 let std = self.std_params();
288 let dist = self.class_distribution();
289 let diversity = self.diversity_score();
290
291 let mut lines = Vec::new();
292 lines.push(format!(
293 "{{ \"count\": {}, \"diversity_score\": {:.4},",
294 self.count(),
295 diversity
296 ));
297 lines.push(" \"mean_params\": {".to_string());
298
299 let mut keys: Vec<&String> = mean.keys().collect();
300 keys.sort();
301 for (idx, k) in keys.iter().enumerate() {
302 let m = mean[*k];
303 let s = std.get(*k).copied().unwrap_or(0.0);
304 let comma = if idx + 1 < keys.len() { "," } else { "" };
305 lines.push(format!(
306 " \"{}\": {{ \"mean\": {:.4}, \"std\": {:.4} }}{}",
307 k, m, s, comma
308 ));
309 }
310 lines.push(" },".to_string());
311 lines.push(" \"class_distribution\": {".to_string());
312
313 let mut class_entries: Vec<(String, usize)> = dist
314 .iter()
315 .map(|(c, n)| (c.name().to_string(), *n))
316 .collect();
317 class_entries.sort_by(|a, b| a.0.cmp(&b.0));
318 for (idx, (name, count)) in class_entries.iter().enumerate() {
319 let comma = if idx + 1 < class_entries.len() {
320 ","
321 } else {
322 ""
323 };
324 lines.push(format!(" \"{}\": {}{}", name, count, comma));
325 }
326 lines.push(" }".to_string());
327 lines.push("}".to_string());
328 lines.join("\n")
329 }
330}
331
332pub fn halton(i: usize, base: usize) -> f32 {
340 let mut result = 0.0_f64;
341 let mut denom = 1.0_f64;
342 let mut n = i;
343 while n > 0 {
344 denom *= base as f64;
345 result += (n % base) as f64 / denom;
346 n /= base;
347 }
348 result as f32
349}
350
351pub fn lcg_rand(seed: &mut u32) -> f32 {
355 *seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
356 (*seed >> 9) as f32 / (1u32 << 23) as f32
357}
358
359pub fn param_distance(a: &HashMap<String, f32>, b: &HashMap<String, f32>) -> f32 {
363 let sq: f32 = a
364 .iter()
365 .filter_map(|(k, va)| b.get(k).map(|vb| (va - vb).powi(2)))
366 .sum();
367 sq.sqrt()
368}
369
370#[inline]
372fn scale(v: f32, min: f32, max: f32) -> f32 {
373 (min + v * (max - min)).clamp(min, max)
374}
375
376#[inline]
381fn apply_diversity(v: f32, min: f32, max: f32, diversity_target: f32) -> f32 {
382 let centre = (min + max) / 2.0;
383 let biased = centre + (v - centre) * diversity_target;
384 biased.clamp(min, max)
385}
386
387fn build_params(values: &[(&str, f32, (f32, f32))], diversity: f32) -> HashMap<String, f32> {
389 values
390 .iter()
391 .map(|(name, raw, (lo, hi))| {
392 let scaled = scale(*raw, *lo, *hi);
393 let final_val = if (diversity - 1.0).abs() < f32::EPSILON {
394 scaled
395 } else {
396 apply_diversity(scaled, *lo, *hi, diversity)
397 };
398 (name.to_string(), final_val)
399 })
400 .collect()
401}
402
403pub fn generate_crowd(config: CrowdConfig) -> Crowd {
405 let mut seed = config.seed;
406 let diversity = config.diversity_target.clamp(0.0, 1.0);
407
408 let mut extra_names: Vec<String> = config.extra_params.keys().cloned().collect();
410 extra_names.sort();
411
412 let mut characters: Vec<CrowdCharacter> = (0..config.count)
413 .map(|id| {
414 let h = lcg_rand(&mut seed);
415 let w = lcg_rand(&mut seed);
416 let a = lcg_rand(&mut seed);
417 let m = lcg_rand(&mut seed);
418
419 let mut entries: Vec<(&str, f32, (f32, f32))> = vec![
420 ("height", h, config.height_range),
421 ("weight", w, config.weight_range),
422 ("age", a, config.age_range),
423 ("muscle", m, config.muscle_range),
424 ];
425
426 for name in &extra_names {
427 let Some(&range) = config.extra_params.get(name) else {
428 continue;
429 };
430 let v = lcg_rand(&mut seed);
431 entries.push((name.as_str(), v, range));
434 }
435
436 let params = build_params(&entries, diversity);
437 let variation_class = VariationClass::classify(¶ms);
438 CrowdCharacter {
439 id,
440 params,
441 variation_class,
442 }
443 })
444 .collect();
445
446 if !config.allow_duplicates {
447 enforce_diversity(&mut characters, 0.01, config.seed);
448 }
449
450 Crowd { characters, config }
451}
452
453pub fn generate_crowd_halton(config: CrowdConfig) -> Crowd {
457 const BASES: [usize; 8] = [2, 3, 5, 7, 11, 13, 17, 19];
459
460 let diversity = config.diversity_target.clamp(0.0, 1.0);
461
462 let mut extra_names: Vec<String> = config.extra_params.keys().cloned().collect();
463 extra_names.sort();
464
465 let offset = (config.seed as usize) % 97 + 1;
467
468 let mut characters: Vec<CrowdCharacter> = (0..config.count)
469 .map(|id| {
470 let idx = id + offset;
471 let h = halton(idx, BASES[0]);
472 let w = halton(idx, BASES[1]);
473 let a = halton(idx, BASES[2]);
474 let m = halton(idx, BASES[3]);
475
476 let mut entries: Vec<(&str, f32, (f32, f32))> = vec![
477 ("height", h, config.height_range),
478 ("weight", w, config.weight_range),
479 ("age", a, config.age_range),
480 ("muscle", m, config.muscle_range),
481 ];
482
483 for (ei, name) in extra_names.iter().enumerate() {
484 let base = BASES.get(4 + ei).copied().unwrap_or(23 + ei * 2);
485 let Some(&range) = config.extra_params.get(name) else {
486 continue;
487 };
488 let v = halton(idx, base);
489 entries.push((name.as_str(), v, range));
490 }
491
492 let params = build_params(&entries, diversity);
493 let variation_class = VariationClass::classify(¶ms);
494 CrowdCharacter {
495 id,
496 params,
497 variation_class,
498 }
499 })
500 .collect();
501
502 if !config.allow_duplicates {
503 enforce_diversity(&mut characters, 0.01, config.seed);
504 }
505
506 Crowd { characters, config }
507}
508
509pub fn enforce_diversity(chars: &mut [CrowdCharacter], min_distance: f32, seed: u32) {
514 let mut rng_seed = seed.wrapping_add(0xDEAD_BEEF);
515 let n = chars.len();
516 for i in 0..n {
518 for j in (i + 1)..n {
519 let dist = param_distance(&chars[i].params, &chars[j].params);
520 if dist < min_distance {
521 let keys: Vec<String> = chars[j].params.keys().cloned().collect();
523 for key in &keys {
524 let v = lcg_rand(&mut rng_seed);
525 chars[j].params.insert(key.clone(), v);
526 }
527 chars[j].variation_class = VariationClass::classify(&chars[j].params);
528 }
529 }
530 }
531}
532
533#[cfg(test)]
538mod tests {
539 use super::*;
540 use std::io::Write;
541
542 fn small_config() -> CrowdConfig {
544 CrowdConfig {
545 count: 8,
546 seed: 1234,
547 ..Default::default()
548 }
549 }
550
551 #[test]
555 fn test_lcg_rand_range() {
556 let mut seed = 42_u32;
557 for _ in 0..1000 {
558 let v = lcg_rand(&mut seed);
559 assert!((0.0..1.0).contains(&v), "lcg_rand out of range: {v}");
560 }
561 }
562
563 #[test]
567 fn test_halton_base2() {
568 assert!((halton(1, 2) - 0.5).abs() < 1e-6, "h(1,2) = 0.5");
570 assert!((halton(2, 2) - 0.25).abs() < 1e-6, "h(2,2) = 0.25");
571 assert!((halton(3, 2) - 0.75).abs() < 1e-6, "h(3,2) = 0.75");
572 assert!((halton(4, 2) - 0.125).abs() < 1e-6, "h(4,2) = 0.125");
573 assert_eq!(halton(0, 2), 0.0, "h(0,2) = 0");
574 }
575
576 #[test]
580 fn test_halton_base3() {
581 assert!((halton(1, 3) - 1.0 / 3.0).abs() < 1e-5, "h(1,3) = 1/3");
583 assert!((halton(2, 3) - 2.0 / 3.0).abs() < 1e-5, "h(2,3) = 2/3");
585 assert!((halton(3, 3) - 1.0 / 9.0).abs() < 1e-5, "h(3,3) = 1/9");
587 }
588
589 #[test]
593 fn test_classify_petite() {
594 let params: HashMap<String, f32> = [
595 ("height".to_string(), 0.2),
596 ("weight".to_string(), 0.2),
597 ("muscle".to_string(), 0.3),
598 ]
599 .into();
600 assert_eq!(VariationClass::classify(¶ms), VariationClass::Petite);
601 }
602
603 #[test]
604 fn test_classify_slim() {
605 let params: HashMap<String, f32> = [
606 ("height".to_string(), 0.5),
607 ("weight".to_string(), 0.2),
608 ("muscle".to_string(), 0.3),
609 ]
610 .into();
611 assert_eq!(VariationClass::classify(¶ms), VariationClass::Slim);
612 }
613
614 #[test]
615 fn test_classify_tall() {
616 let params: HashMap<String, f32> = [
617 ("height".to_string(), 0.8),
618 ("weight".to_string(), 0.5),
619 ("muscle".to_string(), 0.3),
620 ]
621 .into();
622 assert_eq!(VariationClass::classify(¶ms), VariationClass::Tall);
623 }
624
625 #[test]
626 fn test_classify_heavy() {
627 let params: HashMap<String, f32> = [
628 ("height".to_string(), 0.5),
629 ("weight".to_string(), 0.8),
630 ("muscle".to_string(), 0.3),
631 ]
632 .into();
633 assert_eq!(VariationClass::classify(¶ms), VariationClass::Heavy);
634 }
635
636 #[test]
637 fn test_classify_stocky() {
638 let params: HashMap<String, f32> = [
639 ("height".to_string(), 0.2),
640 ("weight".to_string(), 0.8),
641 ("muscle".to_string(), 0.3),
642 ]
643 .into();
644 assert_eq!(VariationClass::classify(¶ms), VariationClass::Stocky);
645 }
646
647 #[test]
648 fn test_classify_athletic() {
649 let params: HashMap<String, f32> = [
650 ("height".to_string(), 0.5),
651 ("weight".to_string(), 0.5),
652 ("muscle".to_string(), 0.8),
653 ]
654 .into();
655 assert_eq!(VariationClass::classify(¶ms), VariationClass::Athletic);
656 }
657
658 #[test]
659 fn test_classify_average() {
660 let params: HashMap<String, f32> = [
661 ("height".to_string(), 0.5),
662 ("weight".to_string(), 0.5),
663 ("muscle".to_string(), 0.3),
664 ]
665 .into();
666 assert_eq!(VariationClass::classify(¶ms), VariationClass::Average);
667 }
668
669 #[test]
673 fn test_generate_crowd_count() {
674 let cfg = CrowdConfig {
675 count: 20,
676 seed: 7,
677 ..Default::default()
678 };
679 let crowd = generate_crowd(cfg);
680 assert_eq!(crowd.count(), 20);
681 }
682
683 #[test]
684 fn test_generate_crowd_params_in_range() {
685 let cfg = CrowdConfig {
686 count: 50,
687 seed: 99,
688 height_range: (0.2, 0.8),
689 weight_range: (0.1, 0.9),
690 age_range: (0.0, 1.0),
691 muscle_range: (0.0, 0.5),
692 ..Default::default()
693 };
694 let crowd = generate_crowd(cfg);
695 for ch in &crowd.characters {
696 let h = ch.params["height"];
697 let w = ch.params["weight"];
698 let m = ch.params["muscle"];
699 assert!((0.2..=0.8).contains(&h), "height {h} out of range");
700 assert!((0.1..=0.9).contains(&w), "weight {w} out of range");
701 assert!((0.0..=0.5).contains(&m), "muscle {m} out of range");
702 }
703 }
704
705 #[test]
709 fn test_generate_crowd_deterministic() {
710 let cfg1 = small_config();
711 let cfg2 = small_config();
712 let c1 = generate_crowd(cfg1);
713 let c2 = generate_crowd(cfg2);
714 assert_eq!(c1.count(), c2.count());
715 for (a, b) in c1.characters.iter().zip(c2.characters.iter()) {
716 for (k, va) in &a.params {
717 let vb = b.params[k];
718 assert!((va - vb).abs() < 1e-6, "Non-deterministic at param {k}");
719 }
720 }
721 }
722
723 #[test]
727 fn test_generate_crowd_halton_range() {
728 let cfg = CrowdConfig {
729 count: 30,
730 seed: 5,
731 ..Default::default()
732 };
733 let crowd = generate_crowd_halton(cfg);
734 assert_eq!(crowd.count(), 30);
735 for ch in &crowd.characters {
736 for (k, v) in &ch.params {
737 assert!(
738 *v >= 0.0 && *v <= 1.0,
739 "halton param {k} = {v} out of [0,1]"
740 );
741 }
742 }
743 }
744
745 #[test]
749 fn test_param_distance() {
750 let a: HashMap<String, f32> = [("x".to_string(), 0.0), ("y".to_string(), 0.0)].into();
751 let b: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 0.0)].into();
752 let c: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 1.0)].into();
753
754 let dab = param_distance(&a, &b);
755 let dba = param_distance(&b, &a);
756 assert!((dab - 1.0).abs() < 1e-5, "d(a,b) should be 1.0, got {dab}");
757 assert!((dab - dba).abs() < 1e-6, "symmetry broken");
758
759 let dac = param_distance(&a, &c);
760 assert!((dac - 2.0_f32.sqrt()).abs() < 1e-5, "d(a,c) = sqrt(2)");
761 }
762
763 #[test]
767 fn test_mean_std_params() {
768 let ch0 = CrowdCharacter {
770 id: 0,
771 params: [("height".to_string(), 0.2_f32)].into(),
772 variation_class: VariationClass::Custom,
773 };
774 let ch1 = CrowdCharacter {
775 id: 1,
776 params: [("height".to_string(), 0.8_f32)].into(),
777 variation_class: VariationClass::Custom,
778 };
779 let crowd = Crowd {
780 characters: vec![ch0, ch1],
781 config: CrowdConfig::default(),
782 };
783 let mean = crowd.mean_params();
784 let std = crowd.std_params();
785 assert!((mean["height"] - 0.5).abs() < 1e-5, "mean = 0.5");
786 assert!(
788 (std["height"] - 0.3).abs() < 1e-5,
789 "std = 0.3, got {}",
790 std["height"]
791 );
792 }
793
794 #[test]
798 fn test_diversity_score_monotone() {
799 let narrow = CrowdConfig {
801 count: 20,
802 seed: 1,
803 diversity_target: 0.0,
804 ..Default::default()
805 };
806 let wide = CrowdConfig {
808 count: 20,
809 seed: 1,
810 diversity_target: 1.0,
811 ..Default::default()
812 };
813 let c_narrow = generate_crowd(narrow);
814 let c_wide = generate_crowd(wide);
815 let s_narrow = c_narrow.diversity_score();
816 let s_wide = c_wide.diversity_score();
817 assert!(s_narrow >= 0.0);
818 assert!(s_wide >= 0.0);
819 assert!(
821 s_wide >= s_narrow - 1e-4,
822 "wide diversity {s_wide} should be >= narrow diversity {s_narrow}"
823 );
824 }
825
826 #[test]
830 fn test_by_class_and_distribution() {
831 let cfg = CrowdConfig {
832 count: 100,
833 seed: 2025,
834 ..Default::default()
835 };
836 let crowd = generate_crowd(cfg);
837 let dist = crowd.class_distribution();
838 let total: usize = dist.values().sum();
840 assert_eq!(total, 100);
841 for (class, &count) in &dist {
843 assert_eq!(crowd.by_class(class).len(), count);
844 }
845 }
846
847 #[test]
851 fn test_sorted_by() {
852 let cfg = CrowdConfig {
853 count: 15,
854 seed: 77,
855 ..Default::default()
856 };
857 let crowd = generate_crowd(cfg);
858 let sorted = crowd.sorted_by("height");
859 for window in sorted.windows(2) {
860 let h0 = window[0].params["height"];
861 let h1 = window[1].params["height"];
862 assert!(h0 <= h1, "Not sorted: {h0} > {h1}");
863 }
864 }
865
866 #[test]
870 fn test_to_param_list_and_get() {
871 let cfg = small_config();
872 let crowd = generate_crowd(cfg);
873 let list = crowd.to_param_list();
874 assert_eq!(list.len(), crowd.count());
875 for (i, ch) in crowd.characters.iter().enumerate() {
876 let from_get = crowd.get(ch.id).expect("should succeed");
877 for (k, v) in &from_get.params {
879 let lv = list[i][k];
880 assert!((v - lv).abs() < 1e-6, "list mismatch for {k}");
881 }
882 }
883 }
884
885 #[test]
889 fn test_summary_content() {
890 let crowd = generate_crowd(small_config());
891 let s = crowd.summary();
892 assert!(s.contains("count"), "summary missing 'count'");
893 assert!(
894 s.contains("diversity_score"),
895 "summary missing 'diversity_score'"
896 );
897 assert!(s.contains("mean_params"), "summary missing 'mean_params'");
898 assert!(
899 s.contains("class_distribution"),
900 "summary missing 'class_distribution'"
901 );
902 let path = "/tmp/oxihuman_crowd_summary.txt";
904 let mut f = std::fs::File::create(path).expect("should succeed");
905 f.write_all(s.as_bytes()).expect("should succeed");
906 }
907
908 #[test]
912 fn test_extra_params() {
913 let mut extra = HashMap::new();
914 extra.insert("nose_width".to_string(), (0.2_f32, 0.7_f32));
915 extra.insert("jaw_size".to_string(), (0.1_f32, 0.9_f32));
916 let cfg = CrowdConfig {
917 count: 20,
918 seed: 31415,
919 extra_params: extra,
920 ..Default::default()
921 };
922 let crowd = generate_crowd(cfg);
923 for ch in &crowd.characters {
924 let nw = ch.params["nose_width"];
925 let js = ch.params["jaw_size"];
926 assert!((0.2..=0.7).contains(&nw), "nose_width {nw} out of range");
927 assert!((0.1..=0.9).contains(&js), "jaw_size {js} out of range");
928 }
929 }
930
931 #[test]
935 fn test_enforce_diversity() {
936 let p: HashMap<String, f32> =
938 [("height".to_string(), 0.5), ("weight".to_string(), 0.5)].into();
939 let mut chars = vec![
940 CrowdCharacter {
941 id: 0,
942 params: p.clone(),
943 variation_class: VariationClass::Average,
944 },
945 CrowdCharacter {
946 id: 1,
947 params: p.clone(),
948 variation_class: VariationClass::Average,
949 },
950 ];
951 let before = param_distance(&chars[0].params, &chars[1].params);
952 assert!(before < 1e-5, "should start as identical");
953
954 enforce_diversity(&mut chars, 0.01, 42);
955
956 let after = param_distance(&chars[0].params, &chars[1].params);
957 assert!(
958 after >= 0.01 || after == 0.0,
959 "after enforce_diversity distance = {after}; expected >= 0.01 or randomised away"
960 );
961 }
962
963 #[test]
967 fn test_variation_class_all() {
968 let all = VariationClass::all();
969 assert!(all.contains(&VariationClass::Petite));
970 assert!(all.contains(&VariationClass::Slim));
971 assert!(all.contains(&VariationClass::Average));
972 assert!(all.contains(&VariationClass::Athletic));
973 assert!(all.contains(&VariationClass::Stocky));
974 assert!(all.contains(&VariationClass::Tall));
975 assert!(all.contains(&VariationClass::Heavy));
976 assert!(all.contains(&VariationClass::Custom));
977 }
978
979 #[test]
983 fn test_variation_class_name() {
984 for class in VariationClass::all() {
985 assert!(!class.name().is_empty());
986 }
987 }
988
989 #[test]
993 fn test_halton_crowd_to_file() {
994 let cfg = CrowdConfig {
995 count: 16,
996 seed: 3,
997 ..Default::default()
998 };
999 let crowd = generate_crowd_halton(cfg);
1000 let list = crowd.to_param_list();
1001 let path = "/tmp/oxihuman_halton_crowd.csv";
1002 let mut f = std::fs::File::create(path).expect("should succeed");
1003 writeln!(f, "id,height,weight,age,muscle").expect("should succeed");
1004 for (i, p) in list.iter().enumerate() {
1005 writeln!(
1006 f,
1007 "{},{:.4},{:.4},{:.4},{:.4}",
1008 i,
1009 p.get("height").copied().unwrap_or(0.0),
1010 p.get("weight").copied().unwrap_or(0.0),
1011 p.get("age").copied().unwrap_or(0.0),
1012 p.get("muscle").copied().unwrap_or(0.0),
1013 )
1014 .expect("should succeed");
1015 }
1016 }
1017
1018 #[test]
1022 fn test_get_out_of_range() {
1023 let crowd = generate_crowd(small_config());
1024 assert!(crowd.get(9999).is_none());
1025 }
1026
1027 #[test]
1031 fn test_empty_crowd() {
1032 let cfg = CrowdConfig {
1033 count: 0,
1034 ..Default::default()
1035 };
1036 let crowd = generate_crowd(cfg);
1037 assert_eq!(crowd.count(), 0);
1038 assert_eq!(crowd.diversity_score(), 0.0);
1039 assert!(crowd.mean_params().is_empty());
1040 assert!(crowd.to_param_list().is_empty());
1041 }
1042}