rlmesh-spaces 0.1.0-rc.1

Internal RLMesh crate (unstable Rust API): space specifications and value model.
Documentation
use crate::errors::{SpaceError, err_space};
use crate::spaces::{
    Conformance, SpaceKind, SpaceSpec, SpaceValue, conform_at, validate_space, validate_space_at,
};
use crate::{DType, TupleSpec};

#[must_use = "a space builder does nothing until .build() is called"]
pub struct TupleSpaceBuilder {
    spaces: Vec<SpaceSpec>,
}

impl Default for TupleSpaceBuilder {
    fn default() -> Self {
        Self::new()
    }
}

impl TupleSpaceBuilder {
    pub fn new() -> Self {
        Self { spaces: Vec::new() }
    }

    pub fn with(mut self, space: SpaceSpec) -> Self {
        self.spaces.push(space);
        self
    }

    pub fn extend<I: IntoIterator<Item = SpaceSpec>>(mut self, spaces: I) -> Self {
        self.spaces.extend(spaces);
        self
    }

    pub fn build(self) -> Result<SpaceSpec, SpaceError> {
        make_tuple_space(self.spaces)
    }
}

fn make_tuple_space(spaces: Vec<SpaceSpec>) -> Result<SpaceSpec, SpaceError> {
    let spec = SpaceSpec {
        shape: vec![],
        dtype: DType::Unspecified,
        spec: Some(SpaceKind::Tuple(TupleSpec { spaces })),
    };

    validate_space(&spec)?;
    Ok(spec)
}

pub(crate) fn validate_tuple_at(spec: &SpaceSpec, path: &str) -> Result<(), SpaceError> {
    if !spec.shape.is_empty() {
        return err_space!(path, "Tuple", "shape must be empty");
    }
    if spec.dtype != DType::Unspecified {
        return err_space!(path, "Tuple", "dtype must be 'UNSPECIFIED'");
    }

    let t = match &spec.spec {
        Some(SpaceKind::Tuple(t)) => t,
        _ => return err_space!(path, "Tuple", "spec.tuple must be set"),
    };

    for (i, child) in t.spaces.iter().enumerate() {
        validate_space_at(child, &format!("{path}[{i}]"))?;
    }

    Ok(())
}

pub(crate) fn conform_tuple(space: &SpaceSpec, value: &SpaceValue, path: &str) -> Conformance {
    let tuple_val = match value {
        SpaceValue::Tuple(t) => t,
        _ => return Conformance::Structural(SpaceError::invalid(path, "expected Tuple value")),
    };

    let t = match &space.spec {
        Some(SpaceKind::Tuple(t)) => t,
        _ => return Conformance::Structural(SpaceError::invalid(path, "space is not Tuple")),
    };

    if tuple_val.len() != t.spaces.len() {
        return Conformance::Structural(SpaceError::invalid(
            path,
            format!(
                "tuple length mismatch: expected {}, got {}",
                t.spaces.len(),
                tuple_val.len()
            ),
        ));
    }

    let mut range: Option<SpaceError> = None;
    for (i, (sub_space, sub_val)) in t.spaces.iter().zip(tuple_val.iter()).enumerate() {
        match conform_at(sub_space, sub_val, &format!("{path}[{i}]")) {
            Conformance::Structural(err) => return Conformance::Structural(err),
            Conformance::Range(err) => {
                if range.is_none() {
                    range = Some(err);
                }
            }
            Conformance::Ok => {}
        }
    }

    match range {
        Some(err) => Conformance::Range(err),
        None => Conformance::Ok,
    }
}

#[cfg(test)]
mod tests {
    use crate::DType;
    use crate::spaces::composite::TupleSpaceBuilder;
    use crate::spaces::fundamental::{BoxSpaceBuilder, DiscreteBuilder};
    use crate::spaces::{SpaceValue, contains};
    use crate::tensor::Tensor;

    #[test]
    fn test_tuple_contains() {
        let box_space = BoxSpaceBuilder::scalar(0.0, 1.0, vec![3]).build().unwrap();
        let discrete = DiscreteBuilder::new(4).build().unwrap();

        let space = TupleSpaceBuilder::new()
            .with(box_space)
            .with(discrete)
            .build()
            .unwrap();

        let valid = SpaceValue::Tuple(vec![
            SpaceValue::Box(
                Tensor::from_vec(vec![0u8; 12], vec![3], DType::Float32).expect("valid tensor"),
            ),
            SpaceValue::Discrete(2),
        ]);
        assert!(contains(&space, &valid).is_ok());

        // Wrong length
        let wrong_len = SpaceValue::Tuple(vec![SpaceValue::Discrete(2)]);
        assert!(contains(&space, &wrong_len).is_err());
    }
}