use crate::base::{Action, Observation, State};
#[derive(Debug, Clone, PartialEq)]
pub enum StateError {
InvalidShape {
expected: Vec<usize>,
got: Vec<usize>,
},
InvalidData(String),
InvalidSize {
expected: usize,
got: usize,
},
}
impl std::fmt::Display for StateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StateError::InvalidShape { expected, got } => {
write!(f, "Invalid shape: expected {:?}, got {:?}", expected, got)
}
StateError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
StateError::InvalidSize { expected, got } => {
write!(f, "Invalid size: expected {}, got {}", expected, got)
}
}
}
}
impl std::error::Error for StateError {}
pub trait MarkovState {
fn is_markov() -> bool {
true
}
}
pub trait BeliefState<const SD: usize, const AD: usize, S: State<SD>, A: Action<AD>>: Clone {
fn update(&self, action: &A, observation: &S::Observation) -> Self;
fn sample(&self) -> S;
fn probability(&self, state: &S) -> f64;
}
pub trait HiddenState<const D: usize>: Clone {
type Observation: Observation<D>;
fn update(&mut self, observation: &Self::Observation);
fn reset(&mut self);
}
pub trait LatentState<const D: usize, const AD: usize>: Clone {
type Observation: Observation<D>;
fn encode(observation: &Self::Observation) -> Self;
fn predict_next<A: Action<AD>>(&self, action: &A) -> Self;
fn decode(&self) -> Self::Observation;
}
pub trait StateAggregation<const SD: usize, S: State<SD>> {
type AbstractState: Clone + Eq;
fn aggregate(&self, state: &S) -> Self::AbstractState;
fn same_aggregate(&self, state1: &S, state2: &S) -> bool {
self.aggregate(state1) == self.aggregate(state2)
}
}