etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! The Universal Memory Box.
//! 
//! The `Buffer` enum wraps physical hardware memory allocations. It strictly 
//! acts as a flat, contiguous data container. All spatial geometry (dimensions 
//! and strides) is explicitly managed by the `Shape` struct in `src/shape.rs`.

use std::sync::Arc;
use crate::dtypes::DType;
use crate::errors::{EtensorError, EtensorResult};

/// Contiguous CPU memory blocks wrapped in an Atomic Reference Counter (Arc) 
/// to ensure zero-copy graph sharing and thread-safe execution.
#[derive(Debug, Clone)]
pub enum CpuBuffer {
    F32(Arc<Vec<f32>>),
    I32(Arc<Vec<i32>>),
    I8(Arc<Vec<i8>>),
    // F16 and BF16 are stored as raw u16 bits on the CPU to avoid forcing an external 
    // half-precision crate dependency until mathematically required by kernels.k
    F16(Arc<Vec<u16>>),
    BF16(Arc<Vec<u16>>),
}

/// The hardware-agnostic physical data container.
#[derive(Debug, Clone)]
pub enum Buffer {
    Cpu(CpuBuffer),
    
    // FUTURE IMPLEMENTATION: Wave 2 (Surgical VRAM)
    // Wrapped in a cfg flag so standard users don't need the CUDA Toolkit installed.
    #[cfg(feature = "cuda-native")]
    CudaNative, // Stub placeholder for cudarc::driver::CudaSlice<u8>
    
    // FUTURE IMPLEMENTATION: Wave 3 (Enterprise LibTorch)
    #[cfg(feature = "torch")]
    CudaTorch, // Stub placeholder for tch::Tensor
}

impl Buffer {
    /// Allocates a new contiguous CPU Buffer initialized with zeros.
    pub fn new_cpu_zeros(size: usize, dtype: DType) -> Self {
        let cpu_buf = match dtype {
            DType::F32 => CpuBuffer::F32(Arc::new(vec![0.0; size])),
            DType::I32 => CpuBuffer::I32(Arc::new(vec![0; size])),
            DType::I8 => CpuBuffer::I8(Arc::new(vec![0; size])),
            DType::F16 => CpuBuffer::F16(Arc::new(vec![0; size])),
            DType::BF16 => CpuBuffer::BF16(Arc::new(vec![0; size])),
        };
        Buffer::Cpu(cpu_buf)
    }

    /// Allocates a new CPU Buffer from an existing Vec of f32 data.
    pub fn from_f32_vec(data: Vec<f32>) -> Self {
        Buffer::Cpu(CpuBuffer::F32(Arc::new(data)))
    }

    /// Safely extracts a reference to the underlying CPU f32 slice.
    /// Acts as an explicit boundary guard against DataType or Device mismatches.
    pub fn as_f32_slice(&self) -> EtensorResult<&[f32]> {
        match self {
            Buffer::Cpu(CpuBuffer::F32(arc_vec)) => Ok(arc_vec.as_slice()),
            Buffer::Cpu(_) => Err(EtensorError::DTypeMismatch {
                expected: "float32".to_string(),
                got: "other CPU dtype".to_string(),
            }),
            #[cfg(feature = "cuda-native")]
            Buffer::CudaNative => Err(EtensorError::DeviceMismatch {
                expected: "cpu".to_string(),
                got: "cuda_native".to_string(),
            }),
            #[cfg(feature = "torch")]
            Buffer::CudaTorch => Err(EtensorError::DeviceMismatch {
                expected: "cpu".to_string(),
                got: "cuda_torch".to_string(),
            }),
        }
    }

    /// Returns the active Arc pointer reference count for memory tracking.
    pub fn strong_count(&self) -> EtensorResult<usize> {
        match self {
            Buffer::Cpu(CpuBuffer::F32(arc)) => Ok(Arc::strong_count(arc)),
            Buffer::Cpu(CpuBuffer::I32(arc)) => Ok(Arc::strong_count(arc)),
            Buffer::Cpu(CpuBuffer::I8(arc)) => Ok(Arc::strong_count(arc)),
            Buffer::Cpu(CpuBuffer::F16(arc)) => Ok(Arc::strong_count(arc)),
            Buffer::Cpu(CpuBuffer::BF16(arc)) => Ok(Arc::strong_count(arc)),
            #[cfg(feature = "cuda-native")]
            Buffer::CudaNative => Err(EtensorError::InternalError("Arc count unsupported on CUDA".to_string())),
            #[cfg(feature = "torch")]
            Buffer::CudaTorch => Err(EtensorError::InternalError("Arc count unsupported on Torch".to_string())),
        }
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cpu_buffer_initialization() {
        let buffer = Buffer::new_cpu_zeros(100, DType::F32);
        let slice = buffer.as_f32_slice().unwrap();
        
        assert_eq!(slice.len(), 100);
        assert_eq!(slice[0], 0.0);
        assert_eq!(slice[99], 0.0);
    }

    #[test]
    fn test_dtype_mismatch_rejection() {
        // Create an Int32 buffer
        let buffer = Buffer::new_cpu_zeros(50, DType::I32);
        
        // Attempt to extract it as Float32 (Should fail gracefully, not panic)
        let result = buffer.as_f32_slice();
        
        assert!(result.is_err());
        if let Err(EtensorError::DTypeMismatch { expected, got }) = result {
            assert_eq!(expected, "float32");
            assert_eq!(got, "other CPU dtype");
        } else {
            panic!("Expected DTypeMismatch error!");
        }
    }

    #[test]
    fn test_arc_cloning_is_zero_copy() {
        let buffer_a = Buffer::from_f32_vec(vec![1.0, 2.0, 3.0]);
        
        // Initially, there is 1 pointer to this memory.
        assert_eq!(buffer_a.strong_count().unwrap(), 1);
        
        // Clone the buffer explicitly (Simulating a Tensor view creation)
        let buffer_b = buffer_a.clone();
        
        // The data hasn't been duplicated. Instead, the pointer count is now 2.
        assert_eq!(buffer_a.strong_count().unwrap(), 2);
        assert_eq!(buffer_b.strong_count().unwrap(), 2);
        
        // Ensure both buffers read the exact same underlying memory
        let slice_a = buffer_a.as_f32_slice().unwrap();
        let slice_b = buffer_b.as_f32_slice().unwrap();
        
        assert_eq!(slice_a[1], 2.0);
        assert_eq!(slice_b[1], 2.0);
    }
}