use std::sync::Arc;
use crate::dtypes::DType;
use crate::errors::{EtensorError, EtensorResult};
#[derive(Debug, Clone)]
pub enum CpuBuffer {
F32(Arc<Vec<f32>>),
I32(Arc<Vec<i32>>),
I8(Arc<Vec<i8>>),
F16(Arc<Vec<u16>>),
BF16(Arc<Vec<u16>>),
}
#[derive(Debug, Clone)]
pub enum Buffer {
Cpu(CpuBuffer),
#[cfg(feature = "cuda-native")]
CudaNative,
#[cfg(feature = "torch")]
CudaTorch, }
impl Buffer {
pub fn new_cpu_zeros(size: usize, dtype: DType) -> Self {
let cpu_buf = match dtype {
DType::F32 => CpuBuffer::F32(Arc::new(vec![0.0; size])),
DType::I32 => CpuBuffer::I32(Arc::new(vec![0; size])),
DType::I8 => CpuBuffer::I8(Arc::new(vec![0; size])),
DType::F16 => CpuBuffer::F16(Arc::new(vec![0; size])),
DType::BF16 => CpuBuffer::BF16(Arc::new(vec![0; size])),
};
Buffer::Cpu(cpu_buf)
}
pub fn from_f32_vec(data: Vec<f32>) -> Self {
Buffer::Cpu(CpuBuffer::F32(Arc::new(data)))
}
pub fn as_f32_slice(&self) -> EtensorResult<&[f32]> {
match self {
Buffer::Cpu(CpuBuffer::F32(arc_vec)) => Ok(arc_vec.as_slice()),
Buffer::Cpu(_) => Err(EtensorError::DTypeMismatch {
expected: "float32".to_string(),
got: "other CPU dtype".to_string(),
}),
#[cfg(feature = "cuda-native")]
Buffer::CudaNative => Err(EtensorError::DeviceMismatch {
expected: "cpu".to_string(),
got: "cuda_native".to_string(),
}),
#[cfg(feature = "torch")]
Buffer::CudaTorch => Err(EtensorError::DeviceMismatch {
expected: "cpu".to_string(),
got: "cuda_torch".to_string(),
}),
}
}
pub fn strong_count(&self) -> EtensorResult<usize> {
match self {
Buffer::Cpu(CpuBuffer::F32(arc)) => Ok(Arc::strong_count(arc)),
Buffer::Cpu(CpuBuffer::I32(arc)) => Ok(Arc::strong_count(arc)),
Buffer::Cpu(CpuBuffer::I8(arc)) => Ok(Arc::strong_count(arc)),
Buffer::Cpu(CpuBuffer::F16(arc)) => Ok(Arc::strong_count(arc)),
Buffer::Cpu(CpuBuffer::BF16(arc)) => Ok(Arc::strong_count(arc)),
#[cfg(feature = "cuda-native")]
Buffer::CudaNative => Err(EtensorError::InternalError("Arc count unsupported on CUDA".to_string())),
#[cfg(feature = "torch")]
Buffer::CudaTorch => Err(EtensorError::InternalError("Arc count unsupported on Torch".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cpu_buffer_initialization() {
let buffer = Buffer::new_cpu_zeros(100, DType::F32);
let slice = buffer.as_f32_slice().unwrap();
assert_eq!(slice.len(), 100);
assert_eq!(slice[0], 0.0);
assert_eq!(slice[99], 0.0);
}
#[test]
fn test_dtype_mismatch_rejection() {
let buffer = Buffer::new_cpu_zeros(50, DType::I32);
let result = buffer.as_f32_slice();
assert!(result.is_err());
if let Err(EtensorError::DTypeMismatch { expected, got }) = result {
assert_eq!(expected, "float32");
assert_eq!(got, "other CPU dtype");
} else {
panic!("Expected DTypeMismatch error!");
}
}
#[test]
fn test_arc_cloning_is_zero_copy() {
let buffer_a = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(buffer_a.strong_count().unwrap(), 1);
let buffer_b = buffer_a.clone();
assert_eq!(buffer_a.strong_count().unwrap(), 2);
assert_eq!(buffer_b.strong_count().unwrap(), 2);
let slice_a = buffer_a.as_f32_slice().unwrap();
let slice_b = buffer_b.as_f32_slice().unwrap();
assert_eq!(slice_a[1], 2.0);
assert_eq!(slice_b[1], 2.0);
}
}