border_tch_agent/
tensor_batch.rs

1use border_core::generic_replay_buffer::BatchBase;
2use tch::{Device, Tensor};
3
4/// Adds capability of constructing [`Tensor`] with a static method.
5/// 
6/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html
7pub trait ZeroTensor {
8    /// Constructs zero tensor.
9    fn zeros(shape: &[i64]) -> Tensor;
10}
11
12impl ZeroTensor for u8 {
13    fn zeros(shape: &[i64]) -> Tensor {
14        Tensor::zeros(shape, (tch::kind::Kind::Uint8, Device::Cpu))
15    }
16}
17
18impl ZeroTensor for i32 {
19    fn zeros(shape: &[i64]) -> Tensor {
20        Tensor::zeros(shape, (tch::kind::Kind::Int, Device::Cpu))
21    }
22}
23
24impl ZeroTensor for f32 {
25    fn zeros(shape: &[i64]) -> Tensor {
26        Tensor::zeros(shape, tch::kind::FLOAT_CPU)
27    }
28}
29
30impl ZeroTensor for i64 {
31    fn zeros(shape: &[i64]) -> Tensor {
32        Tensor::zeros(shape, (tch::kind::Kind::Int64, Device::Cpu))
33    }
34}
35
36/// A buffer consisting of a [`Tensor`].
37///
38/// The internal buffer of this struct has the shape of `[n_capacity, shape[1..]]`,
39/// where `shape` is obtained from the data pushed at the first time via
40/// [`TensorBatch::push`] method. `[1..]` means that the first axis of the
41/// given data is ignored as it might be batch size.
42/// 
43/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html
44pub struct TensorBatch {
45    buf: Option<Tensor>,
46    capacity: i64,
47}
48
49impl Clone for TensorBatch {
50    fn clone(&self) -> Self {
51        let buf = match self.buf.is_none() {
52            true => None,
53            false => Some(self.buf.as_ref().unwrap().copy()),
54        };
55
56        Self {
57            buf,
58            capacity: self.capacity,
59        }
60    }
61}
62
63impl TensorBatch {
64    pub fn from_tensor(t: Tensor) -> Self {
65        let capacity = t.size()[0] as _;
66        Self {
67            buf: Some(t),
68            capacity,
69        }
70    }
71}
72
73impl BatchBase for TensorBatch {
74    fn new(capacity: usize) -> Self {
75        // let capacity = capacity as i64;
76        // let mut shape: Vec<_> = S::shape().to_vec().iter().map(|e| *e as i64).collect();
77        // shape.insert(0, capacity);
78        // let buf = D::zeros(shape.as_slice());
79
80        Self {
81            buf: None,
82            capacity: capacity as _,
83        }
84    }
85
86    /// Pushes given data.
87    ///
88    /// If the internal buffer is empty, it will be initialized with the shape
89    /// `[capacity, data.buf.size()[1..]]`.
90    fn push(&mut self, index: usize, data: Self) {
91        if data.buf.is_none() {
92            return;
93        }
94
95        let batch_size = data.buf.as_ref().unwrap().size()[0];
96        if batch_size == 0 {
97            return;
98        }
99
100        if self.buf.is_none() {
101            let mut shape = data.buf.as_ref().unwrap().size().clone();
102            shape[0] = self.capacity;
103            let kind = data.buf.as_ref().unwrap().kind();
104            let device = tch::Device::Cpu;
105            self.buf = Some(Tensor::zeros(&shape, (kind, device)));
106        }
107
108        let index = index as i64;
109        let val: Tensor = data.buf.as_ref().unwrap().copy();
110
111        for i_ in 0..batch_size {
112            let i = (i_ + index) % self.capacity;
113            self.buf.as_ref().unwrap().get(i).copy_(&val.get(i_));
114        }
115    }
116
117    fn sample(&self, ixs: &Vec<usize>) -> Self {
118        let ixs = ixs.iter().map(|&ix| ix as i64).collect::<Vec<_>>();
119        let batch_indexes = Tensor::from_slice(&ixs);
120        let buf = Some(self.buf.as_ref().unwrap().index_select(0, &batch_indexes));
121        Self {
122            buf,
123            capacity: ixs.len() as i64,
124        }
125    }
126}
127
128impl From<TensorBatch> for Tensor {
129    fn from(b: TensorBatch) -> Self {
130        b.buf.unwrap()
131    }
132}