pub struct SimpleReplayBuffer<O, A>{ /* private fields */ }
Expand description
A generic implementation of a replay buffer for reinforcement learning.
This buffer can store transitions of arbitrary observation and action types, making it suitable for a wide range of reinforcement learning tasks. It supports:
- Standard experience replay
- Prioritized experience replay (optional)
- Efficient sampling and storage
§Type Parameters
O
- The type of observations, must implementBatchBase
A
- The type of actions, must implementBatchBase
§Examples
let config = SimpleReplayBufferConfig {
capacity: 10000,
per_config: Some(PerConfig {
alpha: 0.6,
beta_0: 0.4,
beta_final: 1.0,
n_opts_final: 100000,
normalize: true,
}),
};
let mut buffer = SimpleReplayBuffer::<Tensor, Tensor>::build(&config);
// Add transitions
buffer.push(transition)?;
// Sample a batch
let batch = buffer.batch(32)?;
Implementations§
Source§impl<O, A> SimpleReplayBuffer<O, A>
impl<O, A> SimpleReplayBuffer<O, A>
Sourcepub fn whole_actions(&self) -> A
pub fn whole_actions(&self) -> A
Returns a batch containing all actions in the buffer.
§Warning
This method should be used with caution on large replay buffers as it may consume significant memory.
Sourcepub fn num_terminated_flags(&self) -> usize
pub fn num_terminated_flags(&self) -> usize
Returns the number of terminated episodes in the buffer.
Sourcepub fn num_truncated_flags(&self) -> usize
pub fn num_truncated_flags(&self) -> usize
Returns the number of truncated episodes in the buffer.
Sourcepub fn sum_rewards(&self) -> f32
pub fn sum_rewards(&self) -> f32
Returns the sum of all rewards in the buffer.
Trait Implementations§
Source§impl Agent<TestEnv, SimpleReplayBuffer<TestObsBatch, TestActBatch>> for TestAgent
impl Agent<TestEnv, SimpleReplayBuffer<TestObsBatch, TestActBatch>> for TestAgent
Source§fn opt_with_record(
&mut self,
_buffer: &mut SimpleReplayBuffer<TestObsBatch, TestActBatch>,
) -> Record
fn opt_with_record( &mut self, _buffer: &mut SimpleReplayBuffer<TestObsBatch, TestActBatch>, ) -> Record
Source§fn save_params(&self, _path: &Path) -> Result<Vec<PathBuf>>
fn save_params(&self, _path: &Path) -> Result<Vec<PathBuf>>
Source§fn load_params(&mut self, _path: &Path) -> Result<()>
fn load_params(&mut self, _path: &Path) -> Result<()>
Source§fn as_any_ref(&self) -> &dyn Any
fn as_any_ref(&self) -> &dyn Any
Any
value. Read moreSource§fn as_any_mut(&mut self) -> &mut dyn Any
fn as_any_mut(&mut self) -> &mut dyn Any
Any
value. Read moreSource§impl<O, A> ExperienceBufferBase for SimpleReplayBuffer<O, A>
impl<O, A> ExperienceBufferBase for SimpleReplayBuffer<O, A>
Source§impl<O, A> ReplayBufferBase for SimpleReplayBuffer<O, A>
impl<O, A> ReplayBufferBase for SimpleReplayBuffer<O, A>
Source§fn batch(&mut self, size: usize) -> Result<Self::Batch>
fn batch(&mut self, size: usize) -> Result<Self::Batch>
Samples a batch of transitions from the buffer.
If prioritized experience replay is enabled, samples are selected according to their priorities. Otherwise, uniform random sampling is used.
§Arguments
size
- Number of transitions to sample
§Returns
A batch of sampled transitions
§Errors
Returns an error if:
- The buffer is empty
- The requested batch size is larger than the buffer size
Source§fn update_priority(
&mut self,
ixs: &Option<Vec<usize>>,
td_errs: &Option<Vec<f32>>,
)
fn update_priority( &mut self, ixs: &Option<Vec<usize>>, td_errs: &Option<Vec<f32>>, )
Updates the priorities of transitions in the buffer.
This method is used in prioritized experience replay to adjust the sampling probabilities based on TD errors.
§Arguments
ixs
- Optional indices of transitions to updatetd_errs
- Optional TD errors for the transitions