use rhai::{Engine, EvalAltResult};
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Mutex;
lazy_static::lazy_static! {
static ref RNG: Mutex<fastrand::Rng> = Mutex::new(fastrand::Rng::new());
}
thread_local! {
static SAMPLE_COUNTERS: RefCell<HashMap<i64, i64>> = RefCell::new(HashMap::new());
}
fn rand_float() -> Result<f64, Box<EvalAltResult>> {
let mut rng = RNG.lock().unwrap();
Ok(rng.f64())
}
fn rand_int_range(min: i64, max: i64) -> Result<i64, Box<EvalAltResult>> {
if min > max {
return Err(format!(
"rand_int: min ({}) cannot be greater than max ({})",
min, max
)
.into());
}
let mut rng = RNG.lock().unwrap();
Ok(rng.i64(min..=max))
}
fn sample_every(n: i64) -> Result<bool, Box<EvalAltResult>> {
if n <= 0 {
return Err(format!("sample_every: n must be positive, got {}", n).into());
}
SAMPLE_COUNTERS.with(|counters| {
let mut map = counters.borrow_mut();
let counter = map.entry(n).or_insert(0);
*counter += 1;
if *counter >= n {
*counter = 0;
Ok(true)
} else {
Ok(false)
}
})
}
#[cfg(test)]
pub fn clear_sample_counters() {
SAMPLE_COUNTERS.with(|counters| {
counters.borrow_mut().clear();
});
}
fn sample_prob(p: f64) -> Result<bool, Box<EvalAltResult>> {
if !(0.0..=1.0).contains(&p) {
return Err(format!(
"sample_prob: probability must be between 0.0 and 1.0, got {}",
p
)
.into());
}
let mut rng = RNG.lock().unwrap();
Ok(rng.f64() < p)
}
pub fn register_functions(engine: &mut Engine) {
engine.register_fn("rand", rand_float);
engine.register_fn("rand_int", rand_int_range);
engine.register_fn("sample_every", sample_every);
engine.register_fn("sample_prob", sample_prob);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rand_float() {
for _ in 0..100 {
let val = rand_float().unwrap();
assert!(
(0.0..1.0).contains(&val),
"rand() should return value in [0.0, 1.0), got {}",
val
);
}
}
#[test]
fn test_rand_int_range() {
for _ in 0..100 {
let val = rand_int_range(1, 10).unwrap();
assert!(
(1..=10).contains(&val),
"rand_int(1, 10) should return value in [1, 10], got {}",
val
);
}
let val = rand_int_range(5, 5).unwrap();
assert_eq!(val, 5);
for _ in 0..100 {
let val = rand_int_range(-10, -1).unwrap();
assert!(
(-10..=-1).contains(&val),
"rand_int(-10, -1) should return value in [-10, -1], got {}",
val
);
}
}
#[test]
fn test_rand_int_invalid_range() {
let result = rand_int_range(10, 5);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("min (10) cannot be greater than max (5)"));
}
#[test]
fn test_sample_every_basic() {
clear_sample_counters();
assert!(!sample_every(3).unwrap()); assert!(!sample_every(3).unwrap()); assert!(sample_every(3).unwrap()); assert!(!sample_every(3).unwrap()); assert!(!sample_every(3).unwrap()); assert!(sample_every(3).unwrap()); }
#[test]
fn test_sample_every_n_equals_1() {
clear_sample_counters();
assert!(sample_every(1).unwrap());
assert!(sample_every(1).unwrap());
assert!(sample_every(1).unwrap());
}
#[test]
fn test_sample_every_independent_counters() {
clear_sample_counters();
assert!(!sample_every(2).unwrap()); assert!(!sample_every(3).unwrap()); assert!(sample_every(2).unwrap()); assert!(!sample_every(3).unwrap()); assert!(!sample_every(2).unwrap()); assert!(sample_every(3).unwrap()); }
#[test]
fn test_sample_every_invalid_n() {
clear_sample_counters();
let result = sample_every(0);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("n must be positive"));
let result = sample_every(-5);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("n must be positive"));
}
#[test]
fn test_sample_every_large_n() {
clear_sample_counters();
for i in 1..=100 {
let result = sample_every(100).unwrap();
if i == 100 {
assert!(result, "Should return true on 100th call");
} else {
assert!(!result, "Should return false on call {}", i);
}
}
}
#[test]
fn test_sample_every_with_rhai() {
clear_sample_counters();
let mut engine = Engine::new();
register_functions(&mut engine);
let result: bool = engine.eval("sample_every(2)").unwrap();
assert!(!result);
let result: bool = engine.eval("sample_every(2)").unwrap();
assert!(result);
let result: Result<bool, _> = engine.eval("sample_every(0)");
assert!(result.is_err());
let result: Result<bool, _> = engine.eval("sample_every(-1)");
assert!(result.is_err());
}
#[test]
fn test_sample_prob_always_true() {
for _ in 0..100 {
assert!(sample_prob(1.0).unwrap());
}
}
#[test]
fn test_sample_prob_always_false() {
for _ in 0..100 {
assert!(!sample_prob(0.0).unwrap());
}
}
#[test]
fn test_sample_prob_invalid_range() {
assert!(sample_prob(-0.1).is_err());
assert!(sample_prob(1.1).is_err());
assert!(sample_prob(-1.0).is_err());
assert!(sample_prob(2.0).is_err());
}
#[test]
fn test_sample_prob_approximate_rate() {
let mut count = 0;
let trials = 10000;
for _ in 0..trials {
if sample_prob(0.5).unwrap() {
count += 1;
}
}
assert!(
count > 4000 && count < 6000,
"Expected ~5000 true out of 10000, got {}",
count
);
}
#[test]
fn test_sample_prob_with_rhai() {
let mut engine = Engine::new();
register_functions(&mut engine);
let _: bool = engine.eval("sample_prob(0.5)").unwrap();
let _: bool = engine.eval("sample_prob(0.0)").unwrap();
let _: bool = engine.eval("sample_prob(1.0)").unwrap();
assert!(engine.eval::<bool>("sample_prob(-0.1)").is_err());
assert!(engine.eval::<bool>("sample_prob(1.1)").is_err());
}
#[test]
fn test_sample_every_use_case() {
clear_sample_counters();
let mut kept = 0;
let total = 1000;
for _ in 0..total {
if sample_every(100).unwrap() {
kept += 1;
}
}
assert_eq!(kept, 10, "Should keep 10 out of 1000 events (1%)");
}
}