etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! Hardware abstraction and execution routing identifiers.

use crate::errors::{EtensorError, EtensorResult};

/// Identifies the hardware execution context where a tensor's memory physically resides.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Device {
    /// Standard host memory (RAM) executed via generic CPU backends (ndarray/Rayon).
    Cpu,
    
    /// Pure Rust PTX JIT/AOT execution via the `cudarc` driver.
    /// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
    #[cfg(feature = "cuda-native")]
    CudaNative(usize),

    /// Enterprise C++ PyTorch bindings execution via `tch-rs`.
    /// The inner value represents the ordinal GPU index (e.g., 0 for "cuda:0").
    #[cfg(feature = "torch")]
    CudaTorch(usize),
}

impl Device {
    /// Returns true if the device represents host CPU memory.
    pub fn is_cpu(&self) -> bool {
        matches!(self, Device::Cpu)
    }

    /// Safely extracts the hardware index for multi-GPU setups.
    /// 
    /// Returns an error if called on a CPU device, enforcing strict boundary checks
    /// before passing layout configurations to hardware-specific allocators.
    pub fn index(&self) -> EtensorResult<usize> {
        match self {
            Device::Cpu => Err(EtensorError::InternalError(
                "Attempted to extract a GPU index from a CPU device marker.".to_string(),
            )),
            #[cfg(feature = "cuda-native")]
            Device::CudaNative(idx) => Ok(*idx),
            #[cfg(feature = "torch")]
            Device::CudaTorch(idx) => Ok(*idx),
        }
    }
}

impl std::fmt::Display for Device {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Device::Cpu => write!(f, "cpu"),
            #[cfg(feature = "cuda-native")]
            Device::CudaNative(idx) => write!(f, "cuda_native:{}", idx),
            #[cfg(feature = "torch")]
            Device::CudaTorch(idx) => write!(f, "cuda_torch:{}", idx),
        }
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cpu_device_properties() {
        let dev = Device::Cpu;
        
        assert!(dev.is_cpu());
        assert_eq!(dev.to_string(), "cpu");
        
        // Attempting to extract an index from a CPU should trigger our boundary shield
        let idx_result = dev.index();
        assert!(idx_result.is_err());
        
        if let Err(EtensorError::InternalError(msg)) = idx_result {
            assert!(msg.contains("CPU device marker"));
        } else {
            panic!("Expected InternalError for CPU device index extraction!");
        }
    }

    // These tests will be silently skipped by `cargo test` unless you explicitly 
    // run it with `--features cuda-native` or `--features torch`.
    
    #[cfg(feature = "cuda-native")]
    #[test]
    fn test_cuda_native_properties() {
        let dev = Device::CudaNative(1); // Simulating GPU 1
        assert!(!dev.is_cpu());
        assert_eq!(dev.index().unwrap(), 1);
        assert_eq!(dev.to_string(), "cuda_native:1");
    }

    #[cfg(feature = "torch")]
    #[test]
    fn test_cuda_torch_properties() {
        let dev = Device::CudaTorch(0); // Simulating GPU 0
        assert!(!dev.is_cpu());
        assert_eq!(dev.index().unwrap(), 0);
        assert_eq!(dev.to_string(), "cuda_torch:0");
    }
}