1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use super::SubBatch;
use crate::Batch as BatchBase;
pub struct Batch<O, A>
where
O: SubBatch,
A: SubBatch,
{
pub(super) obs: O,
pub(super) act: A,
pub(super) next_obs: O,
pub(super) reward: Vec<f32>,
pub(super) is_done: Vec<i8>,
pub(super) weight: Option<Vec<f32>>,
pub(super) ix_sample: Option<Vec<usize>>,
}
impl<O, A> BatchBase for Batch<O, A>
where
O: SubBatch,
A: SubBatch,
{
type ObsBatch = O;
type ActBatch = A;
fn unpack(
self,
) -> (
Self::ObsBatch,
Self::ActBatch,
Self::ObsBatch,
Vec<f32>,
Vec<i8>,
Option<Vec<usize>>,
Option<Vec<f32>>,
) {
(
self.obs,
self.act,
self.next_obs,
self.reward,
self.is_done,
self.ix_sample,
self.weight,
)
}
fn len(&self) -> usize {
self.reward.len()
}
fn obs(&self) -> &Self::ObsBatch {
&self.obs
}
fn act(&self) -> &Self::ActBatch {
&self.act
}
fn next_obs(&self) -> &Self::ObsBatch {
&self.next_obs
}
fn reward(&self) -> &Vec<f32> {
&self.reward
}
fn is_done(&self) -> &Vec<i8> {
&self.is_done
}
fn weight(&self) -> &Option<Vec<f32>> {
&self.weight
}
fn ix_sample(&self) -> &Option<Vec<usize>> {
&self.ix_sample
}
}