use std::collections::HashMap;
use crate::rng::Rng;
use crate::space::{
BoundedSpace, Discrete, MultiBinary, MultiDiscrete, Space, SpaceInfo, TextSpace,
};
#[derive(Debug, Clone)]
pub struct DictSpace<S: Space> {
entries: Vec<(String, S)>,
}
impl<S: Space> DictSpace<S> {
#[must_use]
pub const fn new(entries: Vec<(String, S)>) -> Self {
Self { entries }
}
#[must_use]
pub const fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &S)> {
self.entries.iter().map(|(k, v)| (k.as_str(), v))
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&S> {
self.entries.iter().find(|(k, _)| k == name).map(|(_, v)| v)
}
}
impl<S: Space> Space for DictSpace<S> {
type Element = HashMap<String, S::Element>;
fn sample(&self, rng: &mut Rng) -> Self::Element {
self.entries
.iter()
.map(|(k, s)| (k.clone(), s.sample(rng)))
.collect()
}
fn contains(&self, value: &Self::Element) -> bool {
if value.len() != self.entries.len() {
return false;
}
self.entries
.iter()
.all(|(k, s)| value.get(k).is_some_and(|v| s.contains(v)))
}
fn shape(&self) -> &[usize] {
&[]
}
fn flatdim(&self) -> usize {
self.entries.iter().map(|(_, s)| s.flatdim()).sum()
}
fn space_info(&self) -> SpaceInfo {
SpaceInfo::Dict(
self.entries
.iter()
.map(|(k, s)| (k.clone(), s.space_info()))
.collect(),
)
}
}
#[derive(Debug, Clone)]
pub enum AnySpace {
Discrete(Discrete),
Bounded(BoundedSpace),
MultiDiscrete(MultiDiscrete),
MultiBinary(MultiBinary),
Text(TextSpace),
}
#[derive(Debug, Clone, PartialEq)]
pub enum AnyValue {
Discrete(i64),
Continuous(Vec<f32>),
MultiDiscrete(Vec<i64>),
MultiBinary(Vec<u8>),
Text(String),
}
impl Space for AnySpace {
type Element = AnyValue;
fn sample(&self, rng: &mut Rng) -> AnyValue {
match self {
Self::Discrete(s) => AnyValue::Discrete(s.sample(rng)),
Self::Bounded(s) => AnyValue::Continuous(s.sample(rng)),
Self::MultiDiscrete(s) => AnyValue::MultiDiscrete(s.sample(rng)),
Self::MultiBinary(s) => AnyValue::MultiBinary(s.sample(rng)),
Self::Text(s) => AnyValue::Text(s.sample(rng)),
}
}
fn contains(&self, value: &AnyValue) -> bool {
match (self, value) {
(Self::Discrete(s), AnyValue::Discrete(v)) => s.contains(v),
(Self::Bounded(s), AnyValue::Continuous(v)) => s.contains(v),
(Self::MultiDiscrete(s), AnyValue::MultiDiscrete(v)) => s.contains(v),
(Self::MultiBinary(s), AnyValue::MultiBinary(v)) => s.contains(v),
(Self::Text(s), AnyValue::Text(v)) => s.contains(v),
_ => false,
}
}
fn shape(&self) -> &[usize] {
match self {
Self::Discrete(s) => s.shape(),
Self::Bounded(s) => s.shape(),
Self::MultiDiscrete(s) => s.shape(),
Self::MultiBinary(s) => s.shape(),
Self::Text(s) => s.shape(),
}
}
fn flatdim(&self) -> usize {
match self {
Self::Discrete(s) => s.flatdim(),
Self::Bounded(s) => s.flatdim(),
Self::MultiDiscrete(s) => s.flatdim(),
Self::MultiBinary(s) => s.flatdim(),
Self::Text(s) => s.flatdim(),
}
}
fn space_info(&self) -> SpaceInfo {
match self {
Self::Discrete(s) => s.space_info(),
Self::Bounded(s) => s.space_info(),
Self::MultiDiscrete(s) => s.space_info(),
Self::MultiBinary(s) => s.space_info(),
Self::Text(s) => s.space_info(),
}
}
}
impl From<Discrete> for AnySpace {
fn from(s: Discrete) -> Self {
Self::Discrete(s)
}
}
impl From<BoundedSpace> for AnySpace {
fn from(s: BoundedSpace) -> Self {
Self::Bounded(s)
}
}
impl From<MultiDiscrete> for AnySpace {
fn from(s: MultiDiscrete) -> Self {
Self::MultiDiscrete(s)
}
}
impl From<MultiBinary> for AnySpace {
fn from(s: MultiBinary) -> Self {
Self::MultiBinary(s)
}
}
impl From<TextSpace> for AnySpace {
fn from(s: TextSpace) -> Self {
Self::Text(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rng::create_rng;
#[test]
fn homogeneous_dict_sample_and_contains() {
let space = DictSpace::new(vec![
(
"position".into(),
BoundedSpace::new(vec![-1.0, -1.0], vec![1.0, 1.0]).unwrap(),
),
(
"velocity".into(),
BoundedSpace::new(vec![-5.0], vec![5.0]).unwrap(),
),
]);
let mut rng = create_rng(Some(42));
let sample = space.sample(&mut rng);
assert!(space.contains(&sample));
assert_eq!(sample.len(), 2);
assert_eq!(sample["position"].len(), 2);
assert_eq!(sample["velocity"].len(), 1);
}
#[test]
fn heterogeneous_dict_sample_and_contains() {
let space = DictSpace::<AnySpace>::new(vec![
(
"obs".into(),
BoundedSpace::new(vec![-1.0], vec![1.0]).unwrap().into(),
),
("mode".into(), Discrete::new(3).into()),
]);
let mut rng = create_rng(Some(42));
let sample = space.sample(&mut rng);
assert!(space.contains(&sample));
assert_eq!(sample.len(), 2);
}
#[test]
fn rejects_missing_key() {
let space = DictSpace::new(vec![
("a".into(), Discrete::new(2)),
("b".into(), Discrete::new(3)),
]);
let mut rng = create_rng(Some(42));
let mut sample = space.sample(&mut rng);
sample.remove("b");
assert!(!space.contains(&sample));
}
#[test]
fn flatdim_sums_entries() {
let space = DictSpace::new(vec![
(
"pos".into(),
BoundedSpace::new(vec![0.0; 3], vec![1.0; 3]).unwrap(),
),
(
"vel".into(),
BoundedSpace::new(vec![0.0; 2], vec![1.0; 2]).unwrap(),
),
]);
assert_eq!(space.flatdim(), 5);
}
#[test]
fn space_info_is_dict() {
let space = DictSpace::new(vec![("x".into(), Discrete::new(4))]);
let info = space.space_info();
assert!(matches!(info, SpaceInfo::Dict(_)));
}
#[test]
fn get_and_len() {
let space = DictSpace::new(vec![
("a".into(), Discrete::new(2)),
("b".into(), Discrete::new(5)),
]);
assert_eq!(space.len(), 2);
assert!(!space.is_empty());
assert!(space.get("a").is_some());
assert!(space.get("c").is_none());
}
}