1use std::sync::Arc;
8use crate::dtypes::DType;
9use crate::errors::{EtensorError, EtensorResult};
10
11#[derive(Debug, Clone)]
14pub enum CpuBuffer {
15 F32(Arc<Vec<f32>>),
16 I32(Arc<Vec<i32>>),
17 I8(Arc<Vec<i8>>),
18 F16(Arc<Vec<u16>>),
21 BF16(Arc<Vec<u16>>),
22}
23
24#[derive(Debug, Clone)]
26pub enum Buffer {
27 Cpu(CpuBuffer),
28
29 #[cfg(feature = "cuda-native")]
32 CudaNative, #[cfg(feature = "torch")]
36 CudaTorch, }
38
39impl Buffer {
40 pub fn new_cpu_zeros(size: usize, dtype: DType) -> Self {
42 let cpu_buf = match dtype {
43 DType::F32 => CpuBuffer::F32(Arc::new(vec![0.0; size])),
44 DType::I32 => CpuBuffer::I32(Arc::new(vec![0; size])),
45 DType::I8 => CpuBuffer::I8(Arc::new(vec![0; size])),
46 DType::F16 => CpuBuffer::F16(Arc::new(vec![0; size])),
47 DType::BF16 => CpuBuffer::BF16(Arc::new(vec![0; size])),
48 };
49 Buffer::Cpu(cpu_buf)
50 }
51
52 pub fn from_f32_vec(data: Vec<f32>) -> Self {
54 Buffer::Cpu(CpuBuffer::F32(Arc::new(data)))
55 }
56
57 pub fn as_f32_slice(&self) -> EtensorResult<&[f32]> {
60 match self {
61 Buffer::Cpu(CpuBuffer::F32(arc_vec)) => Ok(arc_vec.as_slice()),
62 Buffer::Cpu(_) => Err(EtensorError::DTypeMismatch {
63 expected: "float32".to_string(),
64 got: "other CPU dtype".to_string(),
65 }),
66 #[cfg(feature = "cuda-native")]
67 Buffer::CudaNative => Err(EtensorError::DeviceMismatch {
68 expected: "cpu".to_string(),
69 got: "cuda_native".to_string(),
70 }),
71 #[cfg(feature = "torch")]
72 Buffer::CudaTorch => Err(EtensorError::DeviceMismatch {
73 expected: "cpu".to_string(),
74 got: "cuda_torch".to_string(),
75 }),
76 }
77 }
78
79 pub fn strong_count(&self) -> EtensorResult<usize> {
81 match self {
82 Buffer::Cpu(CpuBuffer::F32(arc)) => Ok(Arc::strong_count(arc)),
83 Buffer::Cpu(CpuBuffer::I32(arc)) => Ok(Arc::strong_count(arc)),
84 Buffer::Cpu(CpuBuffer::I8(arc)) => Ok(Arc::strong_count(arc)),
85 Buffer::Cpu(CpuBuffer::F16(arc)) => Ok(Arc::strong_count(arc)),
86 Buffer::Cpu(CpuBuffer::BF16(arc)) => Ok(Arc::strong_count(arc)),
87 #[cfg(feature = "cuda-native")]
88 Buffer::CudaNative => Err(EtensorError::InternalError("Arc count unsupported on CUDA".to_string())),
89 #[cfg(feature = "torch")]
90 Buffer::CudaTorch => Err(EtensorError::InternalError("Arc count unsupported on Torch".to_string())),
91 }
92 }
93}
94
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_cpu_buffer_initialization() {
102 let buffer = Buffer::new_cpu_zeros(100, DType::F32);
103 let slice = buffer.as_f32_slice().unwrap();
104
105 assert_eq!(slice.len(), 100);
106 assert_eq!(slice[0], 0.0);
107 assert_eq!(slice[99], 0.0);
108 }
109
110 #[test]
111 fn test_dtype_mismatch_rejection() {
112 let buffer = Buffer::new_cpu_zeros(50, DType::I32);
114
115 let result = buffer.as_f32_slice();
117
118 assert!(result.is_err());
119 if let Err(EtensorError::DTypeMismatch { expected, got }) = result {
120 assert_eq!(expected, "float32");
121 assert_eq!(got, "other CPU dtype");
122 } else {
123 panic!("Expected DTypeMismatch error!");
124 }
125 }
126
127 #[test]
128 fn test_arc_cloning_is_zero_copy() {
129 let buffer_a = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0]);
130
131 assert_eq!(buffer_a.strong_count().unwrap(), 1);
133
134 let buffer_b = buffer_a.clone();
136
137 assert_eq!(buffer_a.strong_count().unwrap(), 2);
139 assert_eq!(buffer_b.strong_count().unwrap(), 2);
140
141 let slice_a = buffer_a.as_f32_slice().unwrap();
143 let slice_b = buffer_b.as_f32_slice().unwrap();
144
145 assert_eq!(slice_a[1], 2.0);
146 assert_eq!(slice_b[1], 2.0);
147 }
148}