etensor-core 0.0.1

The pure Rust tensor math and autograd engine
Documentation
//! Spatial geometry and multi-dimensional index mapping for ETensor.

/// Represents the dimensional layout and physical memory strides of a tensor.
/// 
/// By separating dimensions from strides, ETensor can perform zero-copy view 
/// manipulations (like transposing or slicing) without moving physical data.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Shape {
    /// The mathematical dimensions of the tensor (e.g., [Batch, Channels, Height, Width]).
    pub dims: Vec<usize>,
    /// The number of physical memory elements to skip to reach the next element in a given dimension.
    pub strides: Vec<usize>,
}

impl Shape {
    /// Creates a new, contiguous shape layout.
    /// 
    /// Automatically calculates the default row-major (C-contiguous) strides 
    /// required to map multi-dimensional coordinates to a flat 1D memory buffer.
    pub fn new(dims: Vec<usize>) -> Self {
        let mut strides = vec![0; dims.len()];
        let mut current_stride = 1;
        
        // Strides are calculated from the innermost dimension (right) to the outermost (left).
        for i in (0..dims.len()).rev() {
            strides[i] = current_stride;
            current_stride *= dims[i];
        }

        Self { dims, strides }
    }

    /// Transposes the last two dimensions in O(1) time.
    /// 
    /// This returns a new `Shape` representation with swapped dimensions and strides, 
    /// explicitly avoiding any duplication of the underlying physical buffer.
    pub fn transpose(&self) -> Self {
        let mut new_dims = self.dims.clone();
        let mut new_strides = self.strides.clone();
        let rank = new_dims.len();
        
        // A tensor must have at least 2 dimensions to be transposed.
        if rank >= 2 {
            new_dims.swap(rank - 1, rank - 2);
            new_strides.swap(rank - 1, rank - 2);
        }
        
        Self { dims: new_dims, strides: new_strides }
    }

    /// Returns the total number of dimensions (rank) of the tensor.
    pub fn rank(&self) -> usize {
        self.dims.len()
    }

    /// Returns the total number of mathematical elements contained in the shape.
    pub fn num_elements(&self) -> usize {
        self.dims.iter().product()
    }

    /// Verifies if the underlying physical memory layout perfectly matches the mathematical layout.
    /// 
    /// This is crucial for backend hardware execution, as many optimized CUDA or CPU 
    /// BLAS kernels require contiguous memory arrays to function correctly.
    pub fn is_contiguous(&self) -> bool {
        let mut expected_stride = 1;
        for i in (0..self.dims.len()).rev() {
            if self.strides[i] != expected_stride {
                return false;
            }
            expected_stride *= self.dims[i];
        }
        true
    }
}


#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_stride_calculation_c_contiguous() {
        // Example: A 2x3 matrix
        // Dims: [2, 3]
        // Expected Strides: [3, 1] (Jump 3 elements to go down a row, 1 element to go right)
        let shape = Shape::new(vec![2, 3]);
        assert_eq!(shape.strides, vec![3, 1]);
        assert!(shape.is_contiguous());
    }

    #[test]
    fn test_zero_copy_transpose() {
        let shape = Shape::new(vec![2, 3]);
        let transposed = shape.transpose();
        
        // Mathematical dimensions are swapped
        assert_eq!(transposed.dims, vec![3, 2]);
        // Strides are swapped (memory does not move!)
        assert_eq!(transposed.strides, vec![1, 3]);
        // A transposed matrix is no longer contiguous in memory
        assert!(!transposed.is_contiguous());
    }

    #[test]
    fn test_rank_and_elements() {
        let shape = Shape::new(vec![2, 3, 4]);
        assert_eq!(shape.rank(), 3);
        assert_eq!(shape.num_elements(), 24);
    }

    #[test]
    fn test_transpose_fallback_on_1d() {
        // Transposing a 1D vector should safely return the same layout without crashing
        let shape = Shape::new(vec![5]);
        let transposed = shape.transpose();
        
        assert_eq!(transposed.dims, vec![5]);
        assert_eq!(transposed.strides, vec![1]);
        assert!(transposed.is_contiguous());
    }
}