trueno/backends/gpu/tensor_view/mod.rs
1//! TensorView - GPU Memory Layout Abstraction
2//!
3//! Provides a view into GPU buffer memory with shape, stride, and layout information.
4//! Enables zero-copy slicing and transposition operations.
5//!
6//! # cuda-tile-behavior.md References
7//!
8//! - Section 3.2: Two-Level Memory Hierarchy
9//! - Falsification tests #31-40: TensorView correctness
10//!
11//! # Academic Foundation
12//!
13//! Based on Halide (PLDI 2013): Schedule/algorithm separation improves portability.
14
15use std::marker::PhantomData;
16use std::ops::Range;
17
18/// Memory layout for tensor storage
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
20pub enum MemoryLayout {
21 /// Row-major (C-style): last dimension varies fastest
22 #[default]
23 RowMajor,
24 /// Column-major (Fortran-style): first dimension varies fastest
25 ColumnMajor,
26 /// Tiled layout for GPU shared memory optimization
27 Tiled {
28 /// Tile dimensions
29 tile_size: [usize; 2],
30 },
31}
32
33/// A view into a contiguous memory region with shape and stride information.
34///
35/// TensorView does not own the data - it provides a structured view over
36/// existing memory, enabling zero-copy operations like slicing and transposition.
37///
38/// # Type Parameters
39///
40/// * `T` - Element type (typically f32 for GPU compute)
41///
42/// # cuda-tile-behavior.md References
43///
44/// - Falsification test #31: TensorView preserves data integrity
45/// - Falsification test #32: Slicing produces correct views
46/// - Falsification test #33: Transpose swaps dimensions correctly
47#[derive(Debug)]
48pub struct TensorView<T> {
49 /// Shape of the tensor (up to 4 dimensions: N, C, H, W)
50 shape: [usize; 4],
51 /// Strides for each dimension (in elements, not bytes)
52 strides: [usize; 4],
53 /// Offset from the start of the buffer (in elements)
54 offset: usize,
55 /// Memory layout hint for optimization
56 layout: MemoryLayout,
57 /// Number of active dimensions (1-4)
58 ndim: usize,
59 /// Phantom data for type safety
60 _marker: PhantomData<T>,
61}
62
63impl<T> TensorView<T> {
64 /// Create a new TensorView with the given shape.
65 ///
66 /// Strides are computed automatically based on row-major layout.
67 ///
68 /// # Arguments
69 ///
70 /// * `shape` - Shape of the tensor (unused dimensions should be 1)
71 ///
72 /// # Examples
73 ///
74 /// ```ignore
75 /// let view = TensorView::<f32>::new([2, 3, 4, 1]); // 2x3x4 tensor
76 /// assert_eq!(view.numel(), 24);
77 /// ```
78 pub fn new(shape: [usize; 4]) -> Self {
79 let ndim = Self::compute_ndim(&shape);
80 let strides = Self::compute_row_major_strides(&shape);
81 Self {
82 shape,
83 strides,
84 offset: 0,
85 layout: MemoryLayout::RowMajor,
86 ndim,
87 _marker: PhantomData,
88 }
89 }
90
91 /// Create a TensorView with explicit strides.
92 ///
93 /// # Arguments
94 ///
95 /// * `shape` - Shape of the tensor
96 /// * `strides` - Strides for each dimension (in elements)
97 pub fn with_strides(shape: [usize; 4], strides: [usize; 4]) -> Self {
98 let ndim = Self::compute_ndim(&shape);
99 Self {
100 shape,
101 strides,
102 offset: 0,
103 layout: MemoryLayout::RowMajor,
104 ndim,
105 _marker: PhantomData,
106 }
107 }
108
109 /// Create a 1D TensorView.
110 pub fn new_1d(len: usize) -> Self {
111 Self::new([len, 1, 1, 1])
112 }
113
114 /// Create a 2D TensorView (matrix).
115 pub fn new_2d(rows: usize, cols: usize) -> Self {
116 Self::new([rows, cols, 1, 1])
117 }
118
119 /// Create a 3D TensorView.
120 pub fn new_3d(d0: usize, d1: usize, d2: usize) -> Self {
121 Self::new([d0, d1, d2, 1])
122 }
123
124 /// Create a 4D TensorView.
125 pub fn new_4d(d0: usize, d1: usize, d2: usize, d3: usize) -> Self {
126 Self::new([d0, d1, d2, d3])
127 }
128
129 /// Get the shape of the tensor.
130 pub fn shape(&self) -> &[usize; 4] {
131 &self.shape
132 }
133
134 /// Get the strides of the tensor.
135 pub fn strides(&self) -> &[usize; 4] {
136 &self.strides
137 }
138
139 /// Get the offset from the start of the buffer.
140 pub fn offset(&self) -> usize {
141 self.offset
142 }
143
144 /// Get the memory layout.
145 pub fn layout(&self) -> MemoryLayout {
146 self.layout
147 }
148
149 /// Get the number of active dimensions.
150 pub fn ndim(&self) -> usize {
151 self.ndim
152 }
153
154 /// Get the total number of elements.
155 pub fn numel(&self) -> usize {
156 self.shape.iter().product()
157 }
158
159 /// Check if the tensor is contiguous in memory.
160 ///
161 /// A tensor is contiguous if elements are stored without gaps
162 /// in row-major order.
163 pub fn is_contiguous(&self) -> bool {
164 let expected_strides = Self::compute_row_major_strides(&self.shape);
165 self.strides == expected_strides
166 }
167
168 /// Check if the tensor is empty (has zero elements).
169 pub fn is_empty(&self) -> bool {
170 self.numel() == 0
171 }
172
173 /// Get dimension size at the given index.
174 ///
175 /// # Panics
176 ///
177 /// Panics if `dim >= 4`.
178 pub fn dim(&self, dim: usize) -> usize {
179 self.shape[dim]
180 }
181
182 /// Get stride at the given dimension.
183 ///
184 /// # Panics
185 ///
186 /// Panics if `dim >= 4`.
187 pub fn stride(&self, dim: usize) -> usize {
188 self.strides[dim]
189 }
190
191 /// Create a slice of this tensor along the first dimension.
192 ///
193 /// # Arguments
194 ///
195 /// * `range` - Range of indices to include
196 ///
197 /// # Returns
198 ///
199 /// A new TensorView representing the slice.
200 ///
201 /// # cuda-tile-behavior.md References
202 ///
203 /// - Falsification test #32: Slicing produces correct views
204 pub fn slice(&self, range: Range<usize>) -> Self {
205 assert!(range.end <= self.shape[0], "Slice range out of bounds");
206 let new_offset = self.offset + range.start * self.strides[0];
207 let mut new_shape = self.shape;
208 new_shape[0] = range.end - range.start;
209
210 Self {
211 shape: new_shape,
212 strides: self.strides,
213 offset: new_offset,
214 layout: self.layout,
215 ndim: self.ndim,
216 _marker: PhantomData,
217 }
218 }
219
220 /// Create a slice along a specific dimension.
221 ///
222 /// # Arguments
223 ///
224 /// * `dim` - Dimension to slice along
225 /// * `range` - Range of indices to include
226 pub fn slice_dim(&self, dim: usize, range: Range<usize>) -> Self {
227 assert!(dim < 4, "Dimension out of bounds");
228 assert!(range.end <= self.shape[dim], "Slice range out of bounds");
229
230 let new_offset = self.offset + range.start * self.strides[dim];
231 let mut new_shape = self.shape;
232 new_shape[dim] = range.end - range.start;
233
234 Self {
235 shape: new_shape,
236 strides: self.strides,
237 offset: new_offset,
238 layout: self.layout,
239 ndim: self.ndim,
240 _marker: PhantomData,
241 }
242 }
243
244 /// Transpose the tensor by swapping two dimensions.
245 ///
246 /// # Arguments
247 ///
248 /// * `dim0` - First dimension to swap
249 /// * `dim1` - Second dimension to swap
250 ///
251 /// # Returns
252 ///
253 /// A new TensorView with swapped dimensions.
254 ///
255 /// # cuda-tile-behavior.md References
256 ///
257 /// - Falsification test #33: Transpose swaps dimensions correctly
258 pub fn transpose(&self, dim0: usize, dim1: usize) -> Self {
259 assert!(dim0 < 4 && dim1 < 4, "Dimension out of bounds");
260
261 let mut new_shape = self.shape;
262 let mut new_strides = self.strides;
263 new_shape.swap(dim0, dim1);
264 new_strides.swap(dim0, dim1);
265
266 Self {
267 shape: new_shape,
268 strides: new_strides,
269 offset: self.offset,
270 layout: self.layout,
271 ndim: self.ndim,
272 _marker: PhantomData,
273 }
274 }
275
276 /// Reshape the tensor to a new shape.
277 ///
278 /// # Arguments
279 ///
280 /// * `new_shape` - New shape (must have same number of elements)
281 ///
282 /// # Returns
283 ///
284 /// A new TensorView with the new shape, or None if reshape is invalid.
285 pub fn reshape(&self, new_shape: [usize; 4]) -> Option<Self> {
286 let new_numel: usize = new_shape.iter().product();
287 if new_numel != self.numel() {
288 return None;
289 }
290
291 // Reshape requires contiguous memory
292 if !self.is_contiguous() {
293 return None;
294 }
295
296 Some(Self::new(new_shape))
297 }
298
299 /// Squeeze dimensions of size 1.
300 ///
301 /// Returns a view with all size-1 dimensions removed.
302 pub fn squeeze(&self) -> Self {
303 let mut new_shape = [1usize; 4];
304 let mut new_strides = [1usize; 4];
305 let mut new_ndim = 0;
306
307 for i in 0..4 {
308 if self.shape[i] > 1 {
309 new_shape[new_ndim] = self.shape[i];
310 new_strides[new_ndim] = self.strides[i];
311 new_ndim += 1;
312 }
313 }
314
315 // If all dimensions were 1, keep at least one
316 if new_ndim == 0 {
317 new_ndim = 1;
318 }
319
320 Self {
321 shape: new_shape,
322 strides: new_strides,
323 offset: self.offset,
324 layout: self.layout,
325 ndim: new_ndim,
326 _marker: PhantomData,
327 }
328 }
329
330 /// Unsqueeze: add a dimension of size 1 at the specified position.
331 ///
332 /// # Arguments
333 ///
334 /// * `dim` - Position to insert the new dimension
335 pub fn unsqueeze(&self, dim: usize) -> Option<Self> {
336 if dim > self.ndim || self.ndim >= 4 {
337 return None;
338 }
339
340 let mut new_shape = [1usize; 4];
341 let mut new_strides = [1usize; 4];
342
343 // Copy dimensions before the insertion point
344 // Using manual loop since we're copying from two separate arrays to two separate arrays
345 #[allow(clippy::manual_memcpy)]
346 for i in 0..dim {
347 new_shape[i] = self.shape[i];
348 new_strides[i] = self.strides[i];
349 }
350
351 // Insert the new dimension
352 new_shape[dim] = 1;
353 new_strides[dim] = if dim < self.ndim { self.strides[dim] * self.shape[dim] } else { 1 };
354
355 // Copy remaining dimensions (offset by 1 for insertion)
356 #[allow(clippy::manual_memcpy)]
357 for i in dim..self.ndim {
358 new_shape[i + 1] = self.shape[i];
359 new_strides[i + 1] = self.strides[i];
360 }
361
362 Some(Self {
363 shape: new_shape,
364 strides: new_strides,
365 offset: self.offset,
366 layout: self.layout,
367 ndim: self.ndim + 1,
368 _marker: PhantomData,
369 })
370 }
371
372 /// Set the memory layout hint.
373 pub fn with_layout(mut self, layout: MemoryLayout) -> Self {
374 self.layout = layout;
375 self
376 }
377
378 /// Compute linear index from multi-dimensional indices.
379 ///
380 /// # Arguments
381 ///
382 /// * `indices` - Array of indices for each dimension
383 ///
384 /// # Returns
385 ///
386 /// Linear offset into the underlying buffer.
387 pub fn linear_index(&self, indices: [usize; 4]) -> usize {
388 self.offset
389 + indices[0] * self.strides[0]
390 + indices[1] * self.strides[1]
391 + indices[2] * self.strides[2]
392 + indices[3] * self.strides[3]
393 }
394
395 /// Compute row-major strides for a given shape.
396 fn compute_row_major_strides(shape: &[usize; 4]) -> [usize; 4] {
397 let mut strides = [1usize; 4];
398 // Strides: s[i] = product of shape[i+1..4]
399 strides[3] = 1;
400 strides[2] = shape[3];
401 strides[1] = shape[3] * shape[2];
402 strides[0] = shape[3] * shape[2] * shape[1];
403 strides
404 }
405
406 /// Compute the number of active dimensions.
407 fn compute_ndim(shape: &[usize; 4]) -> usize {
408 // Count from the end: find last dimension > 1
409 for i in (0..4).rev() {
410 if shape[i] > 1 {
411 return i + 1;
412 }
413 }
414 1 // At least 1 dimension
415 }
416}
417
418impl<T> Clone for TensorView<T> {
419 fn clone(&self) -> Self {
420 Self {
421 shape: self.shape,
422 strides: self.strides,
423 offset: self.offset,
424 layout: self.layout,
425 ndim: self.ndim,
426 _marker: PhantomData,
427 }
428 }
429}
430
431impl<T> Default for TensorView<T> {
432 fn default() -> Self {
433 Self::new([1, 1, 1, 1])
434 }
435}
436
437#[cfg(test)]
438mod tests;