border_core/base/batch.rs
1//! Types and traits for handling batches of transitions in reinforcement learning.
2//!
3//! This module provides abstractions for working with batches of transitions,
4//! which are essential for training reinforcement learning agents. A transition
5//! represents a single step in the environment, containing the observation,
6//! action, next observation, reward, and termination information.
7
8/// A batch of transitions used for training reinforcement learning agents.
9///
10/// This trait represents a collection of transitions in the form `(o_t, a_t, o_t+n, r_t, is_terminated_t, is_truncated_t)`,
11/// where:
12/// - `o_t` is the observation at time step t
13/// - `a_t` is the action taken at time step t
14/// - `o_t+n` is the observation n steps after t
15/// - `r_t` is the reward received after taking action `a_t`
16/// - `is_terminated_t` indicates if the episode terminated at this step
17/// - `is_truncated_t` indicates if the episode was truncated at this step
18///
19/// The value of n determines the type of backup:
20/// - When n = 1, it represents a standard one-step transition
21/// - When n > 1, it represents an n-step transition, which can be used for
22/// n-step temporal difference learning
23///
24/// # Associated Types
25///
26/// * `ObsBatch` - The type used to store batches of observations
27/// * `ActBatch` - The type used to store batches of actions
28///
29/// # Examples
30///
31/// A typical use case is in Q-learning, where transitions are used to update
32/// the Q-function:
33/// ```ignore
34/// let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = batch.unpack();
35/// let target = reward + gamma * (1 - is_terminated) * max_a Q(next_obs, a);
36/// ```
37pub trait TransitionBatch {
38 /// The type used to store batches of observations.
39 ///
40 /// This type must be able to efficiently store and access multiple observations
41 /// simultaneously, typically implemented as a tensor or array-like structure.
42 type ObsBatch;
43
44 /// The type used to store batches of actions.
45 ///
46 /// This type must be able to efficiently store and access multiple actions
47 /// simultaneously, typically implemented as a tensor or array-like structure.
48 type ActBatch;
49
50 /// Unpacks the batch into its constituent parts.
51 ///
52 /// Returns a tuple containing:
53 /// 1. The batch of observations at time t
54 /// 2. The batch of actions taken at time t
55 /// 3. The batch of observations at time t+n
56 /// 4. The batch of rewards received
57 /// 5. The batch of termination flags
58 /// 6. The batch of truncation flags
59 /// 7. Optional sample indices (used for prioritized experience replay)
60 /// 8. Optional importance weights (used for prioritized experience replay)
61 ///
62 /// # Returns
63 ///
64 /// A tuple containing all components of the transition batch, with optional
65 /// metadata for prioritized experience replay.
66 fn unpack(
67 self,
68 ) -> (
69 Self::ObsBatch,
70 Self::ActBatch,
71 Self::ObsBatch,
72 Vec<f32>,
73 Vec<i8>,
74 Vec<i8>,
75 Option<Vec<usize>>,
76 Option<Vec<f32>>,
77 );
78
79 /// Returns the number of transitions in the batch.
80 ///
81 /// This is typically used to determine the batch size for optimization steps
82 /// and to verify that all components of the batch have consistent sizes.
83 fn len(&self) -> usize;
84
85 /// Returns a reference to the batch of observations at time t.
86 ///
87 /// This provides efficient access to the observations without unpacking the
88 /// entire batch.
89 fn obs(&self) -> &Self::ObsBatch;
90
91 /// Returns a reference to the batch of actions taken at time t.
92 ///
93 /// This provides efficient access to the actions without unpacking the
94 /// entire batch.
95 fn act(&self) -> &Self::ActBatch;
96}