rlmesh-spaces 0.1.0-rc.1

Internal RLMesh crate (unstable Rust API): space specifications and value model.
Documentation
use crate::errors::{SpaceError, err_space};
use crate::spaces::{SpaceKind, SpaceSpec, SpaceValue, validate_space};
use crate::{DType, MultiDiscreteNvec, MultiDiscreteSpec};

#[must_use = "a space builder does nothing until .build() is called"]
pub struct MultiDiscreteBuilder {
    dtype: DType,
    shape: Vec<i64>,
    nvec: MultiDiscreteNvec,
}

impl MultiDiscreteBuilder {
    /// `MultiDiscrete(nvec: [n0, n1, ...])` sets shape to `[len]`.
    pub fn vector(nvec: impl Into<Vec<i64>>) -> Self {
        let nvec = nvec.into();

        Self {
            shape: vec![nvec.len() as i64],
            dtype: DType::Int64,
            nvec: MultiDiscreteNvec::Flat(nvec),
        }
    }

    /// `MultiDiscrete(nvec: [[...], [...]])` sets shape to `[rows, cols]`.
    pub fn matrix(rows: impl Into<Vec<Vec<i64>>>) -> Self {
        let rows = rows.into();
        let r = rows.len();
        let c = rows.first().map(|x| x.len()).unwrap_or(0);

        Self {
            shape: vec![r as i64, c as i64],
            dtype: DType::Int64,
            nvec: MultiDiscreteNvec::Shaped(rows),
        }
    }

    pub fn dtype(mut self, dtype: DType) -> Self {
        self.dtype = dtype;
        self
    }

    pub fn build(self) -> Result<SpaceSpec, SpaceError> {
        let spec = SpaceSpec {
            shape: self.shape,
            dtype: self.dtype,
            spec: Some(SpaceKind::MultiDiscrete(MultiDiscreteSpec {
                nvec: Some(self.nvec),
            })),
        };

        validate_space(&spec)?;
        Ok(spec)
    }
}

pub(crate) fn validate_multidiscrete_at(space: &SpaceSpec, path: &str) -> Result<(), SpaceError> {
    if space.shape.is_empty() {
        return err_space!(path, "MultiDiscrete", "shape must be set (rank >= 1)");
    }
    if space.dtype == DType::Unspecified {
        return err_space!(path, "MultiDiscrete", "dtype must be set");
    }

    for (i, &d) in space.shape.iter().enumerate() {
        if d <= 0 {
            return err_space!(
                path,
                "MultiDiscrete",
                format!("MultiDiscrete.shape[{i}] must be > 0")
            );
        }
    }

    let md = match &space.spec {
        Some(SpaceKind::MultiDiscrete(md)) => md,
        _ => {
            return err_space!(path, "MultiDiscrete", "spec.multi_discrete must be set");
        }
    };

    let nvec = match &md.nvec {
        Some(nvec) => nvec,
        None => return err_space!(path, "MultiDiscrete", "nvec must be set"),
    };

    match nvec {
        // rank-1 nvec
        MultiDiscreteNvec::Flat(values) => {
            if values.is_empty() {
                return err_space!(path, "MultiDiscrete", "nvec.flat.data must be non-empty");
            }
            for (i, &n) in values.iter().enumerate() {
                if n <= 0 {
                    return err_space!(
                        path,
                        "MultiDiscrete",
                        format!("nvec.flat.data[{i}] must be > 0")
                    );
                }
            }

            // canonical shape for flat form: [len(values)]
            if space.shape.len() != 1 || space.shape[0] != values.len() as i64 {
                return err_space!(
                    path,
                    "MultiDiscrete",
                    "shape mismatch: for flat nvec, expected shape == [len(nvec)]"
                );
            }
            Ok(())
        }

        // rank-2 nvec (matrix)
        MultiDiscreteNvec::Shaped(rows) => {
            if rows.is_empty() {
                return err_space!(path, "MultiDiscrete", "nvec.shaped.data must be non-empty");
            }

            let cols = rows[0].len();
            if cols == 0 {
                return err_space!(path, "MultiDiscrete", "nvec.shaped rows must be non-empty");
            }

            // must be rectangular
            for (ri, r) in rows.iter().enumerate() {
                if r.len() != cols {
                    return err_space!(
                        path,
                        "MultiDiscrete",
                        format!("nvec.shaped row {ri} length mismatch")
                    );
                }
            }

            // all entries > 0
            for (ri, r) in rows.iter().enumerate() {
                for (ci, &n) in r.iter().enumerate() {
                    if n <= 0 {
                        return err_space!(
                            path,
                            "MultiDiscrete",
                            format!("nvec.shaped[{ri}][{ci}] must be > 0")
                        );
                    }
                }
            }

            // canonical shape for matrix form: [rows, cols]
            if space.shape.len() != 2
                || space.shape[0] != rows.len() as i64
                || space.shape[1] != cols as i64
            {
                return err_space!(
                    path,
                    "MultiDiscrete",
                    "MultiDiscrete shape mismatch: expected shape == [rows, cols] for shaped"
                );
            }

            Ok(())
        }
    }
}

pub(crate) fn contains_multidiscrete(
    space: &SpaceSpec,
    value: &SpaceValue,
    path: &str,
) -> Result<(), SpaceError> {
    let vals = match value {
        SpaceValue::MultiDiscrete(v) => v,
        _ => return err_space!(path, "expected MultiDiscrete value"),
    };

    let md = match &space.spec {
        Some(SpaceKind::MultiDiscrete(md)) => md,
        _ => return err_space!(path, "space is not MultiDiscrete"),
    };

    // Get nvec from the space
    let nvec: Vec<i64> = match &md.nvec {
        Some(MultiDiscreteNvec::Flat(v)) => v.clone(),
        Some(MultiDiscreteNvec::Shaped(m)) => m.iter().flat_map(|row| row.clone()).collect(),
        None => return err_space!(path, "MultiDiscrete.nvec not set"),
    };

    if vals.len() != nvec.len() {
        return err_space!(
            path,
            format!(
                "MultiDiscrete size mismatch: expected {}, got {}",
                nvec.len(),
                vals.len()
            )
        );
    }

    // Check each value is in range [0, nvec[i])
    for (i, (&val, &n)) in vals.iter().zip(nvec.iter()).enumerate() {
        if val < 0 || val >= n {
            return err_space!(
                path,
                format!("value[{}] = {} not in range [0, {})", i, val, n)
            );
        }
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use crate::spaces::fundamental::MultiDiscreteBuilder;
    use crate::spaces::{SpaceValue, contains};

    #[test]
    fn test_multidiscrete_contains() {
        let space = MultiDiscreteBuilder::vector(vec![2, 3]).build().unwrap();

        assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![0, 2])).is_ok());
        assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![1])).is_err());
        assert!(contains(&space, &SpaceValue::MultiDiscrete(vec![2, 0])).is_err());
        assert!(contains(&space, &SpaceValue::Discrete(1)).is_err());
    }
}