use rand::RngExt as _;
use crate::rng::Rng;
use crate::space::{Space, SpaceInfo};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Discrete {
pub n: u64,
pub start: i64,
}
impl Discrete {
#[must_use]
pub fn new(n: u64) -> Self {
assert!(n > 0, "Discrete space must have at least one element");
Self { n, start: 0 }
}
#[must_use]
pub fn with_start(n: u64, start: i64) -> Self {
assert!(n > 0, "Discrete space must have at least one element");
Self { n, start }
}
}
impl Space for Discrete {
type Element = i64;
fn sample(&self, rng: &mut Rng) -> i64 {
self.start + i64::from(rng.random_range(0..u32::try_from(self.n).expect("n fits u32")))
}
fn contains(&self, value: &i64) -> bool {
*value >= self.start && *value < self.start + self.n.cast_signed()
}
fn shape(&self) -> &[usize] {
&[]
}
#[allow(clippy::cast_possible_truncation)] fn flatdim(&self) -> usize {
self.n as usize
}
fn space_info(&self) -> SpaceInfo {
SpaceInfo::Discrete {
n: self.n,
start: self.start,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rng::create_rng;
#[test]
fn sample_is_within_bounds() {
let space = Discrete::new(5);
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let s = space.sample(&mut rng);
assert!(space.contains(&s), "sample {s} not in space");
}
}
#[test]
fn contains_checks_bounds() {
let space = Discrete::with_start(3, -1); assert!(space.contains(&-1));
assert!(space.contains(&0));
assert!(space.contains(&1));
assert!(!space.contains(&-2));
assert!(!space.contains(&2));
}
#[test]
fn shape_is_empty() {
let space = Discrete::new(2);
assert!(space.shape().is_empty());
}
}