1#![allow(dead_code)]
5
6use std::collections::HashMap;
7
8pub struct Lcg {
14 state: u64,
15}
16
17impl Lcg {
18 pub fn new(seed: u64) -> Self {
19 Self {
20 state: seed.wrapping_add(1),
21 }
22 }
23
24 pub fn next_f32(&mut self) -> f32 {
26 self.state = self
27 .state
28 .wrapping_mul(6364136223846793005)
29 .wrapping_add(1442695040888963407);
30 (self.state >> 33) as f32 / (u32::MAX as f32)
31 }
32
33 pub fn next_range(&mut self, min: f32, max: f32) -> f32 {
35 min + self.next_f32() * (max - min)
36 }
37
38 pub fn next_gaussian(&mut self, mean: f32, std_dev: f32) -> f32 {
40 let u1 = self.next_f32() + 1e-10;
41 let u2 = self.next_f32();
42 let z = (-2.0_f32 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
43 mean + std_dev * z
44 }
45}
46
47pub fn van_der_corput(n: usize, base: usize) -> f32 {
54 let mut result = 0.0_f64;
55 let mut denominator = 1.0_f64;
56 let mut n_remaining = n;
57 while n_remaining > 0 {
58 denominator *= base as f64;
59 result += (n_remaining % base) as f64 / denominator;
60 n_remaining /= base;
61 }
62 result as f32
63}
64
65pub enum SamplingStrategy {
71 Uniform,
73 Gaussian { std_dev: f32 },
75 LatinHypercube,
77 LowDiscrepancy,
79}
80
81pub struct ParamSpec {
87 pub name: String,
88 pub min: f32,
89 pub max: f32,
90 pub default: f32,
91 pub weight: f32,
92}
93
94impl ParamSpec {
95 pub fn new(name: impl Into<String>, min: f32, max: f32, default: f32) -> Self {
96 Self {
97 name: name.into(),
98 min,
99 max,
100 default,
101 weight: 1.0,
102 }
103 }
104
105 pub fn with_weight(mut self, weight: f32) -> Self {
106 self.weight = weight;
107 self
108 }
109}
110
111pub struct DiversitySampler {
117 params: Vec<ParamSpec>,
118 strategy: SamplingStrategy,
119 seed: u64,
120}
121
122const LD_PRIMES: [usize; 6] = [2, 3, 5, 7, 11, 13];
124
125impl DiversitySampler {
126 pub fn new(strategy: SamplingStrategy) -> Self {
127 Self {
128 params: Vec::new(),
129 strategy,
130 seed: 42,
131 }
132 }
133
134 pub fn with_seed(mut self, seed: u64) -> Self {
135 self.seed = seed;
136 self
137 }
138
139 pub fn add_param(&mut self, spec: ParamSpec) {
140 self.params.push(spec);
141 }
142
143 pub fn param_count(&self) -> usize {
144 self.params.len()
145 }
146
147 pub fn sample(&self, n: usize) -> Vec<HashMap<String, f32>> {
149 if n == 0 || self.params.is_empty() {
150 return Vec::new();
151 }
152
153 let mut rng = Lcg::new(self.seed);
154
155 match &self.strategy {
156 SamplingStrategy::Uniform => self.sample_uniform(&mut rng, n),
157 SamplingStrategy::Gaussian { std_dev } => {
158 let base: HashMap<String, f32> = self
160 .params
161 .iter()
162 .map(|p| (p.name.clone(), p.default))
163 .collect();
164 self.sample_gaussian(&mut rng, &base, *std_dev, n)
165 }
166 SamplingStrategy::LatinHypercube => self.sample_lhs(&mut rng, n),
167 SamplingStrategy::LowDiscrepancy => self.sample_ld(n),
168 }
169 }
170
171 pub fn sample_near(&self, base: &HashMap<String, f32>, n: usize) -> Vec<HashMap<String, f32>> {
173 if n == 0 || self.params.is_empty() {
174 return Vec::new();
175 }
176 let mut rng = Lcg::new(self.seed);
177 let std_dev = match &self.strategy {
178 SamplingStrategy::Gaussian { std_dev } => *std_dev,
179 _ => 0.1,
180 };
181 self.sample_gaussian(&mut rng, base, std_dev, n)
182 }
183
184 pub fn sample_with_extremes(&self, n: usize) -> Vec<HashMap<String, f32>> {
186 if self.params.is_empty() {
187 return Vec::new();
188 }
189
190 let mut result = Vec::with_capacity(n);
191
192 let min_sample: HashMap<String, f32> = self
194 .params
195 .iter()
196 .map(|p| (p.name.clone(), p.min))
197 .collect();
198 result.push(min_sample);
199
200 if n >= 2 {
202 let max_sample: HashMap<String, f32> = self
203 .params
204 .iter()
205 .map(|p| (p.name.clone(), p.max))
206 .collect();
207 result.push(max_sample);
208 }
209
210 if n > 2 {
212 let remaining = self.sample(n - 2);
213 result.extend(remaining);
214 }
215
216 result.truncate(n);
217 result
218 }
219
220 pub fn diversity_score(samples: &[HashMap<String, f32>]) -> f32 {
222 if samples.len() < 2 {
223 return 0.0;
224 }
225 let mut total = 0.0_f32;
226 let mut count = 0usize;
227
228 for i in 0..samples.len() {
229 for j in (i + 1)..samples.len() {
230 let sq_dist: f32 = samples[i]
231 .iter()
232 .filter_map(|(k, v)| samples[j].get(k).map(|w| (v - w).powi(2)))
233 .sum();
234 total += sq_dist.sqrt();
235 count += 1;
236 }
237 }
238
239 if count == 0 {
240 0.0
241 } else {
242 total / count as f32
243 }
244 }
245
246 fn sample_uniform(&self, rng: &mut Lcg, n: usize) -> Vec<HashMap<String, f32>> {
251 (0..n)
252 .map(|_| {
253 self.params
254 .iter()
255 .map(|p| (p.name.clone(), rng.next_range(p.min, p.max)))
256 .collect()
257 })
258 .collect()
259 }
260
261 fn sample_gaussian(
262 &self,
263 rng: &mut Lcg,
264 base: &HashMap<String, f32>,
265 std_dev: f32,
266 n: usize,
267 ) -> Vec<HashMap<String, f32>> {
268 (0..n)
269 .map(|_| {
270 self.params
271 .iter()
272 .map(|p| {
273 let center = base.get(&p.name).copied().unwrap_or(p.default);
274 let range = p.max - p.min;
275 let val = rng.next_gaussian(center, std_dev * range * p.weight);
276 (p.name.clone(), val.clamp(p.min, p.max))
277 })
278 .collect()
279 })
280 .collect()
281 }
282
283 fn sample_lhs(&self, rng: &mut Lcg, n: usize) -> Vec<HashMap<String, f32>> {
284 let param_strata: Vec<Vec<usize>> = self
286 .params
287 .iter()
288 .map(|_| {
289 let mut strata: Vec<usize> = (0..n).collect();
290 for i in (1..strata.len()).rev() {
292 let j = (rng.next_f32() * (i + 1) as f32) as usize;
293 let j = j.min(i);
294 strata.swap(i, j);
295 }
296 strata
297 })
298 .collect();
299
300 (0..n)
301 .map(|i| {
302 self.params
303 .iter()
304 .enumerate()
305 .map(|(dim, p)| {
306 let stratum = param_strata[dim][i];
307 let lo = stratum as f32 / n as f32;
309 let hi = (stratum + 1) as f32 / n as f32;
310 let t = lo + rng.next_f32() * (hi - lo);
311 let val = p.min + t * (p.max - p.min);
312 (p.name.clone(), val)
313 })
314 .collect()
315 })
316 .collect()
317 }
318
319 fn sample_ld(&self, n: usize) -> Vec<HashMap<String, f32>> {
320 (0..n)
321 .map(|i| {
322 self.params
323 .iter()
324 .enumerate()
325 .map(|(dim, p)| {
326 let t = if dim < LD_PRIMES.len() {
327 van_der_corput(i + 1, LD_PRIMES[dim])
329 } else {
330 let mut rng =
332 Lcg::new(self.seed.wrapping_add(dim as u64).wrapping_add(i as u64));
333 rng.next_f32()
334 };
335 let val = p.min + t.clamp(0.0, 1.0) * (p.max - p.min);
336 (p.name.clone(), val)
337 })
338 .collect()
339 })
340 .collect()
341 }
342}
343
344pub fn default_body_params() -> Vec<ParamSpec> {
350 vec![
351 ParamSpec::new("height", 0.0, 1.0, 0.5),
352 ParamSpec::new("weight", 0.0, 1.0, 0.5),
353 ParamSpec::new("muscle", 0.0, 1.0, 0.3),
354 ParamSpec::new("age", 0.0, 1.0, 0.35),
355 ParamSpec::new("bmi_factor", 0.0, 1.0, 0.4),
356 ParamSpec::new("shoulder_width", 0.0, 1.0, 0.5),
357 ParamSpec::new("hip_width", 0.0, 1.0, 0.5),
358 ]
359}
360
361pub fn generate_population(n: usize, seed: u64) -> Vec<HashMap<String, f32>> {
363 let mut sampler = DiversitySampler::new(SamplingStrategy::LatinHypercube).with_seed(seed);
364 for spec in default_body_params() {
365 sampler.add_param(spec);
366 }
367 sampler.sample(n)
368}
369
370#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_lcg_new() {
380 let lcg = Lcg::new(0);
381 assert_eq!(lcg.state, 1);
383
384 let lcg2 = Lcg::new(42);
385 assert_eq!(lcg2.state, 43);
386 }
387
388 #[test]
389 fn test_lcg_next_f32_range() {
390 let mut lcg = Lcg::new(12345);
391 for _ in 0..100 {
392 let v = lcg.next_f32();
393 assert!((0.0..1.0).contains(&v), "Expected [0,1), got {v}");
394 }
395 }
396
397 #[test]
398 fn test_lcg_next_range() {
399 let mut lcg = Lcg::new(99);
400 for _ in 0..100 {
401 let v = lcg.next_range(2.0, 5.0);
402 assert!((2.0..5.0).contains(&v), "Expected [2,5), got {v}");
403 }
404 }
405
406 #[test]
407 fn test_lcg_next_gaussian() {
408 let mut lcg = Lcg::new(777);
409 let mut sum = 0.0_f32;
410 let n = 1000;
411 for _ in 0..n {
412 sum += lcg.next_gaussian(0.5, 0.1);
413 }
414 let mean = sum / n as f32;
415 assert!((mean - 0.5).abs() < 0.05, "Mean {mean} not near 0.5");
417 }
418
419 #[test]
420 fn test_van_der_corput_base2() {
421 assert!((van_der_corput(1, 2) - 0.5).abs() < 1e-6);
423 assert!((van_der_corput(2, 2) - 0.25).abs() < 1e-6);
425 assert!((van_der_corput(3, 2) - 0.75).abs() < 1e-6);
427 assert!((van_der_corput(4, 2) - 0.125).abs() < 1e-6);
429 assert_eq!(van_der_corput(0, 2), 0.0);
431 }
432
433 #[test]
434 fn test_param_spec_new() {
435 let spec = ParamSpec::new("height", 0.0, 1.0, 0.5);
436 assert_eq!(spec.name, "height");
437 assert_eq!(spec.min, 0.0);
438 assert_eq!(spec.max, 1.0);
439 assert_eq!(spec.default, 0.5);
440 assert_eq!(spec.weight, 1.0);
441
442 let spec2 = spec.with_weight(2.5);
443 assert_eq!(spec2.weight, 2.5);
444 }
445
446 fn make_sampler(strategy: SamplingStrategy) -> DiversitySampler {
447 let mut s = DiversitySampler::new(strategy).with_seed(42);
448 s.add_param(ParamSpec::new("height", 0.0, 1.0, 0.5));
449 s.add_param(ParamSpec::new("weight", 0.0, 1.0, 0.5));
450 s.add_param(ParamSpec::new("age", 0.0, 1.0, 0.35));
451 s
452 }
453
454 #[test]
455 fn test_sampler_uniform() {
456 let s = make_sampler(SamplingStrategy::Uniform);
457 let samples = s.sample(20);
458 assert_eq!(samples.len(), 20);
459 for sample in &samples {
460 assert_eq!(sample.len(), 3);
461 for v in sample.values() {
462 assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
463 }
464 }
465 }
466
467 #[test]
468 fn test_sampler_gaussian() {
469 let s = make_sampler(SamplingStrategy::Gaussian { std_dev: 0.1 });
470 let samples = s.sample(50);
471 assert_eq!(samples.len(), 50);
472 for sample in &samples {
473 for v in sample.values() {
474 assert!(*v >= 0.0 && *v <= 1.0, "Out of [0,1]: {v}");
475 }
476 }
477 }
478
479 #[test]
480 fn test_sampler_latin_hypercube() {
481 let s = make_sampler(SamplingStrategy::LatinHypercube);
482 let samples = s.sample(10);
483 assert_eq!(samples.len(), 10);
484 for sample in &samples {
486 for v in sample.values() {
487 assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
488 }
489 }
490 let heights: Vec<f32> = samples
493 .iter()
494 .map(|m| *m.get("height").expect("should succeed"))
495 .collect();
496 for i in 0..heights.len() {
498 for j in (i + 1)..heights.len() {
499 assert!(
500 (heights[i] - heights[j]).abs() > 1e-6,
501 "LHS produced duplicate heights at {i},{j}"
502 );
503 }
504 }
505 }
506
507 #[test]
508 fn test_sampler_low_discrepancy() {
509 let s = make_sampler(SamplingStrategy::LowDiscrepancy);
510 let samples = s.sample(16);
511 assert_eq!(samples.len(), 16);
512 for sample in &samples {
513 for v in sample.values() {
514 assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
515 }
516 }
517 for (i, sample) in samples.iter().enumerate() {
519 let expected = van_der_corput(i + 1, 2);
520 let actual = *sample.get("height").expect("should succeed");
521 assert!(
522 (actual - expected).abs() < 1e-5,
523 "LD mismatch at i={i}: expected {expected}, got {actual}"
524 );
525 }
526 }
527
528 #[test]
529 fn test_sample_near() {
530 let mut s =
531 DiversitySampler::new(SamplingStrategy::Gaussian { std_dev: 0.05 }).with_seed(7);
532 s.add_param(ParamSpec::new("height", 0.0, 1.0, 0.5));
533 s.add_param(ParamSpec::new("weight", 0.0, 1.0, 0.5));
534
535 let base: HashMap<String, f32> =
536 [("height".to_string(), 0.8), ("weight".to_string(), 0.2)].into();
537
538 let samples = s.sample_near(&base, 30);
539 assert_eq!(samples.len(), 30);
540
541 let mut near_count = 0;
543 for sample in &samples {
544 let h = sample["height"];
545 let w = sample["weight"];
546 if (h - 0.8).abs() < 0.3 && (w - 0.2).abs() < 0.3 {
547 near_count += 1;
548 }
549 }
550 assert!(
551 near_count >= 20,
552 "Expected most samples near base, got {near_count}/30"
553 );
554 }
555
556 #[test]
557 fn test_diversity_score() {
558 let s1: HashMap<String, f32> = [("a".to_string(), 0.5)].into();
560 let identical = vec![s1.clone(), s1.clone()];
561 assert_eq!(DiversitySampler::diversity_score(&identical), 0.0);
562
563 let lo: HashMap<String, f32> = [("x".to_string(), 0.0), ("y".to_string(), 0.0)].into();
565 let hi: HashMap<String, f32> = [("x".to_string(), 1.0), ("y".to_string(), 1.0)].into();
566 let spread = vec![lo, hi];
567 let score = DiversitySampler::diversity_score(&spread);
568 assert!(
570 (score - 2.0_f32.sqrt()).abs() < 1e-5,
571 "Expected sqrt(2), got {score}"
572 );
573
574 let single = vec![s1];
576 assert_eq!(DiversitySampler::diversity_score(&single), 0.0);
577 }
578
579 #[test]
580 fn test_default_body_params() {
581 let params = default_body_params();
582 assert_eq!(params.len(), 7);
583
584 let names: Vec<&str> = params.iter().map(|p| p.name.as_str()).collect();
585 assert!(names.contains(&"height"));
586 assert!(names.contains(&"weight"));
587 assert!(names.contains(&"muscle"));
588 assert!(names.contains(&"age"));
589 assert!(names.contains(&"bmi_factor"));
590 assert!(names.contains(&"shoulder_width"));
591 assert!(names.contains(&"hip_width"));
592
593 for p in ¶ms {
594 assert_eq!(p.min, 0.0);
595 assert_eq!(p.max, 1.0);
596 assert!(p.default >= 0.0 && p.default <= 1.0);
597 }
598 }
599
600 #[test]
601 fn test_generate_population() {
602 let pop = generate_population(20, 42);
603 assert_eq!(pop.len(), 20);
604 for individual in &pop {
605 assert_eq!(individual.len(), 7);
606 for v in individual.values() {
607 assert!(*v >= 0.0 && *v <= 1.0, "Out of range: {v}");
608 }
609 }
610 let pop2 = generate_population(20, 42);
612 assert_eq!(pop.len(), pop2.len());
613 for (a, b) in pop.iter().zip(pop2.iter()) {
614 for (k, v) in a {
615 assert_eq!(*v, *b.get(k).expect("should succeed"));
616 }
617 }
618 }
619
620 #[test]
621 fn test_sample_with_extremes() {
622 let s = make_sampler(SamplingStrategy::Uniform);
623 let samples = s.sample_with_extremes(10);
624 assert_eq!(samples.len(), 10);
625
626 let first = &samples[0];
628 for v in first.values() {
629 assert_eq!(*v, 0.0, "First sample should be all mins");
630 }
631
632 let second = &samples[1];
634 for v in second.values() {
635 assert_eq!(*v, 1.0, "Second sample should be all maxes");
636 }
637
638 for sample in &samples {
640 for v in sample.values() {
641 assert!(*v >= 0.0 && *v <= 1.0);
642 }
643 }
644
645 let empty = DiversitySampler::new(SamplingStrategy::Uniform).sample_with_extremes(5);
647 assert!(empty.is_empty());
648 }
649}