Skip to main content

trueno/backends/gpu/tensor_view/
mod.rs

1//! TensorView - GPU Memory Layout Abstraction
2//!
3//! Provides a view into GPU buffer memory with shape, stride, and layout information.
4//! Enables zero-copy slicing and transposition operations.
5//!
6//! # cuda-tile-behavior.md References
7//!
8//! - Section 3.2: Two-Level Memory Hierarchy
9//! - Falsification tests #31-40: TensorView correctness
10//!
11//! # Academic Foundation
12//!
13//! Based on Halide (PLDI 2013): Schedule/algorithm separation improves portability.
14
15use std::marker::PhantomData;
16use std::ops::Range;
17
18/// Memory layout for tensor storage
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum MemoryLayout {
21    /// Row-major (C-style): last dimension varies fastest
22    #[default]
23    RowMajor,
24    /// Column-major (Fortran-style): first dimension varies fastest
25    ColumnMajor,
26    /// Tiled layout for GPU shared memory optimization
27    Tiled {
28        /// Tile dimensions
29        tile_size: [usize; 2],
30    },
31}
32
33/// A view into a contiguous memory region with shape and stride information.
34///
35/// TensorView does not own the data - it provides a structured view over
36/// existing memory, enabling zero-copy operations like slicing and transposition.
37///
38/// # Type Parameters
39///
40/// * `T` - Element type (typically f32 for GPU compute)
41///
42/// # cuda-tile-behavior.md References
43///
44/// - Falsification test #31: TensorView preserves data integrity
45/// - Falsification test #32: Slicing produces correct views
46/// - Falsification test #33: Transpose swaps dimensions correctly
47#[derive(Debug)]
48pub struct TensorView<T> {
49    /// Shape of the tensor (up to 4 dimensions: N, C, H, W)
50    shape: [usize; 4],
51    /// Strides for each dimension (in elements, not bytes)
52    strides: [usize; 4],
53    /// Offset from the start of the buffer (in elements)
54    offset: usize,
55    /// Memory layout hint for optimization
56    layout: MemoryLayout,
57    /// Number of active dimensions (1-4)
58    ndim: usize,
59    /// Phantom data for type safety
60    _marker: PhantomData<T>,
61}
62
63impl<T> TensorView<T> {
64    /// Create a new TensorView with the given shape.
65    ///
66    /// Strides are computed automatically based on row-major layout.
67    ///
68    /// # Arguments
69    ///
70    /// * `shape` - Shape of the tensor (unused dimensions should be 1)
71    ///
72    /// # Examples
73    ///
74    /// ```ignore
75    /// let view = TensorView::<f32>::new([2, 3, 4, 1]); // 2x3x4 tensor
76    /// assert_eq!(view.numel(), 24);
77    /// ```
78    pub fn new(shape: [usize; 4]) -> Self {
79        let ndim = Self::compute_ndim(&shape);
80        let strides = Self::compute_row_major_strides(&shape);
81        Self {
82            shape,
83            strides,
84            offset: 0,
85            layout: MemoryLayout::RowMajor,
86            ndim,
87            _marker: PhantomData,
88        }
89    }
90
91    /// Create a TensorView with explicit strides.
92    ///
93    /// # Arguments
94    ///
95    /// * `shape` - Shape of the tensor
96    /// * `strides` - Strides for each dimension (in elements)
97    pub fn with_strides(shape: [usize; 4], strides: [usize; 4]) -> Self {
98        let ndim = Self::compute_ndim(&shape);
99        Self {
100            shape,
101            strides,
102            offset: 0,
103            layout: MemoryLayout::RowMajor,
104            ndim,
105            _marker: PhantomData,
106        }
107    }
108
109    /// Create a 1D TensorView.
110    pub fn new_1d(len: usize) -> Self {
111        Self::new([len, 1, 1, 1])
112    }
113
114    /// Create a 2D TensorView (matrix).
115    pub fn new_2d(rows: usize, cols: usize) -> Self {
116        Self::new([rows, cols, 1, 1])
117    }
118
119    /// Create a 3D TensorView.
120    pub fn new_3d(d0: usize, d1: usize, d2: usize) -> Self {
121        Self::new([d0, d1, d2, 1])
122    }
123
124    /// Create a 4D TensorView.
125    pub fn new_4d(d0: usize, d1: usize, d2: usize, d3: usize) -> Self {
126        Self::new([d0, d1, d2, d3])
127    }
128
129    /// Get the shape of the tensor.
130    pub fn shape(&self) -> &[usize; 4] {
131        &self.shape
132    }
133
134    /// Get the strides of the tensor.
135    pub fn strides(&self) -> &[usize; 4] {
136        &self.strides
137    }
138
139    /// Get the offset from the start of the buffer.
140    pub fn offset(&self) -> usize {
141        self.offset
142    }
143
144    /// Get the memory layout.
145    pub fn layout(&self) -> MemoryLayout {
146        self.layout
147    }
148
149    /// Get the number of active dimensions.
150    pub fn ndim(&self) -> usize {
151        self.ndim
152    }
153
154    /// Get the total number of elements.
155    pub fn numel(&self) -> usize {
156        self.shape.iter().product()
157    }
158
159    /// Check if the tensor is contiguous in memory.
160    ///
161    /// A tensor is contiguous if elements are stored without gaps
162    /// in row-major order.
163    pub fn is_contiguous(&self) -> bool {
164        let expected_strides = Self::compute_row_major_strides(&self.shape);
165        self.strides == expected_strides
166    }
167
168    /// Check if the tensor is empty (has zero elements).
169    pub fn is_empty(&self) -> bool {
170        self.numel() == 0
171    }
172
173    /// Get dimension size at the given index.
174    ///
175    /// # Panics
176    ///
177    /// Panics if `dim >= 4`.
178    pub fn dim(&self, dim: usize) -> usize {
179        self.shape[dim]
180    }
181
182    /// Get stride at the given dimension.
183    ///
184    /// # Panics
185    ///
186    /// Panics if `dim >= 4`.
187    pub fn stride(&self, dim: usize) -> usize {
188        self.strides[dim]
189    }
190
191    /// Create a slice of this tensor along the first dimension.
192    ///
193    /// # Arguments
194    ///
195    /// * `range` - Range of indices to include
196    ///
197    /// # Returns
198    ///
199    /// A new TensorView representing the slice.
200    ///
201    /// # cuda-tile-behavior.md References
202    ///
203    /// - Falsification test #32: Slicing produces correct views
204    pub fn slice(&self, range: Range<usize>) -> Self {
205        assert!(range.end <= self.shape[0], "Slice range out of bounds");
206        let new_offset = self.offset + range.start * self.strides[0];
207        let mut new_shape = self.shape;
208        new_shape[0] = range.end - range.start;
209
210        Self {
211            shape: new_shape,
212            strides: self.strides,
213            offset: new_offset,
214            layout: self.layout,
215            ndim: self.ndim,
216            _marker: PhantomData,
217        }
218    }
219
220    /// Create a slice along a specific dimension.
221    ///
222    /// # Arguments
223    ///
224    /// * `dim` - Dimension to slice along
225    /// * `range` - Range of indices to include
226    pub fn slice_dim(&self, dim: usize, range: Range<usize>) -> Self {
227        assert!(dim < 4, "Dimension out of bounds");
228        assert!(range.end <= self.shape[dim], "Slice range out of bounds");
229
230        let new_offset = self.offset + range.start * self.strides[dim];
231        let mut new_shape = self.shape;
232        new_shape[dim] = range.end - range.start;
233
234        Self {
235            shape: new_shape,
236            strides: self.strides,
237            offset: new_offset,
238            layout: self.layout,
239            ndim: self.ndim,
240            _marker: PhantomData,
241        }
242    }
243
244    /// Transpose the tensor by swapping two dimensions.
245    ///
246    /// # Arguments
247    ///
248    /// * `dim0` - First dimension to swap
249    /// * `dim1` - Second dimension to swap
250    ///
251    /// # Returns
252    ///
253    /// A new TensorView with swapped dimensions.
254    ///
255    /// # cuda-tile-behavior.md References
256    ///
257    /// - Falsification test #33: Transpose swaps dimensions correctly
258    pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
259        assert!(dim0 < 4 && dim1 < 4, "Dimension out of bounds");
260
261        let mut new_shape = self.shape;
262        let mut new_strides = self.strides;
263        new_shape.swap(dim0, dim1);
264        new_strides.swap(dim0, dim1);
265
266        Self {
267            shape: new_shape,
268            strides: new_strides,
269            offset: self.offset,
270            layout: self.layout,
271            ndim: self.ndim,
272            _marker: PhantomData,
273        }
274    }
275
276    /// Reshape the tensor to a new shape.
277    ///
278    /// # Arguments
279    ///
280    /// * `new_shape` - New shape (must have same number of elements)
281    ///
282    /// # Returns
283    ///
284    /// A new TensorView with the new shape, or None if reshape is invalid.
285    pub fn reshape(&self, new_shape: [usize; 4]) -> Option<Self> {
286        let new_numel: usize = new_shape.iter().product();
287        if new_numel != self.numel() {
288            return None;
289        }
290
291        // Reshape requires contiguous memory
292        if !self.is_contiguous() {
293            return None;
294        }
295
296        Some(Self::new(new_shape))
297    }
298
299    /// Squeeze dimensions of size 1.
300    ///
301    /// Returns a view with all size-1 dimensions removed.
302    pub fn squeeze(&self) -> Self {
303        let mut new_shape = [1usize; 4];
304        let mut new_strides = [1usize; 4];
305        let mut new_ndim = 0;
306
307        for i in 0..4 {
308            if self.shape[i] > 1 {
309                new_shape[new_ndim] = self.shape[i];
310                new_strides[new_ndim] = self.strides[i];
311                new_ndim += 1;
312            }
313        }
314
315        // If all dimensions were 1, keep at least one
316        if new_ndim == 0 {
317            new_ndim = 1;
318        }
319
320        Self {
321            shape: new_shape,
322            strides: new_strides,
323            offset: self.offset,
324            layout: self.layout,
325            ndim: new_ndim,
326            _marker: PhantomData,
327        }
328    }
329
330    /// Unsqueeze: add a dimension of size 1 at the specified position.
331    ///
332    /// # Arguments
333    ///
334    /// * `dim` - Position to insert the new dimension
335    pub fn unsqueeze(&self, dim: usize) -> Option<Self> {
336        if dim > self.ndim || self.ndim >= 4 {
337            return None;
338        }
339
340        let mut new_shape = [1usize; 4];
341        let mut new_strides = [1usize; 4];
342
343        // Copy dimensions before the insertion point
344        // Using manual loop since we're copying from two separate arrays to two separate arrays
345        #[allow(clippy::manual_memcpy)]
346        for i in 0..dim {
347            new_shape[i] = self.shape[i];
348            new_strides[i] = self.strides[i];
349        }
350
351        // Insert the new dimension
352        new_shape[dim] = 1;
353        new_strides[dim] = if dim < self.ndim { self.strides[dim] * self.shape[dim] } else { 1 };
354
355        // Copy remaining dimensions (offset by 1 for insertion)
356        #[allow(clippy::manual_memcpy)]
357        for i in dim..self.ndim {
358            new_shape[i + 1] = self.shape[i];
359            new_strides[i + 1] = self.strides[i];
360        }
361
362        Some(Self {
363            shape: new_shape,
364            strides: new_strides,
365            offset: self.offset,
366            layout: self.layout,
367            ndim: self.ndim + 1,
368            _marker: PhantomData,
369        })
370    }
371
372    /// Set the memory layout hint.
373    pub fn with_layout(mut self, layout: MemoryLayout) -> Self {
374        self.layout = layout;
375        self
376    }
377
378    /// Compute linear index from multi-dimensional indices.
379    ///
380    /// # Arguments
381    ///
382    /// * `indices` - Array of indices for each dimension
383    ///
384    /// # Returns
385    ///
386    /// Linear offset into the underlying buffer.
387    pub fn linear_index(&self, indices: [usize; 4]) -> usize {
388        self.offset
389            + indices[0] * self.strides[0]
390            + indices[1] * self.strides[1]
391            + indices[2] * self.strides[2]
392            + indices[3] * self.strides[3]
393    }
394
395    /// Compute row-major strides for a given shape.
396    fn compute_row_major_strides(shape: &[usize; 4]) -> [usize; 4] {
397        let mut strides = [1usize; 4];
398        // Strides: s[i] = product of shape[i+1..4]
399        strides[3] = 1;
400        strides[2] = shape[3];
401        strides[1] = shape[3] * shape[2];
402        strides[0] = shape[3] * shape[2] * shape[1];
403        strides
404    }
405
406    /// Compute the number of active dimensions.
407    fn compute_ndim(shape: &[usize; 4]) -> usize {
408        // Count from the end: find last dimension > 1
409        for i in (0..4).rev() {
410            if shape[i] > 1 {
411                return i + 1;
412            }
413        }
414        1 // At least 1 dimension
415    }
416}
417
418impl<T> Clone for TensorView<T> {
419    fn clone(&self) -> Self {
420        Self {
421            shape: self.shape,
422            strides: self.strides,
423            offset: self.offset,
424            layout: self.layout,
425            ndim: self.ndim,
426            _marker: PhantomData,
427        }
428    }
429}
430
431impl<T> Default for TensorView<T> {
432    fn default() -> Self {
433        Self::new([1, 1, 1, 1])
434    }
435}
436
437#[cfg(test)]
438mod tests;