Skip to main content

trueno/backends/gpu/partition_view/
mod.rs

1//! PartitionView - Tiling Strategy for GPU Compute
2//!
3//! Divides a TensorView into tiles for efficient GPU processing.
4//! Enables automatic work distribution across thread blocks.
5//!
6//! # cuda-tile-behavior.md References
7//!
8//! - Section 3.2: Two-Level Memory Hierarchy
9//! - Falsification tests #36-45: PartitionView correctness
10//!
11//! # Academic Foundation
12//!
13//! Based on Volkov & Demmel (2008): Power-of-two tiles achieve 95%+ peak throughput.
14
15use super::tensor_view::TensorView;
16use std::marker::PhantomData;
17
18/// A tiling strategy over a TensorView.
19///
20/// PartitionView divides a tensor into tiles of a specified shape,
21/// enabling efficient GPU processing with shared memory optimization.
22///
23/// # Type Parameters
24///
25/// * `T` - Element type of the underlying tensor
26///
27/// # cuda-tile-behavior.md References
28///
29/// - Falsification test #36: Tile count calculation is correct
30/// - Falsification test #37: Tile iteration covers all elements
31/// - Falsification test #38: Edge tiles are handled correctly
32#[derive(Debug)]
33pub struct PartitionView<T> {
34    /// The underlying tensor being partitioned
35    tensor: TensorView<T>,
36    /// Shape of each tile
37    tile_shape: [usize; 4],
38    /// Phantom data for type safety
39    _marker: PhantomData<T>,
40}
41
42/// Information about a single tile within a partition.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct TileInfo {
45    /// Tile index in each dimension
46    pub tile_idx: [usize; 4],
47    /// Starting element index in each dimension
48    pub start: [usize; 4],
49    /// Size of this tile in each dimension (may be smaller at edges)
50    pub size: [usize; 4],
51    /// Whether this is an edge tile (smaller than full tile size)
52    pub is_edge: bool,
53}
54
55impl<T> PartitionView<T> {
56    /// Create a new PartitionView with the given tile shape.
57    ///
58    /// # Arguments
59    ///
60    /// * `tensor` - The tensor to partition
61    /// * `tile_shape` - Shape of each tile
62    ///
63    /// # Panics
64    ///
65    /// Panics if any tile dimension is zero.
66    ///
67    /// # cuda-tile-behavior.md References
68    ///
69    /// - Falsification test #36: Tile count calculation is correct
70    pub fn new(tensor: TensorView<T>, tile_shape: [usize; 4]) -> Self {
71        assert!(tile_shape.iter().all(|&d| d > 0), "Tile dimensions must be non-zero");
72        Self { tensor, tile_shape, _marker: PhantomData }
73    }
74
75    /// Create a PartitionView with power-of-two tile sizes.
76    ///
77    /// This is recommended for GPU compute as it enables efficient
78    /// memory coalescing and avoids bank conflicts.
79    ///
80    /// # Arguments
81    ///
82    /// * `tensor` - The tensor to partition
83    /// * `tile_log2` - Log2 of tile size for each dimension
84    ///
85    /// # cuda-tile-behavior.md References
86    ///
87    /// - Falsification test #1: Power-of-two tiles improve GPU occupancy
88    pub fn new_power_of_two(tensor: TensorView<T>, tile_log2: [usize; 4]) -> Self {
89        let tile_shape =
90            [1 << tile_log2[0], 1 << tile_log2[1], 1 << tile_log2[2], 1 << tile_log2[3]];
91        Self::new(tensor, tile_shape)
92    }
93
94    /// Create a PartitionView with 2D tiles (for matrix operations).
95    ///
96    /// # Arguments
97    ///
98    /// * `tensor` - The tensor to partition
99    /// * `tile_rows` - Number of rows per tile
100    /// * `tile_cols` - Number of columns per tile
101    pub fn new_2d(tensor: TensorView<T>, tile_rows: usize, tile_cols: usize) -> Self {
102        Self::new(tensor, [tile_rows, tile_cols, 1, 1])
103    }
104
105    /// Get the underlying tensor.
106    pub fn tensor(&self) -> &TensorView<T> {
107        &self.tensor
108    }
109
110    /// Get the tile shape.
111    pub fn tile_shape(&self) -> &[usize; 4] {
112        &self.tile_shape
113    }
114
115    /// Get the number of tiles in each dimension.
116    ///
117    /// # cuda-tile-behavior.md References
118    ///
119    /// - Falsification test #36: Tile count calculation is correct
120    pub fn tile_count(&self) -> [usize; 4] {
121        let tensor_shape = self.tensor.shape();
122        [
123            tensor_shape[0].div_ceil(self.tile_shape[0]),
124            tensor_shape[1].div_ceil(self.tile_shape[1]),
125            tensor_shape[2].div_ceil(self.tile_shape[2]),
126            tensor_shape[3].div_ceil(self.tile_shape[3]),
127        ]
128    }
129
130    /// Get the total number of tiles.
131    pub fn total_tiles(&self) -> usize {
132        let count = self.tile_count();
133        count.iter().product()
134    }
135
136    /// Get information about a specific tile.
137    ///
138    /// # Arguments
139    ///
140    /// * `tile_idx` - Index of the tile in each dimension
141    ///
142    /// # Returns
143    ///
144    /// TileInfo containing the tile's position and size.
145    ///
146    /// # cuda-tile-behavior.md References
147    ///
148    /// - Falsification test #38: Edge tiles are handled correctly
149    pub fn get_tile(&self, tile_idx: [usize; 4]) -> Option<TileInfo> {
150        let tile_count = self.tile_count();
151
152        // Validate tile index
153        for i in 0..4 {
154            if tile_idx[i] >= tile_count[i] {
155                return None;
156            }
157        }
158
159        let tensor_shape = self.tensor.shape();
160        let mut start = [0usize; 4];
161        let mut size = [0usize; 4];
162        let mut is_edge = false;
163
164        for i in 0..4 {
165            start[i] = tile_idx[i] * self.tile_shape[i];
166            let remaining = tensor_shape[i] - start[i];
167            size[i] = remaining.min(self.tile_shape[i]);
168
169            // Check if this is an edge tile
170            if size[i] < self.tile_shape[i] {
171                is_edge = true;
172            }
173        }
174
175        Some(TileInfo { tile_idx, start, size, is_edge })
176    }
177
178    /// Get a TensorView for a specific tile.
179    ///
180    /// # Arguments
181    ///
182    /// * `tile_idx` - Index of the tile in each dimension
183    ///
184    /// # Returns
185    ///
186    /// A TensorView representing the tile, or None if index is invalid.
187    pub fn get_tile_view(&self, tile_idx: [usize; 4]) -> Option<TensorView<T>> {
188        let info = self.get_tile(tile_idx)?;
189
190        // Create a sliced view for this tile
191        let mut view = self.tensor.clone();
192
193        for i in 0..4 {
194            if self.tensor.shape()[i] > 1 {
195                view = view.slice_dim(i, info.start[i]..info.start[i] + info.size[i]);
196            }
197        }
198
199        Some(view)
200    }
201
202    /// Iterate over all tiles.
203    ///
204    /// # cuda-tile-behavior.md References
205    ///
206    /// - Falsification test #37: Tile iteration covers all elements
207    pub fn iter_tiles(&self) -> TileIterator<'_, T> {
208        TileIterator { partition: self, current: [0, 0, 0, 0], done: false }
209    }
210
211    /// Check if tiles are power-of-two sized.
212    ///
213    /// Power-of-two tiles are preferred for GPU compute.
214    pub fn is_power_of_two_tiles(&self) -> bool {
215        self.tile_shape.iter().all(|&d| d.is_power_of_two())
216    }
217
218    /// Get the number of elements per tile (maximum).
219    pub fn elements_per_tile(&self) -> usize {
220        self.tile_shape.iter().product()
221    }
222
223    /// Get recommended workgroup size for GPU dispatch.
224    ///
225    /// Returns (x, y, z) workgroup dimensions based on tile shape.
226    pub fn recommended_workgroup_size(&self) -> (u32, u32, u32) {
227        // Common GPU workgroup limits
228        const MAX_WORKGROUP_SIZE: usize = 256;
229        const MAX_DIM: usize = 16;
230
231        let tile_2d = [self.tile_shape[0], self.tile_shape[1]];
232
233        // For 2D tiles, use 2D workgroups
234        if tile_2d[0] > 1 && tile_2d[1] > 1 {
235            let x = tile_2d[1].min(MAX_DIM) as u32;
236            let y = tile_2d[0].min(MAX_DIM) as u32;
237            let z = 1;
238            (x, y, z)
239        } else {
240            // 1D workgroup
241            let size = self.elements_per_tile().min(MAX_WORKGROUP_SIZE);
242            (size as u32, 1, 1)
243        }
244    }
245}
246
247impl<T> Clone for PartitionView<T> {
248    fn clone(&self) -> Self {
249        Self { tensor: self.tensor.clone(), tile_shape: self.tile_shape, _marker: PhantomData }
250    }
251}
252
253/// Iterator over tiles in a PartitionView.
254pub struct TileIterator<'a, T> {
255    partition: &'a PartitionView<T>,
256    current: [usize; 4],
257    done: bool,
258}
259
260impl<T> Iterator for TileIterator<'_, T> {
261    type Item = TileInfo;
262
263    fn next(&mut self) -> Option<Self::Item> {
264        if self.done {
265            return None;
266        }
267
268        let tile = self.partition.get_tile(self.current)?;
269        let tile_count = self.partition.tile_count();
270
271        // Advance to next tile (row-major order)
272        self.current[3] += 1;
273        for i in (0..4).rev() {
274            if self.current[i] >= tile_count[i] {
275                self.current[i] = 0;
276                if i > 0 {
277                    self.current[i - 1] += 1;
278                } else {
279                    self.done = true;
280                }
281            } else {
282                break;
283            }
284        }
285
286        Some(tile)
287    }
288
289    fn size_hint(&self) -> (usize, Option<usize>) {
290        let total = self.partition.total_tiles();
291        (total, Some(total))
292    }
293}
294
295impl<T> ExactSizeIterator for TileIterator<'_, T> {}
296
297#[cfg(test)]
298mod tests;