Skip to main content

etensor_core/
buffer.rs

1//! The Universal Memory Box.
2//! 
3//! The `Buffer` enum wraps physical hardware memory allocations. It strictly 
4//! acts as a flat, contiguous data container. All spatial geometry (dimensions 
5//! and strides) is explicitly managed by the `Shape` struct in `src/shape.rs`.
6
7use std::sync::Arc;
8use crate::dtypes::DType;
9use crate::errors::{EtensorError, EtensorResult};
10
11/// Contiguous CPU memory blocks wrapped in an Atomic Reference Counter (Arc) 
12/// to ensure zero-copy graph sharing and thread-safe execution.
13#[derive(Debug, Clone)]
14pub enum CpuBuffer {
15    F32(Arc<Vec<f32>>),
16    I32(Arc<Vec<i32>>),
17    I8(Arc<Vec<i8>>),
18    // F16 and BF16 are stored as raw u16 bits on the CPU to avoid forcing an external 
19    // half-precision crate dependency until mathematically required by kernels.k
20    F16(Arc<Vec<u16>>),
21    BF16(Arc<Vec<u16>>),
22}
23
24/// The hardware-agnostic physical data container.
25#[derive(Debug, Clone)]
26pub enum Buffer {
27    Cpu(CpuBuffer),
28    
29    // FUTURE IMPLEMENTATION: Wave 2 (Surgical VRAM)
30    // Wrapped in a cfg flag so standard users don't need the CUDA Toolkit installed.
31    #[cfg(feature = "cuda-native")]
32    CudaNative, // Stub placeholder for cudarc::driver::CudaSlice<u8>
33    
34    // FUTURE IMPLEMENTATION: Wave 3 (Enterprise LibTorch)
35    #[cfg(feature = "torch")]
36    CudaTorch, // Stub placeholder for tch::Tensor
37}
38
39impl Buffer {
40    /// Allocates a new contiguous CPU Buffer initialized with zeros.
41    pub fn new_cpu_zeros(size: usize, dtype: DType) -> Self {
42        let cpu_buf = match dtype {
43            DType::F32 => CpuBuffer::F32(Arc::new(vec![0.0; size])),
44            DType::I32 => CpuBuffer::I32(Arc::new(vec![0; size])),
45            DType::I8 => CpuBuffer::I8(Arc::new(vec![0; size])),
46            DType::F16 => CpuBuffer::F16(Arc::new(vec![0; size])),
47            DType::BF16 => CpuBuffer::BF16(Arc::new(vec![0; size])),
48        };
49        Buffer::Cpu(cpu_buf)
50    }
51
52    /// Allocates a new CPU Buffer from an existing Vec of f32 data.
53    pub fn from_f32_vec(data: Vec<f32>) -> Self {
54        Buffer::Cpu(CpuBuffer::F32(Arc::new(data)))
55    }
56
57    /// Safely extracts a reference to the underlying CPU f32 slice.
58    /// Acts as an explicit boundary guard against DataType or Device mismatches.
59    pub fn as_f32_slice(&self) -> EtensorResult<&[f32]> {
60        match self {
61            Buffer::Cpu(CpuBuffer::F32(arc_vec)) => Ok(arc_vec.as_slice()),
62            Buffer::Cpu(_) => Err(EtensorError::DTypeMismatch {
63                expected: "float32".to_string(),
64                got: "other CPU dtype".to_string(),
65            }),
66            #[cfg(feature = "cuda-native")]
67            Buffer::CudaNative => Err(EtensorError::DeviceMismatch {
68                expected: "cpu".to_string(),
69                got: "cuda_native".to_string(),
70            }),
71            #[cfg(feature = "torch")]
72            Buffer::CudaTorch => Err(EtensorError::DeviceMismatch {
73                expected: "cpu".to_string(),
74                got: "cuda_torch".to_string(),
75            }),
76        }
77    }
78
79    /// Returns the active Arc pointer reference count for memory tracking.
80    pub fn strong_count(&self) -> EtensorResult<usize> {
81        match self {
82            Buffer::Cpu(CpuBuffer::F32(arc)) => Ok(Arc::strong_count(arc)),
83            Buffer::Cpu(CpuBuffer::I32(arc)) => Ok(Arc::strong_count(arc)),
84            Buffer::Cpu(CpuBuffer::I8(arc)) => Ok(Arc::strong_count(arc)),
85            Buffer::Cpu(CpuBuffer::F16(arc)) => Ok(Arc::strong_count(arc)),
86            Buffer::Cpu(CpuBuffer::BF16(arc)) => Ok(Arc::strong_count(arc)),
87            #[cfg(feature = "cuda-native")]
88            Buffer::CudaNative => Err(EtensorError::InternalError("Arc count unsupported on CUDA".to_string())),
89            #[cfg(feature = "torch")]
90            Buffer::CudaTorch => Err(EtensorError::InternalError("Arc count unsupported on Torch".to_string())),
91        }
92    }
93}
94
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_cpu_buffer_initialization() {
102        let buffer = Buffer::new_cpu_zeros(100, DType::F32);
103        let slice = buffer.as_f32_slice().unwrap();
104        
105        assert_eq!(slice.len(), 100);
106        assert_eq!(slice[0], 0.0);
107        assert_eq!(slice[99], 0.0);
108    }
109
110    #[test]
111    fn test_dtype_mismatch_rejection() {
112        // Create an Int32 buffer
113        let buffer = Buffer::new_cpu_zeros(50, DType::I32);
114        
115        // Attempt to extract it as Float32 (Should fail gracefully, not panic)
116        let result = buffer.as_f32_slice();
117        
118        assert!(result.is_err());
119        if let Err(EtensorError::DTypeMismatch { expected, got }) = result {
120            assert_eq!(expected, "float32");
121            assert_eq!(got, "other CPU dtype");
122        } else {
123            panic!("Expected DTypeMismatch error!");
124        }
125    }
126
127    #[test]
128    fn test_arc_cloning_is_zero_copy() {
129        let buffer_a = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0]);
130        
131        // Initially, there is 1 pointer to this memory.
132        assert_eq!(buffer_a.strong_count().unwrap(), 1);
133        
134        // Clone the buffer explicitly (Simulating a Tensor view creation)
135        let buffer_b = buffer_a.clone();
136        
137        // The data hasn't been duplicated. Instead, the pointer count is now 2.
138        assert_eq!(buffer_a.strong_count().unwrap(), 2);
139        assert_eq!(buffer_b.strong_count().unwrap(), 2);
140        
141        // Ensure both buffers read the exact same underlying memory
142        let slice_a = buffer_a.as_f32_slice().unwrap();
143        let slice_b = buffer_b.as_f32_slice().unwrap();
144        
145        assert_eq!(slice_a[1], 2.0);
146        assert_eq!(slice_b[1], 2.0);
147    }
148}