border_tch_agent/
tensor_batch.rs1use border_core::generic_replay_buffer::BatchBase;
2use tch::{Device, Tensor};
3
4pub trait ZeroTensor {
8 fn zeros(shape: &[i64]) -> Tensor;
10}
11
12impl ZeroTensor for u8 {
13 fn zeros(shape: &[i64]) -> Tensor {
14 Tensor::zeros(shape, (tch::kind::Kind::Uint8, Device::Cpu))
15 }
16}
17
18impl ZeroTensor for i32 {
19 fn zeros(shape: &[i64]) -> Tensor {
20 Tensor::zeros(shape, (tch::kind::Kind::Int, Device::Cpu))
21 }
22}
23
24impl ZeroTensor for f32 {
25 fn zeros(shape: &[i64]) -> Tensor {
26 Tensor::zeros(shape, tch::kind::FLOAT_CPU)
27 }
28}
29
30impl ZeroTensor for i64 {
31 fn zeros(shape: &[i64]) -> Tensor {
32 Tensor::zeros(shape, (tch::kind::Kind::Int64, Device::Cpu))
33 }
34}
35
36pub struct TensorBatch {
45 buf: Option<Tensor>,
46 capacity: i64,
47}
48
49impl Clone for TensorBatch {
50 fn clone(&self) -> Self {
51 let buf = match self.buf.is_none() {
52 true => None,
53 false => Some(self.buf.as_ref().unwrap().copy()),
54 };
55
56 Self {
57 buf,
58 capacity: self.capacity,
59 }
60 }
61}
62
63impl TensorBatch {
64 pub fn from_tensor(t: Tensor) -> Self {
65 let capacity = t.size()[0] as _;
66 Self {
67 buf: Some(t),
68 capacity,
69 }
70 }
71}
72
73impl BatchBase for TensorBatch {
74 fn new(capacity: usize) -> Self {
75 Self {
76 buf: None,
77 capacity: capacity as _,
78 }
79 }
80
81 fn push(&mut self, index: usize, data: Self) {
86 if data.buf.is_none() {
87 return;
88 }
89
90 let batch_size = data.buf.as_ref().unwrap().size()[0];
91 if batch_size == 0 {
92 return;
93 }
94
95 if self.buf.is_none() {
96 let mut shape = data.buf.as_ref().unwrap().size().clone();
97 shape[0] = self.capacity;
98 let kind = data.buf.as_ref().unwrap().kind();
99 let device = tch::Device::Cpu;
100 self.buf = Some(Tensor::zeros(&shape, (kind, device)));
101 }
102
103 let index = index as i64;
104 let val: Tensor = data.buf.as_ref().unwrap().copy();
105
106 for i_ in 0..batch_size {
107 let i = (i_ + index) % self.capacity;
108 self.buf.as_ref().unwrap().get(i).copy_(&val.get(i_));
109 }
110 }
111
112 fn sample(&self, ixs: &Vec<usize>) -> Self {
113 let ixs = ixs.iter().map(|&ix| ix as i64).collect::<Vec<_>>();
114 let batch_indexes = Tensor::from_slice(&ixs);
115 let buf = Some(self.buf.as_ref().unwrap().index_select(0, &batch_indexes));
116 Self {
117 buf,
118 capacity: ixs.len() as i64,
119 }
120 }
121}
122
123impl From<TensorBatch> for Tensor {
124 fn from(b: TensorBatch) -> Self {
125 b.buf.unwrap()
126 }
127}