border_candle_agent/
tensor_batch.rs1use border_core::generic_replay_buffer::BatchBase;
2use candle_core::{error::Result, DType, Device, IndexOp, Tensor};
3
4pub trait ZeroTensor {
8 fn zeros(shape: &[usize]) -> Result<Tensor>;
10}
11
12impl ZeroTensor for u8 {
13 fn zeros(shape: &[usize]) -> Result<Tensor> {
14 Tensor::zeros(shape, DType::U8, &Device::Cpu)
15 }
16}
17
18impl ZeroTensor for f32 {
19 fn zeros(shape: &[usize]) -> Result<Tensor> {
20 Tensor::zeros(shape, DType::F32, &Device::Cpu)
21 }
22}
23
24impl ZeroTensor for i64 {
25 fn zeros(shape: &[usize]) -> Result<Tensor> {
26 Tensor::zeros(shape, DType::I64, &Device::Cpu)
27 }
28}
29
30#[derive(Clone, Debug)]
36pub struct TensorBatch {
37 buf: Option<Tensor>,
38 capacity: usize,
39}
40
41impl TensorBatch {
42 pub fn from_tensor(t: Tensor) -> Self {
43 let capacity = t.dims()[0] as _;
44 Self {
45 buf: Some(t),
46 capacity,
47 }
48 }
49
50 pub fn to(&mut self, device: &Device) -> Result<()> {
51 if let Some(buf) = &self.buf {
52 self.buf = Some(buf.to_device(device)?);
53 }
54 Ok(())
55 }
56}
57
58impl BatchBase for TensorBatch {
59 fn new(capacity: usize) -> Self {
60 Self {
61 buf: None,
62 capacity: capacity,
63 }
64 }
65
66 fn push(&mut self, index: usize, data: Self) {
71 if data.buf.is_none() {
72 return;
73 }
74
75 let batch_size = data.buf.as_ref().unwrap().dims()[0];
76 if batch_size == 0 {
77 return;
78 }
79
80 if self.buf.is_none() {
81 let mut shape = data.buf.as_ref().unwrap().dims().to_vec();
82 shape[0] = self.capacity;
83 let dtype = data.buf.as_ref().unwrap().dtype();
84 let device = Device::Cpu;
85 self.buf = Some(Tensor::zeros(shape, dtype, &device).unwrap());
86 }
87
88 if index + batch_size > self.capacity {
89 let batch_size = self.capacity - index;
90 let data = &data.buf.unwrap();
91 let data1 = data.i((..batch_size,)).unwrap();
92 let data2 = data.i((batch_size..,)).unwrap();
93 self.buf
94 .as_mut()
95 .unwrap()
96 .slice_set(&data1, 0, index)
97 .unwrap();
98 self.buf.as_mut().unwrap().slice_set(&data2, 0, 0).unwrap();
99 } else {
100 self.buf
101 .as_mut()
102 .unwrap()
103 .slice_set(&data.buf.unwrap(), 0, index)
104 .unwrap();
105 }
106 }
107
108 fn sample(&self, ixs: &Vec<usize>) -> Self {
109 let capacity = ixs.len();
110 let ixs = {
111 let device = self.buf.as_ref().unwrap().device();
112 let ixs = ixs.iter().map(|x| *x as u32).collect();
113 Tensor::from_vec(ixs, &[capacity], device).unwrap()
114 };
115 let buf = Some(self.buf.as_ref().unwrap().index_select(&ixs, 0).unwrap());
116 Self { buf, capacity }
117 }
118}
119
120impl From<TensorBatch> for Tensor {
121 fn from(b: TensorBatch) -> Self {
122 b.buf.unwrap()
123 }
124}