border_tch_agent/
tensor_batch.rs

1use border_core::generic_replay_buffer::BatchBase;
2use tch::{Device, Tensor};
3
4/// Adds capability of constructing [`Tensor`] with a static method.
5///
6/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html
7pub trait ZeroTensor {
8    /// Constructs zero tensor.
9    fn zeros(shape: &[i64]) -> Tensor;
10}
11
12impl ZeroTensor for u8 {
13    fn zeros(shape: &[i64]) -> Tensor {
14        Tensor::zeros(shape, (tch::kind::Kind::Uint8, Device::Cpu))
15    }
16}
17
18impl ZeroTensor for i32 {
19    fn zeros(shape: &[i64]) -> Tensor {
20        Tensor::zeros(shape, (tch::kind::Kind::Int, Device::Cpu))
21    }
22}
23
24impl ZeroTensor for f32 {
25    fn zeros(shape: &[i64]) -> Tensor {
26        Tensor::zeros(shape, tch::kind::FLOAT_CPU)
27    }
28}
29
30impl ZeroTensor for i64 {
31    fn zeros(shape: &[i64]) -> Tensor {
32        Tensor::zeros(shape, (tch::kind::Kind::Int64, Device::Cpu))
33    }
34}
35
36/// A buffer consisting of a [`Tensor`].
37///
38/// The internal buffer of this struct has the shape of `[n_capacity, shape[1..]]`,
39/// where `shape` is obtained from the data pushed at the first time via
40/// [`TensorBatch::push`] method. `[1..]` means that the first axis of the
41/// given data is ignored as it might be batch size.
42///
43/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html
44pub struct TensorBatch {
45    buf: Option<Tensor>,
46    capacity: i64,
47}
48
49impl Clone for TensorBatch {
50    fn clone(&self) -> Self {
51        let buf = match self.buf.is_none() {
52            true => None,
53            false => Some(self.buf.as_ref().unwrap().copy()),
54        };
55
56        Self {
57            buf,
58            capacity: self.capacity,
59        }
60    }
61}
62
63impl TensorBatch {
64    pub fn from_tensor(t: Tensor) -> Self {
65        let capacity = t.size()[0] as _;
66        Self {
67            buf: Some(t),
68            capacity,
69        }
70    }
71}
72
73impl BatchBase for TensorBatch {
74    fn new(capacity: usize) -> Self {
75        Self {
76            buf: None,
77            capacity: capacity as _,
78        }
79    }
80
81    /// Pushes given data.
82    ///
83    /// If the internal buffer is empty, it will be initialized with the shape
84    /// `[capacity, data.buf.size()[1..]]`.
85    fn push(&mut self, index: usize, data: Self) {
86        if data.buf.is_none() {
87            return;
88        }
89
90        let batch_size = data.buf.as_ref().unwrap().size()[0];
91        if batch_size == 0 {
92            return;
93        }
94
95        if self.buf.is_none() {
96            let mut shape = data.buf.as_ref().unwrap().size().clone();
97            shape[0] = self.capacity;
98            let kind = data.buf.as_ref().unwrap().kind();
99            let device = tch::Device::Cpu;
100            self.buf = Some(Tensor::zeros(&shape, (kind, device)));
101        }
102
103        let index = index as i64;
104        let val: Tensor = data.buf.as_ref().unwrap().copy();
105
106        for i_ in 0..batch_size {
107            let i = (i_ + index) % self.capacity;
108            self.buf.as_ref().unwrap().get(i).copy_(&val.get(i_));
109        }
110    }
111
112    fn sample(&self, ixs: &Vec<usize>) -> Self {
113        let ixs = ixs.iter().map(|&ix| ix as i64).collect::<Vec<_>>();
114        let batch_indexes = Tensor::from_slice(&ixs);
115        let buf = Some(self.buf.as_ref().unwrap().index_select(0, &batch_indexes));
116        Self {
117            buf,
118            capacity: ixs.len() as i64,
119        }
120    }
121}
122
123impl From<TensorBatch> for Tensor {
124    fn from(b: TensorBatch) -> Self {
125        b.buf.unwrap()
126    }
127}