etensor_core/shape.rs
1//! Spatial geometry and multi-dimensional index mapping for ETensor.
2
3/// Represents the dimensional layout and physical memory strides of a tensor.
4///
5/// By separating dimensions from strides, ETensor can perform zero-copy view
6/// manipulations (like transposing or slicing) without moving physical data.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct Shape {
9 /// The mathematical dimensions of the tensor (e.g., [Batch, Channels, Height, Width]).
10 pub dims: Vec<usize>,
11 /// The number of physical memory elements to skip to reach the next element in a given dimension.
12 pub strides: Vec<usize>,
13}
14
15impl Shape {
16 /// Creates a new, contiguous shape layout.
17 ///
18 /// Automatically calculates the default row-major (C-contiguous) strides
19 /// required to map multi-dimensional coordinates to a flat 1D memory buffer.
20 pub fn new(dims: Vec<usize>) -> Self {
21 let mut strides = vec![0; dims.len()];
22 let mut current_stride = 1;
23
24 // Strides are calculated from the innermost dimension (right) to the outermost (left).
25 for i in (0..dims.len()).rev() {
26 strides[i] = current_stride;
27 current_stride *= dims[i];
28 }
29
30 Self { dims, strides }
31 }
32
33 /// Transposes the last two dimensions in O(1) time.
34 ///
35 /// This returns a new `Shape` representation with swapped dimensions and strides,
36 /// explicitly avoiding any duplication of the underlying physical buffer.
37 pub fn transpose(&self) -> Self {
38 let mut new_dims = self.dims.clone();
39 let mut new_strides = self.strides.clone();
40 let rank = new_dims.len();
41
42 // A tensor must have at least 2 dimensions to be transposed.
43 if rank >= 2 {
44 new_dims.swap(rank - 1, rank - 2);
45 new_strides.swap(rank - 1, rank - 2);
46 }
47
48 Self { dims: new_dims, strides: new_strides }
49 }
50
51 /// Returns the total number of dimensions (rank) of the tensor.
52 pub fn rank(&self) -> usize {
53 self.dims.len()
54 }
55
56 /// Returns the total number of mathematical elements contained in the shape.
57 pub fn num_elements(&self) -> usize {
58 self.dims.iter().product()
59 }
60
61 /// Verifies if the underlying physical memory layout perfectly matches the mathematical layout.
62 ///
63 /// This is crucial for backend hardware execution, as many optimized CUDA or CPU
64 /// BLAS kernels require contiguous memory arrays to function correctly.
65 pub fn is_contiguous(&self) -> bool {
66 let mut expected_stride = 1;
67 for i in (0..self.dims.len()).rev() {
68 if self.strides[i] != expected_stride {
69 return false;
70 }
71 expected_stride *= self.dims[i];
72 }
73 true
74 }
75}
76
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_stride_calculation_c_contiguous() {
84 // Example: A 2x3 matrix
85 // Dims: [2, 3]
86 // Expected Strides: [3, 1] (Jump 3 elements to go down a row, 1 element to go right)
87 let shape = Shape::new(vec![2, 3]);
88 assert_eq!(shape.strides, vec![3, 1]);
89 assert!(shape.is_contiguous());
90 }
91
92 #[test]
93 fn test_zero_copy_transpose() {
94 let shape = Shape::new(vec![2, 3]);
95 let transposed = shape.transpose();
96
97 // Mathematical dimensions are swapped
98 assert_eq!(transposed.dims, vec![3, 2]);
99 // Strides are swapped (memory does not move!)
100 assert_eq!(transposed.strides, vec![1, 3]);
101 // A transposed matrix is no longer contiguous in memory
102 assert!(!transposed.is_contiguous());
103 }
104
105 #[test]
106 fn test_rank_and_elements() {
107 let shape = Shape::new(vec![2, 3, 4]);
108 assert_eq!(shape.rank(), 3);
109 assert_eq!(shape.num_elements(), 24);
110 }
111
112 #[test]
113 fn test_transpose_fallback_on_1d() {
114 // Transposing a 1D vector should safely return the same layout without crashing
115 let shape = Shape::new(vec![5]);
116 let transposed = shape.transpose();
117
118 assert_eq!(transposed.dims, vec![5]);
119 assert_eq!(transposed.strides, vec![1]);
120 assert!(transposed.is_contiguous());
121 }
122}