Skip to main content

datasynth_core/
utils.rs

1//! Shared generator utilities.
2
3use rand::Rng;
4
5/// Select from weighted options. Weights don't need to sum to 1.0.
6pub fn weighted_select<'a, T, R: Rng>(rng: &mut R, options: &'a [(T, f64)]) -> &'a T {
7    let total: f64 = options.iter().map(|(_, w)| w).sum();
8    let mut roll = rng.gen::<f64>() * total;
9    for (item, weight) in options {
10        roll -= weight;
11        if roll <= 0.0 {
12            return item;
13        }
14    }
15    &options
16        .last()
17        .expect("weighted_select called with empty options")
18        .0
19}
20
21/// Sample a Decimal in a range using the RNG.
22pub fn sample_decimal_range<R: Rng>(
23    rng: &mut R,
24    min: rust_decimal::Decimal,
25    max: rust_decimal::Decimal,
26) -> rust_decimal::Decimal {
27    use rust_decimal::prelude::ToPrimitive;
28    let min_f = min.to_f64().unwrap_or(0.0);
29    let max_f = max.to_f64().unwrap_or(min_f + 1.0);
30    let val = rng.gen_range(min_f..=max_f);
31    rust_decimal::Decimal::from_f64_retain(val).unwrap_or(min)
32}
33
34/// Create a seeded RNG for a generator, with an optional discriminator for sub-generators.
35pub fn seeded_rng(seed: u64, discriminator: u64) -> rand_chacha::ChaCha8Rng {
36    use rand::SeedableRng;
37    rand_chacha::ChaCha8Rng::seed_from_u64(seed.wrapping_add(discriminator))
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use rand::SeedableRng;
44    use rand_chacha::ChaCha8Rng;
45
46    #[test]
47    fn test_weighted_select_distribution() {
48        let mut rng = ChaCha8Rng::seed_from_u64(42);
49        let options = vec![("a", 0.9), ("b", 0.1)];
50        let mut a_count = 0;
51        for _ in 0..100 {
52            if *weighted_select(&mut rng, &options) == "a" {
53                a_count += 1;
54            }
55        }
56        assert!(a_count > 70, "Expected ~90% 'a', got {}", a_count);
57    }
58
59    #[test]
60    fn test_weighted_select_single_option() {
61        let mut rng = ChaCha8Rng::seed_from_u64(42);
62        let options = vec![("only", 1.0)];
63        assert_eq!(*weighted_select(&mut rng, &options), "only");
64    }
65
66    #[test]
67    fn test_sample_decimal_range() {
68        let mut rng = ChaCha8Rng::seed_from_u64(42);
69        let min = rust_decimal::Decimal::new(100, 0);
70        let max = rust_decimal::Decimal::new(200, 0);
71        for _ in 0..100 {
72            let val = sample_decimal_range(&mut rng, min, max);
73            assert!(
74                val >= min && val <= max,
75                "Value {} outside [{}, {}]",
76                val,
77                min,
78                max
79            );
80        }
81    }
82
83    #[test]
84    fn test_seeded_rng_deterministic() {
85        let rng1 = seeded_rng(42, 100);
86        let rng2 = seeded_rng(42, 100);
87        // Same seed + discriminator should produce same state
88        assert_eq!(format!("{:?}", rng1), format!("{:?}", rng2));
89    }
90
91    #[test]
92    fn test_seeded_rng_different_discriminators() {
93        let mut rng1 = seeded_rng(42, 0);
94        let mut rng2 = seeded_rng(42, 1);
95        let val1: f64 = rng1.gen();
96        let val2: f64 = rng2.gen();
97        assert_ne!(
98            val1, val2,
99            "Different discriminators should produce different values"
100        );
101    }
102}