Skip to main content

etensor_core/
shape.rs

1//! Spatial geometry and multi-dimensional index mapping for ETensor.
2
3/// Represents the dimensional layout and physical memory strides of a tensor.
4/// 
5/// By separating dimensions from strides, ETensor can perform zero-copy view 
6/// manipulations (like transposing or slicing) without moving physical data.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Shape {
9    /// The mathematical dimensions of the tensor (e.g., [Batch, Channels, Height, Width]).
10    pub dims: Vec<usize>,
11    /// The number of physical memory elements to skip to reach the next element in a given dimension.
12    pub strides: Vec<usize>,
13}
14
15impl Shape {
16    /// Creates a new, contiguous shape layout.
17    /// 
18    /// Automatically calculates the default row-major (C-contiguous) strides 
19    /// required to map multi-dimensional coordinates to a flat 1D memory buffer.
20    pub fn new(dims: Vec<usize>) -> Self {
21        let mut strides = vec![0; dims.len()];
22        let mut current_stride = 1;
23        
24        // Strides are calculated from the innermost dimension (right) to the outermost (left).
25        for i in (0..dims.len()).rev() {
26            strides[i] = current_stride;
27            current_stride *= dims[i];
28        }
29
30        Self { dims, strides }
31    }
32
33    /// Transposes the last two dimensions in O(1) time.
34    /// 
35    /// This returns a new `Shape` representation with swapped dimensions and strides, 
36    /// explicitly avoiding any duplication of the underlying physical buffer.
37    pub fn transpose(&self) -> Self {
38        let mut new_dims = self.dims.clone();
39        let mut new_strides = self.strides.clone();
40        let rank = new_dims.len();
41        
42        // A tensor must have at least 2 dimensions to be transposed.
43        if rank >= 2 {
44            new_dims.swap(rank - 1, rank - 2);
45            new_strides.swap(rank - 1, rank - 2);
46        }
47        
48        Self { dims: new_dims, strides: new_strides }
49    }
50
51    /// Returns the total number of dimensions (rank) of the tensor.
52    pub fn rank(&self) -> usize {
53        self.dims.len()
54    }
55
56    /// Returns the total number of mathematical elements contained in the shape.
57    pub fn num_elements(&self) -> usize {
58        self.dims.iter().product()
59    }
60
61    /// Verifies if the underlying physical memory layout perfectly matches the mathematical layout.
62    /// 
63    /// This is crucial for backend hardware execution, as many optimized CUDA or CPU 
64    /// BLAS kernels require contiguous memory arrays to function correctly.
65    pub fn is_contiguous(&self) -> bool {
66        let mut expected_stride = 1;
67        for i in (0..self.dims.len()).rev() {
68            if self.strides[i] != expected_stride {
69                return false;
70            }
71            expected_stride *= self.dims[i];
72        }
73        true
74    }
75}
76
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn test_stride_calculation_c_contiguous() {
84        // Example: A 2x3 matrix
85        // Dims: [2, 3]
86        // Expected Strides: [3, 1] (Jump 3 elements to go down a row, 1 element to go right)
87        let shape = Shape::new(vec![2, 3]);
88        assert_eq!(shape.strides, vec![3, 1]);
89        assert!(shape.is_contiguous());
90    }
91
92    #[test]
93    fn test_zero_copy_transpose() {
94        let shape = Shape::new(vec![2, 3]);
95        let transposed = shape.transpose();
96        
97        // Mathematical dimensions are swapped
98        assert_eq!(transposed.dims, vec![3, 2]);
99        // Strides are swapped (memory does not move!)
100        assert_eq!(transposed.strides, vec![1, 3]);
101        // A transposed matrix is no longer contiguous in memory
102        assert!(!transposed.is_contiguous());
103    }
104
105    #[test]
106    fn test_rank_and_elements() {
107        let shape = Shape::new(vec![2, 3, 4]);
108        assert_eq!(shape.rank(), 3);
109        assert_eq!(shape.num_elements(), 24);
110    }
111
112    #[test]
113    fn test_transpose_fallback_on_1d() {
114        // Transposing a 1D vector should safely return the same layout without crashing
115        let shape = Shape::new(vec![5]);
116        let transposed = shape.transpose();
117        
118        assert_eq!(transposed.dims, vec![5]);
119        assert_eq!(transposed.strides, vec![1]);
120        assert!(transposed.is_contiguous());
121    }
122}