use rand::Rng;
use super::{SelectionContext, Strategy};
use crate::Weighted;
#[derive(Debug)]
pub struct Random;
impl Random {
pub fn new() -> Self {
Self
}
}
impl Default for Random {
fn default() -> Self {
Self
}
}
impl<N> Strategy<N> for Random {
fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
let eligible: Vec<usize> = (0..candidates.len())
.filter(|i| !ctx.is_excluded(*i))
.collect();
if eligible.is_empty() {
return None;
}
Some(eligible[rand::rng().random_range(0..eligible.len())])
}
}
#[derive(Debug)]
pub struct WeightedRandom;
impl WeightedRandom {
pub fn new() -> Self {
Self
}
}
impl Default for WeightedRandom {
fn default() -> Self {
Self
}
}
impl<N: Weighted> Strategy<N> for WeightedRandom {
fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
if candidates.is_empty() {
return None;
}
let eligible: Vec<usize> = (0..candidates.len())
.filter(|i| !ctx.is_excluded(*i))
.collect();
if eligible.is_empty() {
return None;
}
let total_weight: u64 = eligible.iter().map(|&i| candidates[i].weight() as u64).sum();
if total_weight == 0 {
return None;
}
let mut point = rand::rng().random_range(0..total_weight);
for &i in &eligible {
let w = candidates[i].weight() as u64;
if point < w {
return Some(i);
}
point -= w;
}
eligible.last().copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn random_empty_returns_none() {
let s = Random::new();
let nodes: [i32; 0] = [];
assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
}
#[test]
fn random_all_excluded_returns_none() {
let s = Random::new();
let nodes = [1, 2];
let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
assert_eq!(s.select(&nodes, &ctx), None);
}
#[test]
fn random_respects_exclude() {
let s = Random::new();
let nodes = [1, 2];
let ctx = SelectionContext::builder().exclude(vec![0]).build();
for _ in 0..20 {
assert_eq!(s.select(&nodes, &ctx), Some(1));
}
}
struct W(u32);
impl Weighted for W {
fn weight(&self) -> u32 {
self.0
}
}
#[test]
fn weighted_random_empty_returns_none() {
let s = WeightedRandom::new();
let nodes: [W; 0] = [];
assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
}
#[test]
fn weighted_random_zero_weight_returns_none() {
let s = WeightedRandom::new();
let nodes = [W(0), W(0)];
assert_eq!(s.select(&nodes, &SelectionContext::default()), None);
}
#[test]
fn weighted_random_all_excluded_returns_none() {
let s = WeightedRandom::new();
let nodes = [W(1)];
let ctx = SelectionContext::builder().exclude(vec![0]).build();
assert_eq!(s.select(&nodes, &ctx), None);
}
}