Skip to main content

oxihuman_morph/
crowd_generator.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5#![allow(clippy::too_many_arguments)]
6
7//! Crowd generator: produce diverse crowds of character parameter sets.
8//!
9//! Supports both LCG pseudo-random and Halton quasi-random sequence generation
10//! with configurable variation ranges, diversity enforcement, and statistics.
11
12use std::collections::HashMap;
13
14// ---------------------------------------------------------------------------
15// CrowdConfig
16// ---------------------------------------------------------------------------
17
18/// Configuration for crowd generation.
19pub struct CrowdConfig {
20    /// Number of characters to generate.
21    pub count: usize,
22    /// Deterministic seed for reproducibility.
23    pub seed: u32,
24    /// Height parameter range `[min, max]` in `[0, 1]`.
25    pub height_range: (f32, f32),
26    /// Weight parameter range `[min, max]` in `[0, 1]`.
27    pub weight_range: (f32, f32),
28    /// Age parameter range `[min, max]` in `[0, 1]`.
29    pub age_range: (f32, f32),
30    /// Muscle parameter range `[min, max]` in `[0, 1]`.
31    pub muscle_range: (f32, f32),
32    /// `0.0` = uniform distribution, `1.0` = maximum spread.
33    pub diversity_target: f32,
34    /// If `false`, the generator will attempt to avoid duplicate param sets.
35    pub allow_duplicates: bool,
36    /// Additional named parameter ranges: name → `(min, max)`.
37    pub extra_params: HashMap<String, (f32, f32)>,
38}
39
40impl Default for CrowdConfig {
41    fn default() -> Self {
42        Self {
43            count: 10,
44            seed: 42,
45            height_range: (0.0, 1.0),
46            weight_range: (0.0, 1.0),
47            age_range: (0.0, 1.0),
48            muscle_range: (0.0, 1.0),
49            diversity_target: 0.5,
50            allow_duplicates: true,
51            extra_params: HashMap::new(),
52        }
53    }
54}
55
56// ---------------------------------------------------------------------------
57// VariationClass
58// ---------------------------------------------------------------------------
59
60/// Broad variation class used for diversity tracking in a crowd.
61#[derive(Debug, Clone, PartialEq, Eq, Hash)]
62pub enum VariationClass {
63    /// Low height, low weight.
64    Petite,
65    /// Average height, low weight.
66    Slim,
67    /// Average across all parameters.
68    Average,
69    /// High muscle, moderate weight.
70    Athletic,
71    /// Low height, high weight.
72    Stocky,
73    /// High height, average weight.
74    Tall,
75    /// High weight.
76    Heavy,
77    /// Does not fit any standard class.
78    Custom,
79}
80
81impl VariationClass {
82    /// Classify a character based on their parameter map.
83    pub fn classify(params: &HashMap<String, f32>) -> VariationClass {
84        let height = params.get("height").copied().unwrap_or(0.5);
85        let weight = params.get("weight").copied().unwrap_or(0.5);
86        let muscle = params.get("muscle").copied().unwrap_or(0.3);
87
88        // Thresholds
89        let lo = 0.35_f32;
90        let hi = 0.65_f32;
91
92        let short = height < lo;
93        let tall = height > hi;
94        let light = weight < lo;
95        let heavy_w = weight > hi;
96        let avg_h = !short && !tall;
97        let avg_w = !light && !heavy_w;
98        let muscular = muscle > hi;
99
100        if short && light {
101            VariationClass::Petite
102        } else if avg_h && light {
103            VariationClass::Slim
104        } else if short && heavy_w {
105            VariationClass::Stocky
106        } else if heavy_w {
107            VariationClass::Heavy
108        } else if tall && avg_w {
109            VariationClass::Tall
110        } else if muscular && avg_w {
111            VariationClass::Athletic
112        } else if avg_h && avg_w && !muscular {
113            VariationClass::Average
114        } else {
115            VariationClass::Custom
116        }
117    }
118
119    /// Return all defined variation classes (except `Custom`).
120    pub fn all() -> Vec<VariationClass> {
121        vec![
122            VariationClass::Petite,
123            VariationClass::Slim,
124            VariationClass::Average,
125            VariationClass::Athletic,
126            VariationClass::Stocky,
127            VariationClass::Tall,
128            VariationClass::Heavy,
129            VariationClass::Custom,
130        ]
131    }
132
133    /// Human-readable name of the variation class.
134    pub fn name(&self) -> &'static str {
135        match self {
136            VariationClass::Petite => "Petite",
137            VariationClass::Slim => "Slim",
138            VariationClass::Average => "Average",
139            VariationClass::Athletic => "Athletic",
140            VariationClass::Stocky => "Stocky",
141            VariationClass::Tall => "Tall",
142            VariationClass::Heavy => "Heavy",
143            VariationClass::Custom => "Custom",
144        }
145    }
146}
147
148// ---------------------------------------------------------------------------
149// CrowdCharacter
150// ---------------------------------------------------------------------------
151
152/// A single generated character's full parameter set.
153pub struct CrowdCharacter {
154    /// Zero-based index assigned at generation time.
155    pub id: usize,
156    /// Named parameter values in `[0, 1]`.
157    pub params: HashMap<String, f32>,
158    /// Broad variation class determined from `params`.
159    pub variation_class: VariationClass,
160}
161
162// ---------------------------------------------------------------------------
163// Crowd
164// ---------------------------------------------------------------------------
165
166/// The generated crowd of [`CrowdCharacter`] instances.
167pub struct Crowd {
168    /// All characters in generation order.
169    pub characters: Vec<CrowdCharacter>,
170    /// The config that produced this crowd.
171    pub config: CrowdConfig,
172}
173
174impl Crowd {
175    /// Number of characters in the crowd.
176    pub fn count(&self) -> usize {
177        self.characters.len()
178    }
179
180    /// Look up a character by `id`.
181    pub fn get(&self, id: usize) -> Option<&CrowdCharacter> {
182        self.characters.iter().find(|c| c.id == id)
183    }
184
185    // -----------------------------------------------------------------------
186    // Statistics
187    // -----------------------------------------------------------------------
188
189    /// Compute the mean value of each parameter across the crowd.
190    pub fn mean_params(&self) -> HashMap<String, f32> {
191        if self.characters.is_empty() {
192            return HashMap::new();
193        }
194        let n = self.characters.len() as f32;
195        let mut sums: HashMap<String, f32> = HashMap::new();
196        for ch in &self.characters {
197            for (k, v) in &ch.params {
198                *sums.entry(k.clone()).or_insert(0.0) += v;
199            }
200        }
201        sums.iter_mut().for_each(|(_, v)| *v /= n);
202        sums
203    }
204
205    /// Compute the standard deviation of each parameter across the crowd.
206    pub fn std_params(&self) -> HashMap<String, f32> {
207        if self.characters.len() < 2 {
208            return HashMap::new();
209        }
210        let means = self.mean_params();
211        let n = self.characters.len() as f32;
212        let mut sq_sums: HashMap<String, f32> = HashMap::new();
213        for ch in &self.characters {
214            for (k, v) in &ch.params {
215                let mean = means.get(k).copied().unwrap_or(0.0);
216                let d = v - mean;
217                *sq_sums.entry(k.clone()).or_insert(0.0) += d * d;
218            }
219        }
220        sq_sums.iter_mut().for_each(|(_, v)| *v = (*v / n).sqrt());
221        sq_sums
222    }
223
224    /// Mean pairwise parameter distance (diversity score).
225    pub fn diversity_score(&self) -> f32 {
226        let n = self.characters.len();
227        if n < 2 {
228            return 0.0;
229        }
230        let mut total = 0.0_f32;
231        let mut count = 0usize;
232        for i in 0..n {
233            for j in (i + 1)..n {
234                total += param_distance(&self.characters[i].params, &self.characters[j].params);
235                count += 1;
236            }
237        }
238        if count == 0 {
239            0.0
240        } else {
241            total / count as f32
242        }
243    }
244
245    // -----------------------------------------------------------------------
246    // Filtering / sorting
247    // -----------------------------------------------------------------------
248
249    /// Return all characters that belong to the given [`VariationClass`].
250    pub fn by_class(&self, class: &VariationClass) -> Vec<&CrowdCharacter> {
251        self.characters
252            .iter()
253            .filter(|c| &c.variation_class == class)
254            .collect()
255    }
256
257    /// Count characters per variation class.
258    pub fn class_distribution(&self) -> HashMap<VariationClass, usize> {
259        let mut dist: HashMap<VariationClass, usize> = HashMap::new();
260        for ch in &self.characters {
261            *dist.entry(ch.variation_class.clone()).or_insert(0) += 1;
262        }
263        dist
264    }
265
266    /// Return references to characters sorted in ascending order by `param`.
267    ///
268    /// Characters missing the requested parameter are placed at the end.
269    pub fn sorted_by(&self, param: &str) -> Vec<&CrowdCharacter> {
270        let mut refs: Vec<&CrowdCharacter> = self.characters.iter().collect();
271        refs.sort_by(|a, b| {
272            let va = a.params.get(param).copied().unwrap_or(f32::MAX);
273            let vb = b.params.get(param).copied().unwrap_or(f32::MAX);
274            va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
275        });
276        refs
277    }
278
279    /// Export the crowd as a plain list of param maps (for batch processing).
280    pub fn to_param_list(&self) -> Vec<HashMap<String, f32>> {
281        self.characters.iter().map(|c| c.params.clone()).collect()
282    }
283
284    /// Generate a JSON-like human-readable summary string.
285    pub fn summary(&self) -> String {
286        let mean = self.mean_params();
287        let std = self.std_params();
288        let dist = self.class_distribution();
289        let diversity = self.diversity_score();
290
291        let mut lines = Vec::new();
292        lines.push(format!(
293            "{{ \"count\": {}, \"diversity_score\": {:.4},",
294            self.count(),
295            diversity
296        ));
297        lines.push("  \"mean_params\": {".to_string());
298
299        let mut keys: Vec<&String> = mean.keys().collect();
300        keys.sort();
301        for (idx, k) in keys.iter().enumerate() {
302            let m = mean[*k];
303            let s = std.get(*k).copied().unwrap_or(0.0);
304            let comma = if idx + 1 < keys.len() { "," } else { "" };
305            lines.push(format!(
306                "    \"{}\": {{ \"mean\": {:.4}, \"std\": {:.4} }}{}",
307                k, m, s, comma
308            ));
309        }
310        lines.push("  },".to_string());
311        lines.push("  \"class_distribution\": {".to_string());
312
313        let mut class_entries: Vec<(String, usize)> = dist
314            .iter()
315            .map(|(c, n)| (c.name().to_string(), *n))
316            .collect();
317        class_entries.sort_by(|a, b| a.0.cmp(&b.0));
318        for (idx, (name, count)) in class_entries.iter().enumerate() {
319            let comma = if idx + 1 < class_entries.len() {
320                ","
321            } else {
322                ""
323            };
324            lines.push(format!("    \"{}\": {}{}", name, count, comma));
325        }
326        lines.push("  }".to_string());
327        lines.push("}".to_string());
328        lines.join("\n")
329    }
330}
331
332// ---------------------------------------------------------------------------
333// Core public API
334// ---------------------------------------------------------------------------
335
336/// Compute the Halton quasi-random sequence value for index `i` in base `base`.
337///
338/// Uses 1-based indexing to avoid the trivial `0` at `i = 0`.
339pub fn halton(i: usize, base: usize) -> f32 {
340    let mut result = 0.0_f64;
341    let mut denom = 1.0_f64;
342    let mut n = i;
343    while n > 0 {
344        denom *= base as f64;
345        result += (n % base) as f64 / denom;
346        n /= base;
347    }
348    result as f32
349}
350
351/// LCG-based pseudo-random value in `[0, 1)`.
352///
353/// Advances `*seed` in place (multiplier 1664525, addend 1013904223 — Numerical Recipes).
354pub fn lcg_rand(seed: &mut u32) -> f32 {
355    *seed = seed.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
356    (*seed >> 9) as f32 / (1u32 << 23) as f32
357}
358
359/// Euclidean L2 distance between two parameter maps.
360///
361/// Only keys present in **both** maps contribute to the distance.
362pub fn param_distance(a: &HashMap<String, f32>, b: &HashMap<String, f32>) -> f32 {
363    let sq: f32 = a
364        .iter()
365        .filter_map(|(k, va)| b.get(k).map(|vb| (va - vb).powi(2)))
366        .sum();
367    sq.sqrt()
368}
369
370/// Scale an LCG random value from `[0, 1)` to `[min, max)`.
371#[inline]
372fn scale(v: f32, min: f32, max: f32) -> f32 {
373    (min + v * (max - min)).clamp(min, max)
374}
375
376/// Apply `diversity_target` bias: mix between uniform center (0.5) and full range.
377///
378/// When `diversity_target = 0.0`, values are pulled towards `0.5`.
379/// When `diversity_target = 1.0`, the full `[min, max]` range is used.
380#[inline]
381fn apply_diversity(v: f32, min: f32, max: f32, diversity_target: f32) -> f32 {
382    let centre = (min + max) / 2.0;
383    let biased = centre + (v - centre) * diversity_target;
384    biased.clamp(min, max)
385}
386
387/// Build a [`HashMap`] of parameters from a slice of `(value, range)` pairs.
388fn build_params(values: &[(&str, f32, (f32, f32))], diversity: f32) -> HashMap<String, f32> {
389    values
390        .iter()
391        .map(|(name, raw, (lo, hi))| {
392            let scaled = scale(*raw, *lo, *hi);
393            let final_val = if (diversity - 1.0).abs() < f32::EPSILON {
394                scaled
395            } else {
396                apply_diversity(scaled, *lo, *hi, diversity)
397            };
398            (name.to_string(), final_val)
399        })
400        .collect()
401}
402
403/// Generate a crowd using LCG pseudo-random numbers.
404pub fn generate_crowd(config: CrowdConfig) -> Crowd {
405    let mut seed = config.seed;
406    let diversity = config.diversity_target.clamp(0.0, 1.0);
407
408    // Collect extra param names for deterministic ordering.
409    let mut extra_names: Vec<String> = config.extra_params.keys().cloned().collect();
410    extra_names.sort();
411
412    let mut characters: Vec<CrowdCharacter> = (0..config.count)
413        .map(|id| {
414            let h = lcg_rand(&mut seed);
415            let w = lcg_rand(&mut seed);
416            let a = lcg_rand(&mut seed);
417            let m = lcg_rand(&mut seed);
418
419            let mut entries: Vec<(&str, f32, (f32, f32))> = vec![
420                ("height", h, config.height_range),
421                ("weight", w, config.weight_range),
422                ("age", a, config.age_range),
423                ("muscle", m, config.muscle_range),
424            ];
425
426            for name in &extra_names {
427                let Some(&range) = config.extra_params.get(name) else {
428                    continue;
429                };
430                let v = lcg_rand(&mut seed);
431                // SAFETY: the string slice lives for the iteration body only —
432                // we'll collect into owned Strings via build_params immediately.
433                entries.push((name.as_str(), v, range));
434            }
435
436            let params = build_params(&entries, diversity);
437            let variation_class = VariationClass::classify(&params);
438            CrowdCharacter {
439                id,
440                params,
441                variation_class,
442            }
443        })
444        .collect();
445
446    if !config.allow_duplicates {
447        enforce_diversity(&mut characters, 0.01, config.seed);
448    }
449
450    Crowd { characters, config }
451}
452
453/// Generate a crowd using the Halton quasi-random sequence for better coverage.
454///
455/// Core parameters use Halton bases 2, 3, 5, 7; extra parameters use bases 11, 13, 17, …
456pub fn generate_crowd_halton(config: CrowdConfig) -> Crowd {
457    // Prime bases for the Halton sequence.
458    const BASES: [usize; 8] = [2, 3, 5, 7, 11, 13, 17, 19];
459
460    let diversity = config.diversity_target.clamp(0.0, 1.0);
461
462    let mut extra_names: Vec<String> = config.extra_params.keys().cloned().collect();
463    extra_names.sort();
464
465    // Offset by seed to shift the sequence start.
466    let offset = (config.seed as usize) % 97 + 1;
467
468    let mut characters: Vec<CrowdCharacter> = (0..config.count)
469        .map(|id| {
470            let idx = id + offset;
471            let h = halton(idx, BASES[0]);
472            let w = halton(idx, BASES[1]);
473            let a = halton(idx, BASES[2]);
474            let m = halton(idx, BASES[3]);
475
476            let mut entries: Vec<(&str, f32, (f32, f32))> = vec![
477                ("height", h, config.height_range),
478                ("weight", w, config.weight_range),
479                ("age", a, config.age_range),
480                ("muscle", m, config.muscle_range),
481            ];
482
483            for (ei, name) in extra_names.iter().enumerate() {
484                let base = BASES.get(4 + ei).copied().unwrap_or(23 + ei * 2);
485                let Some(&range) = config.extra_params.get(name) else {
486                    continue;
487                };
488                let v = halton(idx, base);
489                entries.push((name.as_str(), v, range));
490            }
491
492            let params = build_params(&entries, diversity);
493            let variation_class = VariationClass::classify(&params);
494            CrowdCharacter {
495                id,
496                params,
497                variation_class,
498            }
499        })
500        .collect();
501
502    if !config.allow_duplicates {
503        enforce_diversity(&mut characters, 0.01, config.seed);
504    }
505
506    Crowd { characters, config }
507}
508
509/// Ensure minimum pairwise diversity in a set of characters.
510///
511/// Any pair closer than `min_distance` in parameter space has one of the two
512/// regenerated via LCG. Runs at most `O(n^2)` passes (one sweep).
513pub fn enforce_diversity(chars: &mut [CrowdCharacter], min_distance: f32, seed: u32) {
514    let mut rng_seed = seed.wrapping_add(0xDEAD_BEEF);
515    let n = chars.len();
516    // One forward sweep: for each pair (i, j) that is too close, randomise j.
517    for i in 0..n {
518        for j in (i + 1)..n {
519            let dist = param_distance(&chars[i].params, &chars[j].params);
520            if dist < min_distance {
521                // Regenerate character j with fresh random params.
522                let keys: Vec<String> = chars[j].params.keys().cloned().collect();
523                for key in &keys {
524                    let v = lcg_rand(&mut rng_seed);
525                    chars[j].params.insert(key.clone(), v);
526                }
527                chars[j].variation_class = VariationClass::classify(&chars[j].params);
528            }
529        }
530    }
531}
532
533// ---------------------------------------------------------------------------
534// Tests
535// ---------------------------------------------------------------------------
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use std::io::Write;
541
542    // Helper: default config with small count.
543    fn small_config() -> CrowdConfig {
544        CrowdConfig {
545            count: 8,
546            seed: 1234,
547            ..Default::default()
548        }
549    }
550
551    // -----------------------------------------------------------------------
552    // 1. lcg_rand stays in [0, 1)
553    // -----------------------------------------------------------------------
554    #[test]
555    fn test_lcg_rand_range() {
556        let mut seed = 42_u32;
557        for _ in 0..1000 {
558            let v = lcg_rand(&mut seed);
559            assert!((0.0..1.0).contains(&v), "lcg_rand out of range: {v}");
560        }
561    }
562
563    // -----------------------------------------------------------------------
564    // 2. halton sequence correctness (base 2)
565    // -----------------------------------------------------------------------
566    #[test]
567    fn test_halton_base2() {
568        // Known values for Halton base-2
569        assert!((halton(1, 2) - 0.5).abs() < 1e-6, "h(1,2) = 0.5");
570        assert!((halton(2, 2) - 0.25).abs() < 1e-6, "h(2,2) = 0.25");
571        assert!((halton(3, 2) - 0.75).abs() < 1e-6, "h(3,2) = 0.75");
572        assert!((halton(4, 2) - 0.125).abs() < 1e-6, "h(4,2) = 0.125");
573        assert_eq!(halton(0, 2), 0.0, "h(0,2) = 0");
574    }
575
576    // -----------------------------------------------------------------------
577    // 3. halton sequence correctness (base 3)
578    // -----------------------------------------------------------------------
579    #[test]
580    fn test_halton_base3() {
581        // h(1, 3) = 1/3
582        assert!((halton(1, 3) - 1.0 / 3.0).abs() < 1e-5, "h(1,3) = 1/3");
583        // h(2, 3) = 2/3
584        assert!((halton(2, 3) - 2.0 / 3.0).abs() < 1e-5, "h(2,3) = 2/3");
585        // h(3, 3) = 1/9
586        assert!((halton(3, 3) - 1.0 / 9.0).abs() < 1e-5, "h(3,3) = 1/9");
587    }
588
589    // -----------------------------------------------------------------------
590    // 4. VariationClass::classify covers all major branches
591    // -----------------------------------------------------------------------
592    #[test]
593    fn test_classify_petite() {
594        let params: HashMap<String, f32> = [
595            ("height".to_string(), 0.2),
596            ("weight".to_string(), 0.2),
597            ("muscle".to_string(), 0.3),
598        ]
599        .into();
600        assert_eq!(VariationClass::classify(&params), VariationClass::Petite);
601    }
602
603    #[test]
604    fn test_classify_slim() {
605        let params: HashMap<String, f32> = [
606            ("height".to_string(), 0.5),
607            ("weight".to_string(), 0.2),
608            ("muscle".to_string(), 0.3),
609        ]
610        .into();
611        assert_eq!(VariationClass::classify(&params), VariationClass::Slim);
612    }
613
614    #[test]
615    fn test_classify_tall() {
616        let params: HashMap<String, f32> = [
617            ("height".to_string(), 0.8),
618            ("weight".to_string(), 0.5),
619            ("muscle".to_string(), 0.3),
620        ]
621        .into();
622        assert_eq!(VariationClass::classify(&params), VariationClass::Tall);
623    }
624
625    #[test]
626    fn test_classify_heavy() {
627        let params: HashMap<String, f32> = [
628            ("height".to_string(), 0.5),
629            ("weight".to_string(), 0.8),
630            ("muscle".to_string(), 0.3),
631        ]
632        .into();
633        assert_eq!(VariationClass::classify(&params), VariationClass::Heavy);
634    }
635
636    #[test]
637    fn test_classify_stocky() {
638        let params: HashMap<String, f32> = [
639            ("height".to_string(), 0.2),
640            ("weight".to_string(), 0.8),
641            ("muscle".to_string(), 0.3),
642        ]
643        .into();
644        assert_eq!(VariationClass::classify(&params), VariationClass::Stocky);
645    }
646
647    #[test]
648    fn test_classify_athletic() {
649        let params: HashMap<String, f32> = [
650            ("height".to_string(), 0.5),
651            ("weight".to_string(), 0.5),
652            ("muscle".to_string(), 0.8),
653        ]
654        .into();
655        assert_eq!(VariationClass::classify(&params), VariationClass::Athletic);
656    }
657
658    #[test]
659    fn test_classify_average() {
660        let params: HashMap<String, f32> = [
661            ("height".to_string(), 0.5),
662            ("weight".to_string(), 0.5),
663            ("muscle".to_string(), 0.3),
664        ]
665        .into();
666        assert_eq!(VariationClass::classify(&params), VariationClass::Average);
667    }
668
669    // -----------------------------------------------------------------------
670    // 5. generate_crowd basic properties
671    // -----------------------------------------------------------------------
672    #[test]
673    fn test_generate_crowd_count() {
674        let cfg = CrowdConfig {
675            count: 20,
676            seed: 7,
677            ..Default::default()
678        };
679        let crowd = generate_crowd(cfg);
680        assert_eq!(crowd.count(), 20);
681    }
682
683    #[test]
684    fn test_generate_crowd_params_in_range() {
685        let cfg = CrowdConfig {
686            count: 50,
687            seed: 99,
688            height_range: (0.2, 0.8),
689            weight_range: (0.1, 0.9),
690            age_range: (0.0, 1.0),
691            muscle_range: (0.0, 0.5),
692            ..Default::default()
693        };
694        let crowd = generate_crowd(cfg);
695        for ch in &crowd.characters {
696            let h = ch.params["height"];
697            let w = ch.params["weight"];
698            let m = ch.params["muscle"];
699            assert!((0.2..=0.8).contains(&h), "height {h} out of range");
700            assert!((0.1..=0.9).contains(&w), "weight {w} out of range");
701            assert!((0.0..=0.5).contains(&m), "muscle {m} out of range");
702        }
703    }
704
705    // -----------------------------------------------------------------------
706    // 6. Determinism: same seed → same crowd
707    // -----------------------------------------------------------------------
708    #[test]
709    fn test_generate_crowd_deterministic() {
710        let cfg1 = small_config();
711        let cfg2 = small_config();
712        let c1 = generate_crowd(cfg1);
713        let c2 = generate_crowd(cfg2);
714        assert_eq!(c1.count(), c2.count());
715        for (a, b) in c1.characters.iter().zip(c2.characters.iter()) {
716            for (k, va) in &a.params {
717                let vb = b.params[k];
718                assert!((va - vb).abs() < 1e-6, "Non-deterministic at param {k}");
719            }
720        }
721    }
722
723    // -----------------------------------------------------------------------
724    // 7. generate_crowd_halton: values in [0, 1]
725    // -----------------------------------------------------------------------
726    #[test]
727    fn test_generate_crowd_halton_range() {
728        let cfg = CrowdConfig {
729            count: 30,
730            seed: 5,
731            ..Default::default()
732        };
733        let crowd = generate_crowd_halton(cfg);
734        assert_eq!(crowd.count(), 30);
735        for ch in &crowd.characters {
736            for (k, v) in &ch.params {
737                assert!(
738                    *v >= 0.0 && *v <= 1.0,
739                    "halton param {k} = {v} out of [0,1]"
740                );
741            }
742        }
743    }
744
745    // -----------------------------------------------------------------------
746    // 8. param_distance symmetry and triangle inequality
747    // -----------------------------------------------------------------------
748    #[test]
749    fn test_param_distance() {
750        let a: HashMap<String, f32> = [("x".to_string(), 0.0), ("y".to_string(), 0.0)].into();
751        let b: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 0.0)].into();
752        let c: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 1.0)].into();
753
754        let dab = param_distance(&a, &b);
755        let dba = param_distance(&b, &a);
756        assert!((dab - 1.0).abs() < 1e-5, "d(a,b) should be 1.0, got {dab}");
757        assert!((dab - dba).abs() < 1e-6, "symmetry broken");
758
759        let dac = param_distance(&a, &c);
760        assert!((dac - 2.0_f32.sqrt()).abs() < 1e-5, "d(a,c) = sqrt(2)");
761    }
762
763    // -----------------------------------------------------------------------
764    // 9. mean_params and std_params
765    // -----------------------------------------------------------------------
766    #[test]
767    fn test_mean_std_params() {
768        // Two characters: height 0.2 and 0.8 → mean 0.5, std = 0.3
769        let ch0 = CrowdCharacter {
770            id: 0,
771            params: [("height".to_string(), 0.2_f32)].into(),
772            variation_class: VariationClass::Custom,
773        };
774        let ch1 = CrowdCharacter {
775            id: 1,
776            params: [("height".to_string(), 0.8_f32)].into(),
777            variation_class: VariationClass::Custom,
778        };
779        let crowd = Crowd {
780            characters: vec![ch0, ch1],
781            config: CrowdConfig::default(),
782        };
783        let mean = crowd.mean_params();
784        let std = crowd.std_params();
785        assert!((mean["height"] - 0.5).abs() < 1e-5, "mean = 0.5");
786        // std = sqrt(((0.2-0.5)^2 + (0.8-0.5)^2) / 2) = sqrt(0.09) = 0.3
787        assert!(
788            (std["height"] - 0.3).abs() < 1e-5,
789            "std = 0.3, got {}",
790            std["height"]
791        );
792    }
793
794    // -----------------------------------------------------------------------
795    // 10. diversity_score is non-negative and increases with spread
796    // -----------------------------------------------------------------------
797    #[test]
798    fn test_diversity_score_monotone() {
799        // Narrow config (low spread)
800        let narrow = CrowdConfig {
801            count: 20,
802            seed: 1,
803            diversity_target: 0.0,
804            ..Default::default()
805        };
806        // Wide config (high spread)
807        let wide = CrowdConfig {
808            count: 20,
809            seed: 1,
810            diversity_target: 1.0,
811            ..Default::default()
812        };
813        let c_narrow = generate_crowd(narrow);
814        let c_wide = generate_crowd(wide);
815        let s_narrow = c_narrow.diversity_score();
816        let s_wide = c_wide.diversity_score();
817        assert!(s_narrow >= 0.0);
818        assert!(s_wide >= 0.0);
819        // Wide should have at least as large a score as narrow.
820        assert!(
821            s_wide >= s_narrow - 1e-4,
822            "wide diversity {s_wide} should be >= narrow diversity {s_narrow}"
823        );
824    }
825
826    // -----------------------------------------------------------------------
827    // 11. by_class and class_distribution
828    // -----------------------------------------------------------------------
829    #[test]
830    fn test_by_class_and_distribution() {
831        let cfg = CrowdConfig {
832            count: 100,
833            seed: 2025,
834            ..Default::default()
835        };
836        let crowd = generate_crowd(cfg);
837        let dist = crowd.class_distribution();
838        // Sum of all class counts should equal total count
839        let total: usize = dist.values().sum();
840        assert_eq!(total, 100);
841        // by_class counts should match distribution counts
842        for (class, &count) in &dist {
843            assert_eq!(crowd.by_class(class).len(), count);
844        }
845    }
846
847    // -----------------------------------------------------------------------
848    // 12. sorted_by returns correct ascending order
849    // -----------------------------------------------------------------------
850    #[test]
851    fn test_sorted_by() {
852        let cfg = CrowdConfig {
853            count: 15,
854            seed: 77,
855            ..Default::default()
856        };
857        let crowd = generate_crowd(cfg);
858        let sorted = crowd.sorted_by("height");
859        for window in sorted.windows(2) {
860            let h0 = window[0].params["height"];
861            let h1 = window[1].params["height"];
862            assert!(h0 <= h1, "Not sorted: {h0} > {h1}");
863        }
864    }
865
866    // -----------------------------------------------------------------------
867    // 13. to_param_list and get round-trips
868    // -----------------------------------------------------------------------
869    #[test]
870    fn test_to_param_list_and_get() {
871        let cfg = small_config();
872        let crowd = generate_crowd(cfg);
873        let list = crowd.to_param_list();
874        assert_eq!(list.len(), crowd.count());
875        for (i, ch) in crowd.characters.iter().enumerate() {
876            let from_get = crowd.get(ch.id).expect("should succeed");
877            // Params from get() and from to_param_list() should match
878            for (k, v) in &from_get.params {
879                let lv = list[i][k];
880                assert!((v - lv).abs() < 1e-6, "list mismatch for {k}");
881            }
882        }
883    }
884
885    // -----------------------------------------------------------------------
886    // 14. summary produces non-empty string with key fields
887    // -----------------------------------------------------------------------
888    #[test]
889    fn test_summary_content() {
890        let crowd = generate_crowd(small_config());
891        let s = crowd.summary();
892        assert!(s.contains("count"), "summary missing 'count'");
893        assert!(
894            s.contains("diversity_score"),
895            "summary missing 'diversity_score'"
896        );
897        assert!(s.contains("mean_params"), "summary missing 'mean_params'");
898        assert!(
899            s.contains("class_distribution"),
900            "summary missing 'class_distribution'"
901        );
902        // Write to /tmp/ for inspection
903        let path = "/tmp/oxihuman_crowd_summary.txt";
904        let mut f = std::fs::File::create(path).expect("should succeed");
905        f.write_all(s.as_bytes()).expect("should succeed");
906    }
907
908    // -----------------------------------------------------------------------
909    // 15. extra_params are generated and in range
910    // -----------------------------------------------------------------------
911    #[test]
912    fn test_extra_params() {
913        let mut extra = HashMap::new();
914        extra.insert("nose_width".to_string(), (0.2_f32, 0.7_f32));
915        extra.insert("jaw_size".to_string(), (0.1_f32, 0.9_f32));
916        let cfg = CrowdConfig {
917            count: 20,
918            seed: 31415,
919            extra_params: extra,
920            ..Default::default()
921        };
922        let crowd = generate_crowd(cfg);
923        for ch in &crowd.characters {
924            let nw = ch.params["nose_width"];
925            let js = ch.params["jaw_size"];
926            assert!((0.2..=0.7).contains(&nw), "nose_width {nw} out of range");
927            assert!((0.1..=0.9).contains(&js), "jaw_size {js} out of range");
928        }
929    }
930
931    // -----------------------------------------------------------------------
932    // 16. enforce_diversity increases minimum pairwise distance
933    // -----------------------------------------------------------------------
934    #[test]
935    fn test_enforce_diversity() {
936        // Create two nearly identical characters
937        let p: HashMap<String, f32> =
938            [("height".to_string(), 0.5), ("weight".to_string(), 0.5)].into();
939        let mut chars = vec![
940            CrowdCharacter {
941                id: 0,
942                params: p.clone(),
943                variation_class: VariationClass::Average,
944            },
945            CrowdCharacter {
946                id: 1,
947                params: p.clone(),
948                variation_class: VariationClass::Average,
949            },
950        ];
951        let before = param_distance(&chars[0].params, &chars[1].params);
952        assert!(before < 1e-5, "should start as identical");
953
954        enforce_diversity(&mut chars, 0.01, 42);
955
956        let after = param_distance(&chars[0].params, &chars[1].params);
957        assert!(
958            after >= 0.01 || after == 0.0,
959            "after enforce_diversity distance = {after}; expected >= 0.01 or randomised away"
960        );
961    }
962
963    // -----------------------------------------------------------------------
964    // 17. VariationClass::all() covers all variants
965    // -----------------------------------------------------------------------
966    #[test]
967    fn test_variation_class_all() {
968        let all = VariationClass::all();
969        assert!(all.contains(&VariationClass::Petite));
970        assert!(all.contains(&VariationClass::Slim));
971        assert!(all.contains(&VariationClass::Average));
972        assert!(all.contains(&VariationClass::Athletic));
973        assert!(all.contains(&VariationClass::Stocky));
974        assert!(all.contains(&VariationClass::Tall));
975        assert!(all.contains(&VariationClass::Heavy));
976        assert!(all.contains(&VariationClass::Custom));
977    }
978
979    // -----------------------------------------------------------------------
980    // 18. VariationClass::name() returns non-empty strings
981    // -----------------------------------------------------------------------
982    #[test]
983    fn test_variation_class_name() {
984        for class in VariationClass::all() {
985            assert!(!class.name().is_empty());
986        }
987    }
988
989    // -----------------------------------------------------------------------
990    // 19. Halton crowd writes to /tmp for visual inspection
991    // -----------------------------------------------------------------------
992    #[test]
993    fn test_halton_crowd_to_file() {
994        let cfg = CrowdConfig {
995            count: 16,
996            seed: 3,
997            ..Default::default()
998        };
999        let crowd = generate_crowd_halton(cfg);
1000        let list = crowd.to_param_list();
1001        let path = "/tmp/oxihuman_halton_crowd.csv";
1002        let mut f = std::fs::File::create(path).expect("should succeed");
1003        writeln!(f, "id,height,weight,age,muscle").expect("should succeed");
1004        for (i, p) in list.iter().enumerate() {
1005            writeln!(
1006                f,
1007                "{},{:.4},{:.4},{:.4},{:.4}",
1008                i,
1009                p.get("height").copied().unwrap_or(0.0),
1010                p.get("weight").copied().unwrap_or(0.0),
1011                p.get("age").copied().unwrap_or(0.0),
1012                p.get("muscle").copied().unwrap_or(0.0),
1013            )
1014            .expect("should succeed");
1015        }
1016    }
1017
1018    // -----------------------------------------------------------------------
1019    // 20. get() returns None for out-of-range id
1020    // -----------------------------------------------------------------------
1021    #[test]
1022    fn test_get_out_of_range() {
1023        let crowd = generate_crowd(small_config());
1024        assert!(crowd.get(9999).is_none());
1025    }
1026
1027    // -----------------------------------------------------------------------
1028    // 21. Crowd with zero characters is safe
1029    // -----------------------------------------------------------------------
1030    #[test]
1031    fn test_empty_crowd() {
1032        let cfg = CrowdConfig {
1033            count: 0,
1034            ..Default::default()
1035        };
1036        let crowd = generate_crowd(cfg);
1037        assert_eq!(crowd.count(), 0);
1038        assert_eq!(crowd.diversity_score(), 0.0);
1039        assert!(crowd.mean_params().is_empty());
1040        assert!(crowd.to_param_list().is_empty());
1041    }
1042}