border_core/base/
batch.rs

1//! Types and traits for handling batches of transitions in reinforcement learning.
2//!
3//! This module provides abstractions for working with batches of transitions,
4//! which are essential for training reinforcement learning agents. A transition
5//! represents a single step in the environment, containing the observation,
6//! action, next observation, reward, and termination information.
7
8/// A batch of transitions used for training reinforcement learning agents.
9///
10/// This trait represents a collection of transitions in the form `(o_t, a_t, o_t+n, r_t, is_terminated_t, is_truncated_t)`,
11/// where:
12/// - `o_t` is the observation at time step t
13/// - `a_t` is the action taken at time step t
14/// - `o_t+n` is the observation n steps after t
15/// - `r_t` is the reward received after taking action `a_t`
16/// - `is_terminated_t` indicates if the episode terminated at this step
17/// - `is_truncated_t` indicates if the episode was truncated at this step
18///
19/// The value of n determines the type of backup:
20/// - When n = 1, it represents a standard one-step transition
21/// - When n > 1, it represents an n-step transition, which can be used for
22///   n-step temporal difference learning
23///
24/// # Associated Types
25///
26/// * `ObsBatch` - The type used to store batches of observations
27/// * `ActBatch` - The type used to store batches of actions
28///
29/// # Examples
30///
31/// A typical use case is in Q-learning, where transitions are used to update
32/// the Q-function:
33/// ```ignore
34/// let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = batch.unpack();
35/// let target = reward + gamma * (1 - is_terminated) * max_a Q(next_obs, a);
36/// ```
37pub trait TransitionBatch {
38    /// The type used to store batches of observations.
39    ///
40    /// This type must be able to efficiently store and access multiple observations
41    /// simultaneously, typically implemented as a tensor or array-like structure.
42    type ObsBatch;
43
44    /// The type used to store batches of actions.
45    ///
46    /// This type must be able to efficiently store and access multiple actions
47    /// simultaneously, typically implemented as a tensor or array-like structure.
48    type ActBatch;
49
50    /// Unpacks the batch into its constituent parts.
51    ///
52    /// Returns a tuple containing:
53    /// 1. The batch of observations at time t
54    /// 2. The batch of actions taken at time t
55    /// 3. The batch of observations at time t+n
56    /// 4. The batch of rewards received
57    /// 5. The batch of termination flags
58    /// 6. The batch of truncation flags
59    /// 7. Optional sample indices (used for prioritized experience replay)
60    /// 8. Optional importance weights (used for prioritized experience replay)
61    ///
62    /// # Returns
63    ///
64    /// A tuple containing all components of the transition batch, with optional
65    /// metadata for prioritized experience replay.
66    fn unpack(
67        self,
68    ) -> (
69        Self::ObsBatch,
70        Self::ActBatch,
71        Self::ObsBatch,
72        Vec<f32>,
73        Vec<i8>,
74        Vec<i8>,
75        Option<Vec<usize>>,
76        Option<Vec<f32>>,
77    );
78
79    /// Returns the number of transitions in the batch.
80    ///
81    /// This is typically used to determine the batch size for optimization steps
82    /// and to verify that all components of the batch have consistent sizes.
83    fn len(&self) -> usize;
84
85    /// Returns a reference to the batch of observations at time t.
86    ///
87    /// This provides efficient access to the observations without unpacking the
88    /// entire batch.
89    fn obs(&self) -> &Self::ObsBatch;
90
91    /// Returns a reference to the batch of actions taken at time t.
92    ///
93    /// This provides efficient access to the actions without unpacking the
94    /// entire batch.
95    fn act(&self) -> &Self::ActBatch;
96}