border_tch_agent/
tensor_batch.rs1use border_core::generic_replay_buffer::BatchBase;
2use tch::{Device, Tensor};
3
4pub trait ZeroTensor {
8 fn zeros(shape: &[i64]) -> Tensor;
10}
11
12impl ZeroTensor for u8 {
13 fn zeros(shape: &[i64]) -> Tensor {
14 Tensor::zeros(shape, (tch::kind::Kind::Uint8, Device::Cpu))
15 }
16}
17
18impl ZeroTensor for i32 {
19 fn zeros(shape: &[i64]) -> Tensor {
20 Tensor::zeros(shape, (tch::kind::Kind::Int, Device::Cpu))
21 }
22}
23
24impl ZeroTensor for f32 {
25 fn zeros(shape: &[i64]) -> Tensor {
26 Tensor::zeros(shape, tch::kind::FLOAT_CPU)
27 }
28}
29
30impl ZeroTensor for i64 {
31 fn zeros(shape: &[i64]) -> Tensor {
32 Tensor::zeros(shape, (tch::kind::Kind::Int64, Device::Cpu))
33 }
34}
35
36pub struct TensorBatch {
45 buf: Option<Tensor>,
46 capacity: i64,
47}
48
49impl Clone for TensorBatch {
50 fn clone(&self) -> Self {
51 let buf = match self.buf.is_none() {
52 true => None,
53 false => Some(self.buf.as_ref().unwrap().copy()),
54 };
55
56 Self {
57 buf,
58 capacity: self.capacity,
59 }
60 }
61}
62
63impl TensorBatch {
64 pub fn from_tensor(t: Tensor) -> Self {
65 let capacity = t.size()[0] as _;
66 Self {
67 buf: Some(t),
68 capacity,
69 }
70 }
71}
72
73impl BatchBase for TensorBatch {
74 fn new(capacity: usize) -> Self {
75 Self {
81 buf: None,
82 capacity: capacity as _,
83 }
84 }
85
86 fn push(&mut self, index: usize, data: Self) {
91 if data.buf.is_none() {
92 return;
93 }
94
95 let batch_size = data.buf.as_ref().unwrap().size()[0];
96 if batch_size == 0 {
97 return;
98 }
99
100 if self.buf.is_none() {
101 let mut shape = data.buf.as_ref().unwrap().size().clone();
102 shape[0] = self.capacity;
103 let kind = data.buf.as_ref().unwrap().kind();
104 let device = tch::Device::Cpu;
105 self.buf = Some(Tensor::zeros(&shape, (kind, device)));
106 }
107
108 let index = index as i64;
109 let val: Tensor = data.buf.as_ref().unwrap().copy();
110
111 for i_ in 0..batch_size {
112 let i = (i_ + index) % self.capacity;
113 self.buf.as_ref().unwrap().get(i).copy_(&val.get(i_));
114 }
115 }
116
117 fn sample(&self, ixs: &Vec<usize>) -> Self {
118 let ixs = ixs.iter().map(|&ix| ix as i64).collect::<Vec<_>>();
119 let batch_indexes = Tensor::from_slice(&ixs);
120 let buf = Some(self.buf.as_ref().unwrap().index_select(0, &batch_indexes));
121 Self {
122 buf,
123 capacity: ixs.len() as i64,
124 }
125 }
126}
127
128impl From<TensorBatch> for Tensor {
129 fn from(b: TensorBatch) -> Self {
130 b.buf.unwrap()
131 }
132}