border_candle_agent/
tensor_batch.rsuse border_core::generic_replay_buffer::BatchBase;
use candle_core::{error::Result, DType, Device, Tensor};
pub trait ZeroTensor {
fn zeros(shape: &[usize]) -> Result<Tensor>;
}
impl ZeroTensor for u8 {
fn zeros(shape: &[usize]) -> Result<Tensor> {
Tensor::zeros(shape, DType::U8, &Device::Cpu)
}
}
impl ZeroTensor for f32 {
fn zeros(shape: &[usize]) -> Result<Tensor> {
Tensor::zeros(shape, DType::F32, &Device::Cpu)
}
}
impl ZeroTensor for i64 {
fn zeros(shape: &[usize]) -> Result<Tensor> {
Tensor::zeros(shape, DType::I64, &Device::Cpu)
}
}
#[derive(Clone, Debug)]
pub struct TensorBatch {
buf: Vec<Tensor>,
capacity: usize,
}
impl TensorBatch {
pub fn from_tensor(t: Tensor) -> Self {
let capacity = t.dims()[0] as _;
Self {
buf: vec![t],
capacity,
}
}
}
impl BatchBase for TensorBatch {
fn new(capacity: usize) -> Self {
Self {
buf: Vec::with_capacity(capacity),
capacity: capacity,
}
}
fn push(&mut self, ix: usize, data: Self) {
if self.buf.len() == self.capacity {
for (i, sample) in data.buf.into_iter().enumerate() {
let ix_ = (ix + i) % self.capacity;
self.buf[ix_] = sample;
}
} else if self.buf.len() < self.capacity {
for (i, sample) in data.buf.into_iter().enumerate() {
if self.buf.len() < self.capacity {
self.buf.push(sample);
} else {
let ix_ = (ix + i) % self.capacity;
self.buf[ix_] = sample;
}
}
} else {
panic!("The length of the buffer is SubBatch is larger than its capacity.");
}
}
fn sample(&self, ixs: &Vec<usize>) -> Self {
let buf = ixs.iter().map(|&ix| self.buf[ix].clone()).collect();
Self {
buf,
capacity: ixs.len(),
}
}
}
impl From<TensorBatch> for Tensor {
fn from(b: TensorBatch) -> Self {
Tensor::cat(&b.buf[..], 0).unwrap()
}
}