border_candle_agent/
tensor_batch.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use border_core::generic_replay_buffer::BatchBase;
use candle_core::{error::Result, DType, Device, Tensor};

/// Adds capability of constructing [`Tensor`] with a static method.
/// 
/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html
pub trait ZeroTensor {
    /// Constructs zero tensor.
    fn zeros(shape: &[usize]) -> Result<Tensor>;
}

impl ZeroTensor for u8 {
    fn zeros(shape: &[usize]) -> Result<Tensor> {
        Tensor::zeros(shape, DType::U8, &Device::Cpu)
    }
}

impl ZeroTensor for f32 {
    fn zeros(shape: &[usize]) -> Result<Tensor> {
        Tensor::zeros(shape, DType::F32, &Device::Cpu)
    }
}

impl ZeroTensor for i64 {
    fn zeros(shape: &[usize]) -> Result<Tensor> {
        Tensor::zeros(shape, DType::I64, &Device::Cpu)
    }
}

/// A buffer consisting of a [`Tensor`].
///
/// The internal buffer is `Vec<Tensor>`.
/// 
/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html
#[derive(Clone, Debug)]
pub struct TensorBatch {
    buf: Vec<Tensor>,
    capacity: usize,
}

impl TensorBatch {
    pub fn from_tensor(t: Tensor) -> Self {
        let capacity = t.dims()[0] as _;
        Self {
            buf: vec![t],
            capacity,
        }
    }
}

impl BatchBase for TensorBatch {
    fn new(capacity: usize) -> Self {
        Self {
            buf: Vec::with_capacity(capacity),
            capacity: capacity,
        }
    }

    /// Pushes given data.
    ///
    /// if ix + data.buf.len() exceeds the self.capacity,
    /// the tail samples in data is placed in the head of the buffer of self.
    fn push(&mut self, ix: usize, data: Self) {
        if self.buf.len() == self.capacity {
            for (i, sample) in data.buf.into_iter().enumerate() {
                let ix_ = (ix + i) % self.capacity;
                self.buf[ix_] = sample;
            }
        } else if self.buf.len() < self.capacity {
            for (i, sample) in data.buf.into_iter().enumerate() {
                if self.buf.len() < self.capacity {
                    self.buf.push(sample);
                } else {
                    let ix_ = (ix + i) % self.capacity;
                    self.buf[ix_] = sample;
                }
            }
        } else {
            panic!("The length of the buffer is SubBatch is larger than its capacity.");
        }
    }

    fn sample(&self, ixs: &Vec<usize>) -> Self {
        let buf = ixs.iter().map(|&ix| self.buf[ix].clone()).collect();
        Self {
            buf,
            capacity: ixs.len(),
        }
    }
}

impl From<TensorBatch> for Tensor {
    fn from(b: TensorBatch) -> Self {
        Tensor::cat(&b.buf[..], 0).unwrap()
    }
}