Skip to main content

rlevo_core/
state.rs

1//! Advanced state abstraction traits for non-Markovian and latent representations.
2//!
3//! This module extends the base [`State`] contract with higher-level abstractions
4//! needed for POMDPs, recurrent policies, and world-model-based agents:
5//! - [`MarkovState`] — verifies the Markov property holds for a representation
6//! - [`BeliefState`] — probability distribution over possible states (POMDP)
7//! - [`HiddenState`] — recurrent agent memory (e.g., RNN hidden state)
8//! - [`LatentState`] — learned compact representation with encode/predict/decode
9//! - [`StateAggregation`] — maps concrete states to abstract representatives
10//!
11//! [`State`]: crate::base::State
12
13use crate::base::{Action, Observation, State};
14
15/// Error type for state validation failures.
16///
17/// Returned by validation logic when a state's shape, contents, or element
18/// count do not match the expectations of the calling code.
19///
20/// # Examples
21///
22/// ```
23/// use rlevo_core::state::StateError;
24///
25/// let err = StateError::InvalidShape {
26///     expected: vec![4, 4],
27///     got: vec![4, 3],
28/// };
29/// assert!(err.to_string().contains("Invalid shape"));
30///
31/// let err = StateError::InvalidData("NaN in position field".into());
32/// assert!(err.to_string().contains("NaN in position field"));
33///
34/// let err = StateError::InvalidSize { expected: 16, got: 12 };
35/// assert!(err.to_string().contains("Invalid size"));
36/// ```
37#[derive(Debug, Clone, PartialEq)]
38pub enum StateError {
39    /// Shape dimensions do not match expectations.
40    InvalidShape {
41        expected: Vec<usize>,
42        got: Vec<usize>,
43    },
44    /// Data contents violate invariants.
45    InvalidData(String),
46    /// Total element count does not match expectations.
47    InvalidSize {
48        expected: usize,
49        got: usize,
50    },
51}
52
53impl std::fmt::Display for StateError {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            StateError::InvalidShape { expected, got } => {
57                write!(f, "Invalid shape: expected {:?}, got {:?}", expected, got)
58            }
59            StateError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
60            StateError::InvalidSize { expected, got } => {
61                write!(f, "Invalid size: expected {}, got {}", expected, got)
62            }
63        }
64    }
65}
66
67impl std::error::Error for StateError {}
68
69/// Verifies that a state representation satisfies the Markov property.
70///
71/// A representation is Markov when the future is conditionally independent of
72/// the past given the present. Tabular and neural Q-learning both assume this.
73pub trait MarkovState {
74    /// Returns `true` if this representation satisfies the Markov property.
75    ///
76    /// The default implementation returns `true`, which is correct for most
77    /// fully-observable environments. Override to return `false` for raw pixel
78    /// or partially-observable representations that require history stacking.
79    fn is_markov() -> bool {
80        true
81    }
82}
83
84/// A probability distribution over possible environment states (POMDP belief).
85///
86/// Belief states are used in partially-observable settings where the agent
87/// cannot observe the true state directly. The belief is updated via Bayes'
88/// rule as the most recent action and new observation arrive.
89///
90/// # Type Parameters
91///
92/// - `SR`: Rank of the state space tensor (number of axes).
93/// - `AR`: Rank of the action space tensor (number of axes).
94/// - `S`: The underlying environment [`State`] type.
95/// - `A`: The [`Action`] type taken by the agent.
96pub trait BeliefState<const SR: usize, const AR: usize, S: State<SR>, A: Action<AR>>: Clone {
97    /// Updates the belief distribution given the last action taken and the
98    /// newly received observation.
99    fn update(&self, action: &A, observation: &S::Observation) -> Self;
100
101    /// Draws a state sample from the current belief distribution.
102    fn sample(&self) -> S;
103
104    /// Returns the probability (or unnormalized weight) assigned to `state`.
105    fn probability(&self, state: &S) -> f64;
106}
107
108/// Recurrent agent memory analogous to an RNN hidden state.
109///
110/// Implementations hold the internal summary of past observations (e.g., the
111/// `h_t` vector of a GRU or LSTM). The hidden state is updated at each step
112/// with the latest [`Observation`] and reset to a
113/// zero vector at episode start.
114///
115/// # Type Parameters
116///
117/// - `R`: Rank of the observation space tensor used to update this state.
118pub trait HiddenState<const R: usize>: Clone {
119    /// The observation type used to update this hidden state.
120    type Observation: Observation<R>;
121
122    /// Incorporates `observation` into the hidden state in-place.
123    fn update(&mut self, observation: &Self::Observation);
124
125    /// Resets the hidden state to its initial value at episode start.
126    fn reset(&mut self);
127}
128
129/// Learned compact representation with encode, predict, and decode steps.
130///
131/// Used by world-model agents (e.g., DreamerV3) that operate in a learned
132/// latent space rather than the raw observation space.
133///
134/// # Type Parameters
135///
136/// - `R`: Rank of the observation space tensor this latent state is derived from.
137/// - `AR`: Rank of the action space tensor used in the transition prediction step.
138pub trait LatentState<const R: usize, const AR: usize>: Clone {
139    /// The observation type this latent state is derived from.
140    type Observation: Observation<R>;
141
142    /// Projects `observation` into the latent space.
143    fn encode(observation: &Self::Observation) -> Self;
144
145    /// Rolls the latent state forward by one step given `action`.
146    fn predict_next<A: Action<AR>>(&self, action: &A) -> Self;
147
148    /// Reconstructs an observation from the latent representation.
149    fn decode(&self) -> Self::Observation;
150}
151
152/// Maps concrete states to abstract representatives for state aggregation.
153///
154/// State aggregation is used in function approximation and hierarchical RL to
155/// group similar states under a shared abstract representation.
156///
157/// # Type Parameters
158///
159/// - `SR`: Rank of the concrete state space tensor.
160/// - `S`: The concrete [`State`] type being aggregated.
161pub trait StateAggregation<const SR: usize, S: State<SR>> {
162    /// The abstract state type produced by aggregation.
163    type AbstractState: Clone + Eq;
164
165    /// Returns the abstract representative for `state`.
166    fn aggregate(&self, state: &S) -> Self::AbstractState;
167
168    /// Returns `true` when `state1` and `state2` map to the same abstract state.
169    fn same_aggregate(&self, state1: &S, state2: &S) -> bool {
170        self.aggregate(state1) == self.aggregate(state2)
171    }
172}