Skip to main content

oxihuman_morph/
diversity.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8// ---------------------------------------------------------------------------
9// Simple LCG pseudo-random number generator (no rand crate needed)
10// ---------------------------------------------------------------------------
11
12/// Simple Linear Congruential Generator for deterministic randomness.
13pub struct Lcg {
14    state: u64,
15}
16
17impl Lcg {
18    pub fn new(seed: u64) -> Self {
19        Self {
20            state: seed.wrapping_add(1),
21        }
22    }
23
24    /// Next f32 in [0, 1).
25    pub fn next_f32(&mut self) -> f32 {
26        self.state = self
27            .state
28            .wrapping_mul(6364136223846793005)
29            .wrapping_add(1442695040888963407);
30        (self.state >> 33) as f32 / (u32::MAX as f32)
31    }
32
33    /// Next f32 in [min, max).
34    pub fn next_range(&mut self, min: f32, max: f32) -> f32 {
35        min + self.next_f32() * (max - min)
36    }
37
38    /// Box-Muller transform: N(mean, std_dev).
39    pub fn next_gaussian(&mut self, mean: f32, std_dev: f32) -> f32 {
40        let u1 = self.next_f32() + 1e-10;
41        let u2 = self.next_f32();
42        let z = (-2.0_f32 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
43        mean + std_dev * z
44    }
45}
46
47// ---------------------------------------------------------------------------
48// Van der Corput low-discrepancy sequence
49// ---------------------------------------------------------------------------
50
51/// Van der Corput sequence value for index `n` in base `base`.
52/// Reflects n's base-b digits about the decimal point.
53pub fn van_der_corput(n: usize, base: usize) -> f32 {
54    let mut result = 0.0_f64;
55    let mut denominator = 1.0_f64;
56    let mut n_remaining = n;
57    while n_remaining > 0 {
58        denominator *= base as f64;
59        result += (n_remaining % base) as f64 / denominator;
60        n_remaining /= base;
61    }
62    result as f32
63}
64
65// ---------------------------------------------------------------------------
66// Sampling strategy
67// ---------------------------------------------------------------------------
68
69/// Strategy for generating varied parameter sets.
70pub enum SamplingStrategy {
71    /// Pure uniform random in `[0,1]`.
72    Uniform,
73    /// Normal distribution centered at base params, clamped to `[0,1]`.
74    Gaussian { std_dev: f32 },
75    /// Latin hypercube sampling for uniform coverage.
76    LatinHypercube,
77    /// Sobol-like low-discrepancy (simple van der Corput sequence).
78    LowDiscrepancy,
79}
80
81// ---------------------------------------------------------------------------
82// Parameter specification
83// ---------------------------------------------------------------------------
84
85/// A parameter specification with name, range, and distribution hint.
86pub struct ParamSpec {
87    pub name: String,
88    pub min: f32,
89    pub max: f32,
90    pub default: f32,
91    pub weight: f32,
92}
93
94impl ParamSpec {
95    pub fn new(name: impl Into<String>, min: f32, max: f32, default: f32) -> Self {
96        Self {
97            name: name.into(),
98            min,
99            max,
100            default,
101            weight: 1.0,
102        }
103    }
104
105    pub fn with_weight(mut self, weight: f32) -> Self {
106        self.weight = weight;
107        self
108    }
109}
110
111// ---------------------------------------------------------------------------
112// Diversity sampler
113// ---------------------------------------------------------------------------
114
115/// Diversity sampler that generates varied body parameter sets.
116pub struct DiversitySampler {
117    params: Vec<ParamSpec>,
118    strategy: SamplingStrategy,
119    seed: u64,
120}
121
122/// First 6 primes for low-discrepancy sequence (one per dimension).
123const LD_PRIMES: [usize; 6] = [2, 3, 5, 7, 11, 13];
124
125impl DiversitySampler {
126    pub fn new(strategy: SamplingStrategy) -> Self {
127        Self {
128            params: Vec::new(),
129            strategy,
130            seed: 42,
131        }
132    }
133
134    pub fn with_seed(mut self, seed: u64) -> Self {
135        self.seed = seed;
136        self
137    }
138
139    pub fn add_param(&mut self, spec: ParamSpec) {
140        self.params.push(spec);
141    }
142
143    pub fn param_count(&self) -> usize {
144        self.params.len()
145    }
146
147    /// Generate N diverse parameter sets.
148    pub fn sample(&self, n: usize) -> Vec<HashMap<String, f32>> {
149        if n == 0 || self.params.is_empty() {
150            return Vec::new();
151        }
152
153        let mut rng = Lcg::new(self.seed);
154
155        match &self.strategy {
156            SamplingStrategy::Uniform => self.sample_uniform(&mut rng, n),
157            SamplingStrategy::Gaussian { std_dev } => {
158                // Use defaults as the base
159                let base: HashMap<String, f32> = self
160                    .params
161                    .iter()
162                    .map(|p| (p.name.clone(), p.default))
163                    .collect();
164                self.sample_gaussian(&mut rng, &base, *std_dev, n)
165            }
166            SamplingStrategy::LatinHypercube => self.sample_lhs(&mut rng, n),
167            SamplingStrategy::LowDiscrepancy => self.sample_ld(n),
168        }
169    }
170
171    /// Generate one sample near given base parameters.
172    pub fn sample_near(&self, base: &HashMap<String, f32>, n: usize) -> Vec<HashMap<String, f32>> {
173        if n == 0 || self.params.is_empty() {
174            return Vec::new();
175        }
176        let mut rng = Lcg::new(self.seed);
177        let std_dev = match &self.strategy {
178            SamplingStrategy::Gaussian { std_dev } => *std_dev,
179            _ => 0.1,
180        };
181        self.sample_gaussian(&mut rng, base, std_dev, n)
182    }
183
184    /// Generate population with guaranteed coverage of extremes.
185    pub fn sample_with_extremes(&self, n: usize) -> Vec<HashMap<String, f32>> {
186        if self.params.is_empty() {
187            return Vec::new();
188        }
189
190        let mut result = Vec::with_capacity(n);
191
192        // First sample: all minimums
193        let min_sample: HashMap<String, f32> = self
194            .params
195            .iter()
196            .map(|p| (p.name.clone(), p.min))
197            .collect();
198        result.push(min_sample);
199
200        // Second sample: all maximums
201        if n >= 2 {
202            let max_sample: HashMap<String, f32> = self
203                .params
204                .iter()
205                .map(|p| (p.name.clone(), p.max))
206                .collect();
207            result.push(max_sample);
208        }
209
210        // Fill the rest with normal sampling
211        if n > 2 {
212            let remaining = self.sample(n - 2);
213            result.extend(remaining);
214        }
215
216        result.truncate(n);
217        result
218    }
219
220    /// Compute diversity score: average pairwise L2 distance between samples.
221    pub fn diversity_score(samples: &[HashMap<String, f32>]) -> f32 {
222        if samples.len() < 2 {
223            return 0.0;
224        }
225        let mut total = 0.0_f32;
226        let mut count = 0usize;
227
228        for i in 0..samples.len() {
229            for j in (i + 1)..samples.len() {
230                let sq_dist: f32 = samples[i]
231                    .iter()
232                    .filter_map(|(k, v)| samples[j].get(k).map(|w| (v - w).powi(2)))
233                    .sum();
234                total += sq_dist.sqrt();
235                count += 1;
236            }
237        }
238
239        if count == 0 {
240            0.0
241        } else {
242            total / count as f32
243        }
244    }
245
246    // -----------------------------------------------------------------------
247    // Internal helpers
248    // -----------------------------------------------------------------------
249
250    fn sample_uniform(&self, rng: &mut Lcg, n: usize) -> Vec<HashMap<String, f32>> {
251        (0..n)
252            .map(|_| {
253                self.params
254                    .iter()
255                    .map(|p| (p.name.clone(), rng.next_range(p.min, p.max)))
256                    .collect()
257            })
258            .collect()
259    }
260
261    fn sample_gaussian(
262        &self,
263        rng: &mut Lcg,
264        base: &HashMap<String, f32>,
265        std_dev: f32,
266        n: usize,
267    ) -> Vec<HashMap<String, f32>> {
268        (0..n)
269            .map(|_| {
270                self.params
271                    .iter()
272                    .map(|p| {
273                        let center = base.get(&p.name).copied().unwrap_or(p.default);
274                        let range = p.max - p.min;
275                        let val = rng.next_gaussian(center, std_dev * range * p.weight);
276                        (p.name.clone(), val.clamp(p.min, p.max))
277                    })
278                    .collect()
279            })
280            .collect()
281    }
282
283    fn sample_lhs(&self, rng: &mut Lcg, n: usize) -> Vec<HashMap<String, f32>> {
284        // For each parameter, create a permutation of strata [0..n)
285        let param_strata: Vec<Vec<usize>> = self
286            .params
287            .iter()
288            .map(|_| {
289                let mut strata: Vec<usize> = (0..n).collect();
290                // Fisher-Yates shuffle using our LCG
291                for i in (1..strata.len()).rev() {
292                    let j = (rng.next_f32() * (i + 1) as f32) as usize;
293                    let j = j.min(i);
294                    strata.swap(i, j);
295                }
296                strata
297            })
298            .collect();
299
300        (0..n)
301            .map(|i| {
302                self.params
303                    .iter()
304                    .enumerate()
305                    .map(|(dim, p)| {
306                        let stratum = param_strata[dim][i];
307                        // Sample uniformly within stratum
308                        let lo = stratum as f32 / n as f32;
309                        let hi = (stratum + 1) as f32 / n as f32;
310                        let t = lo + rng.next_f32() * (hi - lo);
311                        let val = p.min + t * (p.max - p.min);
312                        (p.name.clone(), val)
313                    })
314                    .collect()
315            })
316            .collect()
317    }
318
319    fn sample_ld(&self, n: usize) -> Vec<HashMap<String, f32>> {
320        (0..n)
321            .map(|i| {
322                self.params
323                    .iter()
324                    .enumerate()
325                    .map(|(dim, p)| {
326                        let t = if dim < LD_PRIMES.len() {
327                            // Use 1-indexed to avoid the trivial 0 value
328                            van_der_corput(i + 1, LD_PRIMES[dim])
329                        } else {
330                            // Fall back to uniform via van_der_corput base 2 offset
331                            let mut rng =
332                                Lcg::new(self.seed.wrapping_add(dim as u64).wrapping_add(i as u64));
333                            rng.next_f32()
334                        };
335                        let val = p.min + t.clamp(0.0, 1.0) * (p.max - p.min);
336                        (p.name.clone(), val)
337                    })
338                    .collect()
339            })
340            .collect()
341    }
342}
343
344// ---------------------------------------------------------------------------
345// Default human body parameter specs
346// ---------------------------------------------------------------------------
347
348/// Human body parameter specs (height, weight, muscle, age, etc.).
349pub fn default_body_params() -> Vec<ParamSpec> {
350    vec![
351        ParamSpec::new("height", 0.0, 1.0, 0.5),
352        ParamSpec::new("weight", 0.0, 1.0, 0.5),
353        ParamSpec::new("muscle", 0.0, 1.0, 0.3),
354        ParamSpec::new("age", 0.0, 1.0, 0.35),
355        ParamSpec::new("bmi_factor", 0.0, 1.0, 0.4),
356        ParamSpec::new("shoulder_width", 0.0, 1.0, 0.5),
357        ParamSpec::new("hip_width", 0.0, 1.0, 0.5),
358    ]
359}
360
361/// Quick-generate N random body profiles using LatinHypercube strategy.
362pub fn generate_population(n: usize, seed: u64) -> Vec<HashMap<String, f32>> {
363    let mut sampler = DiversitySampler::new(SamplingStrategy::LatinHypercube).with_seed(seed);
364    for spec in default_body_params() {
365        sampler.add_param(spec);
366    }
367    sampler.sample(n)
368}
369
370// ---------------------------------------------------------------------------
371// Tests
372// ---------------------------------------------------------------------------
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_lcg_new() {
380        let lcg = Lcg::new(0);
381        // State should be seed + 1 = 1 initially
382        assert_eq!(lcg.state, 1);
383
384        let lcg2 = Lcg::new(42);
385        assert_eq!(lcg2.state, 43);
386    }
387
388    #[test]
389    fn test_lcg_next_f32_range() {
390        let mut lcg = Lcg::new(12345);
391        for _ in 0..100 {
392            let v = lcg.next_f32();
393            assert!((0.0..1.0).contains(&v), "Expected [0,1), got {v}");
394        }
395    }
396
397    #[test]
398    fn test_lcg_next_range() {
399        let mut lcg = Lcg::new(99);
400        for _ in 0..100 {
401            let v = lcg.next_range(2.0, 5.0);
402            assert!((2.0..5.0).contains(&v), "Expected [2,5), got {v}");
403        }
404    }
405
406    #[test]
407    fn test_lcg_next_gaussian() {
408        let mut lcg = Lcg::new(777);
409        let mut sum = 0.0_f32;
410        let n = 1000;
411        for _ in 0..n {
412            sum += lcg.next_gaussian(0.5, 0.1);
413        }
414        let mean = sum / n as f32;
415        // Mean should be close to 0.5
416        assert!((mean - 0.5).abs() < 0.05, "Mean {mean} not near 0.5");
417    }
418
419    #[test]
420    fn test_van_der_corput_base2() {
421        // n=1 in base 2: 1 -> 0.1 in binary = 0.5
422        assert!((van_der_corput(1, 2) - 0.5).abs() < 1e-6);
423        // n=2: 10 -> 0.01 = 0.25
424        assert!((van_der_corput(2, 2) - 0.25).abs() < 1e-6);
425        // n=3: 11 -> 0.11 = 0.75
426        assert!((van_der_corput(3, 2) - 0.75).abs() < 1e-6);
427        // n=4: 100 -> 0.001 = 0.125
428        assert!((van_der_corput(4, 2) - 0.125).abs() < 1e-6);
429        // n=0 should give 0
430        assert_eq!(van_der_corput(0, 2), 0.0);
431    }
432
433    #[test]
434    fn test_param_spec_new() {
435        let spec = ParamSpec::new("height", 0.0, 1.0, 0.5);
436        assert_eq!(spec.name, "height");
437        assert_eq!(spec.min, 0.0);
438        assert_eq!(spec.max, 1.0);
439        assert_eq!(spec.default, 0.5);
440        assert_eq!(spec.weight, 1.0);
441
442        let spec2 = spec.with_weight(2.5);
443        assert_eq!(spec2.weight, 2.5);
444    }
445
446    fn make_sampler(strategy: SamplingStrategy) -> DiversitySampler {
447        let mut s = DiversitySampler::new(strategy).with_seed(42);
448        s.add_param(ParamSpec::new("height", 0.0, 1.0, 0.5));
449        s.add_param(ParamSpec::new("weight", 0.0, 1.0, 0.5));
450        s.add_param(ParamSpec::new("age", 0.0, 1.0, 0.35));
451        s
452    }
453
454    #[test]
455    fn test_sampler_uniform() {
456        let s = make_sampler(SamplingStrategy::Uniform);
457        let samples = s.sample(20);
458        assert_eq!(samples.len(), 20);
459        for sample in &samples {
460            assert_eq!(sample.len(), 3);
461            for v in sample.values() {
462                assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
463            }
464        }
465    }
466
467    #[test]
468    fn test_sampler_gaussian() {
469        let s = make_sampler(SamplingStrategy::Gaussian { std_dev: 0.1 });
470        let samples = s.sample(50);
471        assert_eq!(samples.len(), 50);
472        for sample in &samples {
473            for v in sample.values() {
474                assert!(*v >= 0.0 && *v <= 1.0, "Out of [0,1]: {v}");
475            }
476        }
477    }
478
479    #[test]
480    fn test_sampler_latin_hypercube() {
481        let s = make_sampler(SamplingStrategy::LatinHypercube);
482        let samples = s.sample(10);
483        assert_eq!(samples.len(), 10);
484        // All values in range
485        for sample in &samples {
486            for v in sample.values() {
487                assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
488            }
489        }
490        // LHS: each stratum [k/n, (k+1)/n] for each param should be covered
491        // Check that no two samples have identical height values (very unlikely to collide)
492        let heights: Vec<f32> = samples
493            .iter()
494            .map(|m| *m.get("height").expect("should succeed"))
495            .collect();
496        // All values should be distinct (LHS guarantee)
497        for i in 0..heights.len() {
498            for j in (i + 1)..heights.len() {
499                assert!(
500                    (heights[i] - heights[j]).abs() > 1e-6,
501                    "LHS produced duplicate heights at {i},{j}"
502                );
503            }
504        }
505    }
506
507    #[test]
508    fn test_sampler_low_discrepancy() {
509        let s = make_sampler(SamplingStrategy::LowDiscrepancy);
510        let samples = s.sample(16);
511        assert_eq!(samples.len(), 16);
512        for sample in &samples {
513            for v in sample.values() {
514                assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
515            }
516        }
517        // Check that height values follow van_der_corput(1..17, 2) pattern
518        for (i, sample) in samples.iter().enumerate() {
519            let expected = van_der_corput(i + 1, 2);
520            let actual = *sample.get("height").expect("should succeed");
521            assert!(
522                (actual - expected).abs() < 1e-5,
523                "LD mismatch at i={i}: expected {expected}, got {actual}"
524            );
525        }
526    }
527
528    #[test]
529    fn test_sample_near() {
530        let mut s =
531            DiversitySampler::new(SamplingStrategy::Gaussian { std_dev: 0.05 }).with_seed(7);
532        s.add_param(ParamSpec::new("height", 0.0, 1.0, 0.5));
533        s.add_param(ParamSpec::new("weight", 0.0, 1.0, 0.5));
534
535        let base: HashMap<String, f32> =
536            [("height".to_string(), 0.8), ("weight".to_string(), 0.2)].into();
537
538        let samples = s.sample_near(&base, 30);
539        assert_eq!(samples.len(), 30);
540
541        // Most samples should be near the base values
542        let mut near_count = 0;
543        for sample in &samples {
544            let h = sample["height"];
545            let w = sample["weight"];
546            if (h - 0.8).abs() < 0.3 && (w - 0.2).abs() < 0.3 {
547                near_count += 1;
548            }
549        }
550        assert!(
551            near_count >= 20,
552            "Expected most samples near base, got {near_count}/30"
553        );
554    }
555
556    #[test]
557    fn test_diversity_score() {
558        // Identical samples -> score 0
559        let s1: HashMap<String, f32> = [("a".to_string(), 0.5)].into();
560        let identical = vec![s1.clone(), s1.clone()];
561        assert_eq!(DiversitySampler::diversity_score(&identical), 0.0);
562
563        // Two maximally spread samples
564        let lo: HashMap<String, f32> = [("x".to_string(), 0.0), ("y".to_string(), 0.0)].into();
565        let hi: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 1.0)].into();
566        let spread = vec![lo, hi];
567        let score = DiversitySampler::diversity_score(&spread);
568        // L2 distance = sqrt(1^2 + 1^2) = sqrt(2)
569        assert!(
570            (score - 2.0_f32.sqrt()).abs() < 1e-5,
571            "Expected sqrt(2), got {score}"
572        );
573
574        // Single sample -> score 0
575        let single = vec![s1];
576        assert_eq!(DiversitySampler::diversity_score(&single), 0.0);
577    }
578
579    #[test]
580    fn test_default_body_params() {
581        let params = default_body_params();
582        assert_eq!(params.len(), 7);
583
584        let names: Vec<&str> = params.iter().map(|p| p.name.as_str()).collect();
585        assert!(names.contains(&"height"));
586        assert!(names.contains(&"weight"));
587        assert!(names.contains(&"muscle"));
588        assert!(names.contains(&"age"));
589        assert!(names.contains(&"bmi_factor"));
590        assert!(names.contains(&"shoulder_width"));
591        assert!(names.contains(&"hip_width"));
592
593        for p in &params {
594            assert_eq!(p.min, 0.0);
595            assert_eq!(p.max, 1.0);
596            assert!(p.default >= 0.0 && p.default <= 1.0);
597        }
598    }
599
600    #[test]
601    fn test_generate_population() {
602        let pop = generate_population(20, 42);
603        assert_eq!(pop.len(), 20);
604        for individual in &pop {
605            assert_eq!(individual.len(), 7);
606            for v in individual.values() {
607                assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
608            }
609        }
610        // Deterministic: same seed should give same result
611        let pop2 = generate_population(20, 42);
612        assert_eq!(pop.len(), pop2.len());
613        for (a, b) in pop.iter().zip(pop2.iter()) {
614            for (k, v) in a {
615                assert_eq!(*v, *b.get(k).expect("should succeed"));
616            }
617        }
618    }
619
620    #[test]
621    fn test_sample_with_extremes() {
622        let s = make_sampler(SamplingStrategy::Uniform);
623        let samples = s.sample_with_extremes(10);
624        assert_eq!(samples.len(), 10);
625
626        // First sample should be all minimums
627        let first = &samples[0];
628        for v in first.values() {
629            assert_eq!(*v, 0.0, "First sample should be all mins");
630        }
631
632        // Second sample should be all maximums
633        let second = &samples[1];
634        for v in second.values() {
635            assert_eq!(*v, 1.0, "Second sample should be all maxes");
636        }
637
638        // All values in range
639        for sample in &samples {
640            for v in sample.values() {
641                assert!(*v >= 0.0 && *v <= 1.0);
642            }
643        }
644
645        // Empty case
646        let empty = DiversitySampler::new(SamplingStrategy::Uniform).sample_with_extremes(5);
647        assert!(empty.is_empty());
648    }
649}