1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use border_core::{Shape, replay_buffer::SubBatch};
use tch::{Tensor, Device};
use std::marker::PhantomData;
pub trait ZeroTensor {
fn zeros(shape: &[i64]) -> Tensor;
}
impl ZeroTensor for u8 {
fn zeros(shape: &[i64]) -> Tensor {
Tensor::zeros(&shape, (tch::kind::Kind::Uint8, Device::Cpu))
}
}
impl ZeroTensor for i32 {
fn zeros(shape: &[i64]) -> Tensor {
Tensor::zeros(&shape, (tch::kind::Kind::Int, Device::Cpu))
}
}
impl ZeroTensor for f32 {
fn zeros(shape: &[i64]) -> Tensor {
Tensor::zeros(&shape, tch::kind::FLOAT_CPU)
}
}
impl ZeroTensor for i64 {
fn zeros(shape: &[i64]) -> Tensor {
Tensor::zeros(&shape, (tch::kind::Kind::Int64, Device::Cpu))
}
}
pub struct TensorSubBatch<S, D> {
buf: Tensor,
capacity: i64,
phantom: PhantomData<(S, D)>,
}
impl<S, D> Clone for TensorSubBatch<S, D> {
fn clone(&self) -> Self {
Self {
buf: self.buf.copy(),
capacity: self.capacity,
phantom: PhantomData
}
}
}
impl<S, D> TensorSubBatch<S, D>
where
S: Shape,
D: 'static + Copy + tch::kind::Element + ZeroTensor,
{
pub fn from_tensor(t: Tensor) -> Self {
let capacity = t.size()[0] as _;
Self {
buf: t,
capacity,
phantom: PhantomData,
}
}
}
impl<S, D> SubBatch for TensorSubBatch<S, D>
where
S: Shape,
D: 'static + Copy + tch::kind::Element + ZeroTensor,
{
fn new(capacity: usize) -> Self {
let capacity = capacity as i64;
let mut shape: Vec<_> = S::shape().to_vec().iter().map(|e| *e as i64).collect();
shape.insert(0, capacity);
let buf = D::zeros(shape.as_slice());
Self {
buf,
capacity,
phantom: PhantomData,
}
}
fn push(&mut self, index: usize, data: &Self) {
let index = index as i64;
let val: Tensor = data.buf.copy();
let batch_size = val.size()[0];
debug_assert_eq!(&val.size()[1..], &self.buf.size()[1..]);
for i_ in 0..batch_size {
let i = (i_ + index) % self.capacity;
self.buf.get(i).copy_(&val.get(i_));
}
}
fn sample(&self, ixs: &Vec<usize>) -> Self {
let ixs = ixs.iter().map(|&ix| ix as i64).collect::<Vec<_>>();
let batch_indexes = Tensor::of_slice(&ixs);
let buf = self.buf.index_select(0, &batch_indexes);
Self {
buf,
capacity: ixs.len() as i64,
phantom: PhantomData,
}
}
}
impl<S, D> From<TensorSubBatch<S, D>> for Tensor {
fn from(b: TensorSubBatch<S, D>) -> Self {
b.buf
}
}