border_candle_agent/
tensor_batch.rs

1use border_core::generic_replay_buffer::BatchBase;
2use candle_core::{error::Result, DType, Device, IndexOp, Tensor};
3
4/// Adds capability of constructing [`Tensor`] with a static method.
5///
6/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html
7pub trait ZeroTensor {
8    /// Constructs zero tensor.
9    fn zeros(shape: &[usize]) -> Result<Tensor>;
10}
11
12impl ZeroTensor for u8 {
13    fn zeros(shape: &[usize]) -> Result<Tensor> {
14        Tensor::zeros(shape, DType::U8, &Device::Cpu)
15    }
16}
17
18impl ZeroTensor for f32 {
19    fn zeros(shape: &[usize]) -> Result<Tensor> {
20        Tensor::zeros(shape, DType::F32, &Device::Cpu)
21    }
22}
23
24impl ZeroTensor for i64 {
25    fn zeros(shape: &[usize]) -> Result<Tensor> {
26        Tensor::zeros(shape, DType::I64, &Device::Cpu)
27    }
28}
29
30/// A buffer consisting of a [`Tensor`].
31///
32/// The internal buffer is `Vec<Tensor>`.
33///
34/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html
35#[derive(Clone, Debug)]
36pub struct TensorBatch {
37    buf: Option<Tensor>,
38    capacity: usize,
39}
40
41impl TensorBatch {
42    pub fn from_tensor(t: Tensor) -> Self {
43        let capacity = t.dims()[0] as _;
44        Self {
45            buf: Some(t),
46            capacity,
47        }
48    }
49
50    pub fn to(&mut self, device: &Device) -> Result<()> {
51        if let Some(buf) = &self.buf {
52            self.buf = Some(buf.to_device(device)?);
53        }
54        Ok(())
55    }
56}
57
58impl BatchBase for TensorBatch {
59    fn new(capacity: usize) -> Self {
60        Self {
61            buf: None,
62            capacity: capacity,
63        }
64    }
65
66    /// Pushes given data.
67    ///
68    /// If the internal buffer is empty, it will be initialized with the shape
69    /// `[capacity, data.buf.dims()[1..]]`.
70    fn push(&mut self, index: usize, data: Self) {
71        if data.buf.is_none() {
72            return;
73        }
74
75        let batch_size = data.buf.as_ref().unwrap().dims()[0];
76        if batch_size == 0 {
77            return;
78        }
79
80        if self.buf.is_none() {
81            let mut shape = data.buf.as_ref().unwrap().dims().to_vec();
82            shape[0] = self.capacity;
83            let dtype = data.buf.as_ref().unwrap().dtype();
84            let device = Device::Cpu;
85            self.buf = Some(Tensor::zeros(shape, dtype, &device).unwrap());
86        }
87
88        if index + batch_size > self.capacity {
89            let batch_size = self.capacity - index;
90            let data = &data.buf.unwrap();
91            let data1 = data.i((..batch_size,)).unwrap();
92            let data2 = data.i((batch_size..,)).unwrap();
93            self.buf
94                .as_mut()
95                .unwrap()
96                .slice_set(&data1, 0, index)
97                .unwrap();
98            self.buf.as_mut().unwrap().slice_set(&data2, 0, 0).unwrap();
99        } else {
100            self.buf
101                .as_mut()
102                .unwrap()
103                .slice_set(&data.buf.unwrap(), 0, index)
104                .unwrap();
105        }
106    }
107
108    fn sample(&self, ixs: &Vec<usize>) -> Self {
109        let capacity = ixs.len();
110        let ixs = {
111            let device = self.buf.as_ref().unwrap().device();
112            let ixs = ixs.iter().map(|x| *x as u32).collect();
113            Tensor::from_vec(ixs, &[capacity], device).unwrap()
114        };
115        let buf = Some(self.buf.as_ref().unwrap().index_select(&ixs, 0).unwrap());
116        Self { buf, capacity }
117    }
118}
119
120impl From<TensorBatch> for Tensor {
121    fn from(b: TensorBatch) -> Self {
122        b.buf.unwrap()
123    }
124}