use crate::Device;
use rlx_ir::{DType, Shape};
#[derive(Debug, Clone)]
pub struct Buffer {
bytes: Vec<u8>,
shape: Shape,
device: Device,
}
impl Buffer {
pub fn new_host(shape: Shape, data: Vec<u8>) -> Self {
Self {
bytes: data,
shape,
device: Device::Cpu,
}
}
pub fn zeros(shape: Shape) -> Self {
let n = shape.size_bytes().unwrap_or(0);
Self {
bytes: vec![0u8; n],
shape,
device: Device::Cpu,
}
}
pub fn shape(&self) -> &Shape {
&self.shape
}
pub fn device(&self) -> Device {
self.device
}
pub fn dtype(&self) -> DType {
self.shape.dtype()
}
pub fn num_elements(&self) -> usize {
self.shape.num_elements().unwrap_or(0)
}
pub fn byte_size(&self) -> usize {
self.bytes.len()
}
pub fn as_f32(&self) -> &[f32] {
assert_eq!(self.dtype(), DType::F32, "as_f32 on non-F32 buffer");
let n = self.num_elements();
unsafe { std::slice::from_raw_parts(self.bytes.as_ptr() as *const f32, n) }
}
pub fn as_f32_mut(&mut self) -> &mut [f32] {
assert_eq!(self.dtype(), DType::F32, "as_f32_mut on non-F32 buffer");
let n = self.num_elements();
unsafe { std::slice::from_raw_parts_mut(self.bytes.as_mut_ptr() as *mut f32, n) }
}
pub fn to_device(self, device: Device) -> Self {
match (self.device, device) {
(a, b) if a == b => self,
(Device::Cpu, Device::Metal) => Self {
device: Device::Metal,
..self
},
(Device::Metal, Device::Cpu) => Self {
device: Device::Cpu,
..self
},
_ => self,
}
}
pub fn to_host(self) -> Self {
self.to_device(Device::Cpu)
}
pub fn host_bytes(&self) -> &[u8] {
assert_eq!(
self.device,
Device::Cpu,
"host_bytes on non-host buffer; call .to_host() first"
);
&self.bytes
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros_initializes() {
let b = Buffer::zeros(Shape::new(&[2, 3], DType::F32));
assert_eq!(b.num_elements(), 6);
assert_eq!(b.byte_size(), 24);
for v in b.as_f32() {
assert_eq!(*v, 0.0);
}
}
#[test]
fn dtype_mismatch_panics() {
let b = Buffer::zeros(Shape::new(&[4], DType::I32));
let result = std::panic::catch_unwind(|| b.as_f32());
assert!(result.is_err());
}
#[test]
fn to_device_round_trip() {
let b = Buffer::zeros(Shape::new(&[4], DType::F32));
let on_metal = b.to_device(Device::Metal);
assert_eq!(on_metal.device(), Device::Metal);
let back = on_metal.to_host();
assert_eq!(back.device(), Device::Cpu);
}
}