use std::cell::RefCell;
use std::rc::Rc;
use crate::tensor::{Device, Result, Tensor};
#[derive(Clone)]
pub struct Buffer {
pub(crate) inner: Rc<RefCell<Tensor>>,
pub name: String,
}
impl Buffer {
pub fn new(tensor: Tensor, name: &str) -> Self {
Buffer {
inner: Rc::new(RefCell::new(tensor)),
name: name.to_string(),
}
}
pub fn get(&self) -> Tensor {
self.inner.borrow().clone()
}
pub fn set(&self, tensor: Tensor) {
*self.inner.borrow_mut() = tensor;
}
pub fn shape(&self) -> Vec<i64> {
self.inner.borrow().shape()
}
pub fn device(&self) -> Device {
self.inner.borrow().device()
}
pub fn to_device(&self, device: Device) -> Result<()> {
if self.device() != device {
let moved = self.inner.borrow().to_device(device)?;
*self.inner.borrow_mut() = moved;
}
Ok(())
}
}
impl std::fmt::Debug for Buffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Buffer({}, {:?})", self.name, self.inner.borrow().shape())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_shared_via_clone() {
let buf = Buffer::new(
Tensor::zeros(&[3], crate::tensor::test_opts()).unwrap(),
"running_mean",
);
let clone = buf.clone();
let new_data = Tensor::ones(&[3], crate::tensor::test_opts()).unwrap();
clone.set(new_data);
let vals = buf.get().to_f32_vec().unwrap();
assert_eq!(vals, vec![1.0, 1.0, 1.0]);
}
#[test]
fn test_buffer_to_device() {
let buf = Buffer::new(
Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(),
"stats",
);
assert_eq!(buf.device(), crate::tensor::test_device());
buf.to_device(crate::tensor::test_device()).unwrap();
assert_eq!(buf.shape(), vec![4]);
}
}