use crate::TransitionBatch;
pub trait BatchBase {
fn new(capacity: usize) -> Self;
fn push(&mut self, ix: usize, data: Self);
fn sample(&self, ixs: &Vec<usize>) -> Self;
}
pub struct GenericTransitionBatch<O, A>
where
O: BatchBase,
A: BatchBase,
{
pub obs: O,
pub act: A,
pub next_obs: O,
pub reward: Vec<f32>,
pub is_terminated: Vec<i8>,
pub is_truncated: Vec<i8>,
pub weight: Option<Vec<f32>>,
pub ix_sample: Option<Vec<usize>>,
}
impl<O, A> TransitionBatch for GenericTransitionBatch<O, A>
where
O: BatchBase,
A: BatchBase,
{
type ObsBatch = O;
type ActBatch = A;
fn unpack(
self,
) -> (
Self::ObsBatch,
Self::ActBatch,
Self::ObsBatch,
Vec<f32>,
Vec<i8>,
Vec<i8>,
Option<Vec<usize>>,
Option<Vec<f32>>,
) {
(
self.obs,
self.act,
self.next_obs,
self.reward,
self.is_terminated,
self.is_truncated,
self.ix_sample,
self.weight,
)
}
fn len(&self) -> usize {
self.reward.len()
}
fn obs(&self) -> &Self::ObsBatch {
&self.obs
}
fn act(&self) -> &Self::ActBatch {
&self.act
}
}
impl<O, A> GenericTransitionBatch<O, A>
where
O: BatchBase,
A: BatchBase,
{
pub fn with_capacity(capacity: usize) -> Self {
Self {
obs: O::new(capacity),
act: A::new(capacity),
next_obs: O::new(capacity),
reward: Vec::with_capacity(capacity),
is_terminated: Vec::with_capacity(capacity),
is_truncated: Vec::with_capacity(capacity),
weight: None,
ix_sample: None,
}
}
}