Struct SimpleReplayBuffer

Source
pub struct SimpleReplayBuffer<O, A>
where O: BatchBase, A: BatchBase,
{ /* private fields */ }
Expand description

A generic implementation of a replay buffer for reinforcement learning.

This buffer can store transitions of arbitrary observation and action types, making it suitable for a wide range of reinforcement learning tasks. It supports:

  • Standard experience replay
  • Prioritized experience replay (optional)
  • Efficient sampling and storage

§Type Parameters

  • O - The type of observations, must implement BatchBase
  • A - The type of actions, must implement BatchBase

§Examples

let config = SimpleReplayBufferConfig {
    capacity: 10000,
    per_config: Some(PerConfig {
        alpha: 0.6,
        beta_0: 0.4,
        beta_final: 1.0,
        n_opts_final: 100000,
        normalize: true,
    }),
};

let mut buffer = SimpleReplayBuffer::<Tensor, Tensor>::build(&config);

// Add transitions
buffer.push(transition)?;

// Sample a batch
let batch = buffer.batch(32)?;

Implementations§

Source§

impl<O, A> SimpleReplayBuffer<O, A>
where O: BatchBase, A: BatchBase,

Source

pub fn whole_actions(&self) -> A

Returns a batch containing all actions in the buffer.

§Warning

This method should be used with caution on large replay buffers as it may consume significant memory.

Source

pub fn num_terminated_flags(&self) -> usize

Returns the number of terminated episodes in the buffer.

Source

pub fn num_truncated_flags(&self) -> usize

Returns the number of truncated episodes in the buffer.

Source

pub fn sum_rewards(&self) -> f32

Returns the sum of all rewards in the buffer.

Trait Implementations§

Source§

impl Agent<TestEnv, SimpleReplayBuffer<TestObsBatch, TestActBatch>> for TestAgent

Source§

fn train(&mut self)

Switches the agent to training mode. Read more
Source§

fn is_train(&self) -> bool

Returns whether the agent is currently in training mode. Read more
Source§

fn eval(&mut self)

Switches the agent to evaluation mode. Read more
Source§

fn opt_with_record( &mut self, _buffer: &mut SimpleReplayBuffer<TestObsBatch, TestActBatch>, ) -> Record

Performs an optimization step and returns training metrics. Read more
Source§

fn save_params(&self, _path: &Path) -> Result<Vec<PathBuf>>

Saves the agent’s parameters to the specified directory. Read more
Source§

fn load_params(&mut self, _path: &Path) -> Result<()>

Loads the agent’s parameters from the specified directory. Read more
Source§

fn as_any_ref(&self) -> &dyn Any

Returns a reference to the agent as a type-erased Any value. Read more
Source§

fn as_any_mut(&mut self) -> &mut dyn Any

Returns a mutable reference to the agent as a type-erased Any value. Read more
Source§

fn opt(&mut self, buffer: &mut R)

Performs a single optimization step using experiences from the replay buffer. Read more
Source§

impl<O, A> ExperienceBufferBase for SimpleReplayBuffer<O, A>
where O: BatchBase, A: BatchBase,

Source§

fn len(&self) -> usize

Returns the current number of transitions in the buffer.

Source§

fn push(&mut self, tr: Self::Item) -> Result<()>

Adds a new transition to the buffer.

§Arguments
  • tr - The transition to add
§Returns

Ok(()) if the transition was added successfully

§Errors

Returns an error if the buffer is full and cannot accept more transitions

Source§

type Item = GenericTransitionBatch<O, A>

The type of items stored in the buffer. Read more
Source§

impl<O, A> ReplayBufferBase for SimpleReplayBuffer<O, A>
where O: BatchBase, A: BatchBase,

Source§

fn build(config: &Self::Config) -> Self

Creates a new replay buffer with the given configuration.

§Arguments
  • config - Configuration for the replay buffer
§Returns

A new instance of the replay buffer

Source§

fn batch(&mut self, size: usize) -> Result<Self::Batch>

Samples a batch of transitions from the buffer.

If prioritized experience replay is enabled, samples are selected according to their priorities. Otherwise, uniform random sampling is used.

§Arguments
  • size - Number of transitions to sample
§Returns

A batch of sampled transitions

§Errors

Returns an error if:

  • The buffer is empty
  • The requested batch size is larger than the buffer size
Source§

fn update_priority( &mut self, ixs: &Option<Vec<usize>>, td_errs: &Option<Vec<f32>>, )

Updates the priorities of transitions in the buffer.

This method is used in prioritized experience replay to adjust the sampling probabilities based on TD errors.

§Arguments
  • ixs - Optional indices of transitions to update
  • td_errs - Optional TD errors for the transitions
Source§

type Config = SimpleReplayBufferConfig

Configuration parameters for the replay buffer. Read more
Source§

type Batch = GenericTransitionBatch<O, A>

The type of batch generated for training. Read more

Auto Trait Implementations§

§

impl<O, A> Freeze for SimpleReplayBuffer<O, A>
where O: Freeze, A: Freeze,

§

impl<O, A> RefUnwindSafe for SimpleReplayBuffer<O, A>

§

impl<O, A> Send for SimpleReplayBuffer<O, A>
where O: Send, A: Send,

§

impl<O, A> Sync for SimpleReplayBuffer<O, A>
where O: Sync, A: Sync,

§

impl<O, A> Unpin for SimpleReplayBuffer<O, A>
where O: Unpin, A: Unpin,

§

impl<O, A> UnwindSafe for SimpleReplayBuffer<O, A>
where O: UnwindSafe, A: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V