use crate::errors::{SpaceError, err_space};
use crate::spaces::{SpaceKind, SpaceSpec, SpaceValue, validate_space};
use crate::{DType, MultiDiscreteNvec, MultiDiscreteSpec};
#[must_use = "a space builder does nothing until .build() is called"]
pub struct MultiDiscreteBuilder {
dtype: DType,
shape: Vec<i64>,
nvec: MultiDiscreteNvec,
}
impl MultiDiscreteBuilder {
pub fn vector(nvec: impl Into<Vec<i64>>) -> Self {
let nvec = nvec.into();
Self {
shape: vec![nvec.len() as i64],
dtype: DType::Int64,
nvec: MultiDiscreteNvec::Flat(nvec),
}
}
pub fn matrix(rows: impl Into<Vec<Vec<i64>>>) -> Self {
let rows = rows.into();
let r = rows.len();
let c = rows.first().map(|x| x.len()).unwrap_or(0);
Self {
shape: vec![r as i64, c as i64],
dtype: DType::Int64,
nvec: MultiDiscreteNvec::Shaped(rows),
}
}
pub fn dtype(mut self, dtype: DType) -> Self {
self.dtype = dtype;
self
}
pub fn build(self) -> Result<SpaceSpec, SpaceError> {
let spec = SpaceSpec {
shape: self.shape,
dtype: self.dtype,
spec: Some(SpaceKind::MultiDiscrete(MultiDiscreteSpec {
nvec: Some(self.nvec),
})),
};
validate_space(&spec)?;
Ok(spec)
}
}
pub(crate) fn validate_multidiscrete_at(space: &SpaceSpec, path: &str) -> Result<(), SpaceError> {
if space.shape.is_empty() {
return err_space!(path, "MultiDiscrete", "shape must be set (rank >= 1)");
}
if space.dtype == DType::Unspecified {
return err_space!(path, "MultiDiscrete", "dtype must be set");
}
for (i, &d) in space.shape.iter().enumerate() {
if d <= 0 {
return err_space!(
path,
"MultiDiscrete",
format!("MultiDiscrete.shape[{i}] must be > 0")
);
}
}
let md = match &space.spec {
Some(SpaceKind::MultiDiscrete(md)) => md,
_ => {
return err_space!(path, "MultiDiscrete", "spec.multi_discrete must be set");
}
};
let nvec = match &md.nvec {
Some(nvec) => nvec,
None => return err_space!(path, "MultiDiscrete", "nvec must be set"),
};
match nvec {
MultiDiscreteNvec::Flat(values) => {
if values.is_empty() {
return err_space!(path, "MultiDiscrete", "nvec.flat.data must be non-empty");
}
for (i, &n) in values.iter().enumerate() {
if n <= 0 {
return err_space!(
path,
"MultiDiscrete",
format!("nvec.flat.data[{i}] must be > 0")
);
}
}
if space.shape.len() != 1 || space.shape[0] != values.len() as i64 {
return err_space!(
path,
"MultiDiscrete",
"shape mismatch: for flat nvec, expected shape == [len(nvec)]"
);
}
Ok(())
}
MultiDiscreteNvec::Shaped(rows) => {
if rows.is_empty() {
return err_space!(path, "MultiDiscrete", "nvec.shaped.data must be non-empty");
}
let cols = rows[0].len();
if cols == 0 {
return err_space!(path, "MultiDiscrete", "nvec.shaped rows must be non-empty");
}
for (ri, r) in rows.iter().enumerate() {
if r.len() != cols {
return err_space!(
path,
"MultiDiscrete",
format!("nvec.shaped row {ri} length mismatch")
);
}
}
for (ri, r) in rows.iter().enumerate() {
for (ci, &n) in r.iter().enumerate() {
if n <= 0 {
return err_space!(
path,
"MultiDiscrete",
format!("nvec.shaped[{ri}][{ci}] must be > 0")
);
}
}
}
if space.shape.len() != 2
|| space.shape[0] != rows.len() as i64
|| space.shape[1] != cols as i64
{
return err_space!(
path,
"MultiDiscrete",
"MultiDiscrete shape mismatch: expected shape == [rows, cols] for shaped"
);
}
Ok(())
}
}
}
pub(crate) fn contains_multidiscrete(
space: &SpaceSpec,
value: &SpaceValue,
path: &str,
) -> Result<(), SpaceError> {
let vals = match value {
SpaceValue::MultiDiscrete(v) => v,
_ => return err_space!(path, "expected MultiDiscrete value"),
};
let md = match &space.spec {
Some(SpaceKind::MultiDiscrete(md)) => md,
_ => return err_space!(path, "space is not MultiDiscrete"),
};
let nvec: Vec<i64> = match &md.nvec {
Some(MultiDiscreteNvec::Flat(v)) => v.clone(),
Some(MultiDiscreteNvec::Shaped(m)) => m.iter().flat_map(|row| row.clone()).collect(),
None => return err_space!(path, "MultiDiscrete.nvec not set"),
};
if vals.len() != nvec.len() {
return err_space!(
path,
format!(
"MultiDiscrete size mismatch: expected {}, got {}",
nvec.len(),
vals.len()
)
);
}
for (i, (&val, &n)) in vals.iter().zip(nvec.iter()).enumerate() {
if val < 0 || val >= n {
return err_space!(
path,
format!("value[{}] = {} not in range [0, {})", i, val, n)
);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::spaces::fundamental::MultiDiscreteBuilder;
use crate::spaces::{SpaceValue, contains};
#[test]
fn test_multidiscrete_contains() {
let space = MultiDiscreteBuilder::vector(vec![2, 3]).build().unwrap();
assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![0, 2])).is_ok());
assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![1])).is_err());
assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![2, 0])).is_err());
assert!(contains(&space, &SpaceValue::Discrete(1)).is_err());
}
}