use rand::RngExt as _;
use crate::error::{Error, Result};
use crate::rng::Rng;
use crate::space::{Space, SpaceInfo};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MultiDiscrete {
nvec: Vec<u64>,
start: Vec<i64>,
shape: Vec<usize>,
}
impl MultiDiscrete {
pub fn new(nvec: Vec<u64>) -> Result<Self> {
Self::with_start(nvec, None)
}
pub fn with_start(nvec: Vec<u64>, start: Option<Vec<i64>>) -> Result<Self> {
if nvec.is_empty() {
return Err(Error::InvalidSpace {
reason: "nvec must not be empty".to_owned(),
});
}
if nvec.contains(&0) {
return Err(Error::InvalidSpace {
reason: "all nvec elements must be > 0".to_owned(),
});
}
let start = match start {
Some(s) => {
if s.len() != nvec.len() {
return Err(Error::InvalidSpace {
reason: format!(
"start length ({}) must match nvec length ({})",
s.len(),
nvec.len()
),
});
}
s
}
None => vec![0; nvec.len()],
};
let shape = vec![nvec.len()];
Ok(Self { nvec, start, shape })
}
#[must_use]
pub fn nvec(&self) -> &[u64] {
&self.nvec
}
#[must_use]
pub fn start(&self) -> &[i64] {
&self.start
}
}
impl Space for MultiDiscrete {
type Element = Vec<i64>;
fn sample(&self, rng: &mut Rng) -> Vec<i64> {
self.nvec
.iter()
.zip(self.start.iter())
.map(|(&n, &s)| {
s + i64::from(rng.random_range(0..u32::try_from(n).expect("n fits u32")))
})
.collect()
}
fn contains(&self, value: &Vec<i64>) -> bool {
if value.len() != self.nvec.len() {
return false;
}
value
.iter()
.zip(self.nvec.iter().zip(self.start.iter()))
.all(|(&v, (&n, &s))| v >= s && v < s + n.cast_signed())
}
fn shape(&self) -> &[usize] {
&self.shape
}
#[allow(clippy::cast_possible_truncation)] fn flatdim(&self) -> usize {
self.nvec.iter().map(|&n| n as usize).sum()
}
fn space_info(&self) -> SpaceInfo {
SpaceInfo::MultiDiscrete {
nvec: self.nvec.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rng::create_rng;
#[test]
fn sample_within_bounds() {
let space = MultiDiscrete::new(vec![5, 2, 2]).unwrap();
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let s = space.sample(&mut rng);
assert!(space.contains(&s), "sample {s:?} not in space");
}
}
#[test]
fn contains_with_start() {
let space = MultiDiscrete::with_start(vec![3, 2], Some(vec![-1, 5])).unwrap();
assert!(space.contains(&vec![-1, 5]));
assert!(space.contains(&vec![1, 6]));
assert!(!space.contains(&vec![-2, 5]));
assert!(!space.contains(&vec![2, 5]));
assert!(!space.contains(&vec![0, 7]));
}
#[test]
fn rejects_empty_nvec() {
assert!(MultiDiscrete::new(vec![]).is_err());
}
#[test]
fn rejects_zero_element() {
assert!(MultiDiscrete::new(vec![3, 0, 2]).is_err());
}
#[test]
fn rejects_mismatched_start() {
assert!(MultiDiscrete::with_start(vec![3, 2], Some(vec![0])).is_err());
}
#[test]
fn shape_equals_ndims() {
let space = MultiDiscrete::new(vec![5, 2, 2]).unwrap();
assert_eq!(space.shape(), &[3]);
}
#[test]
fn flatdim_is_sum_of_nvec() {
let space = MultiDiscrete::new(vec![5, 2, 2]).unwrap();
assert_eq!(space.flatdim(), 9);
}
}