1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
//! Batch.
/// A batch of transitions for training agents.
///
/// This trait represents a standard transition `(o, a, o', r, is_done)`,
/// where `o` is an observation, `a` is an action, `o'` is an observation
/// after some time steps. Typically, `o'` is for the next step and used as
/// single-step backup. `o'` can also be for the multiple steps after `o` and
/// in this case it is sometimes called n-step backup.
///
/// The type of `o` and `o'` is the associated type `ObsBatch`.
/// The type of `a` is the associated type `ActBatch`.
pub trait StdBatchBase {
/// A set of observation in a batch.
type ObsBatch;
/// A set of observation in a batch.
type ActBatch;
/// Unpack the data `(o_t, a_t, o_t+n, r_t, is_done_t)`.
///
/// Optionally, the return value has sample indices in the replay buffer and
/// thier weights. Those are used for prioritized experience replay (PER).
fn unpack(
self,
) -> (
Self::ObsBatch,
Self::ActBatch,
Self::ObsBatch,
Vec<f32>,
Vec<i8>,
Option<Vec<usize>>,
Option<Vec<f32>>,
);
/// Returns the number of samples in the batch.
fn len(&self) -> usize;
/// Returns `o_t`.
fn obs(&self) -> &Self::ObsBatch;
/// Returns `a_t`.
fn act(&self) -> &Self::ActBatch;
/// Returns `o_t+1`.
fn next_obs(&self) -> &Self::ObsBatch;
/// Returns `r_t`.
fn reward(&self) -> &Vec<f32>;
/// Returns `is_done_t`.
fn is_done(&self) -> &Vec<i8>;
/// Returns `weight`. It is used for PER.
fn weight(&self) -> &Option<Vec<f32>>;
/// Returns `ix_sample`. It is used for PER.
fn ix_sample(&self) -> &Option<Vec<usize>>;
/// Creates an empty batch.
fn empty() -> Self;
}