trueno/backends/gpu/partition_view/mod.rs
1//! PartitionView - Tiling Strategy for GPU Compute
2//!
3//! Divides a TensorView into tiles for efficient GPU processing.
4//! Enables automatic work distribution across thread blocks.
5//!
6//! # cuda-tile-behavior.md References
7//!
8//! - Section 3.2: Two-Level Memory Hierarchy
9//! - Falsification tests #36-45: PartitionView correctness
10//!
11//! # Academic Foundation
12//!
13//! Based on Volkov & Demmel (2008): Power-of-two tiles achieve 95%+ peak throughput.
14
15use super::tensor_view::TensorView;
16use std::marker::PhantomData;
17
18/// A tiling strategy over a TensorView.
19///
20/// PartitionView divides a tensor into tiles of a specified shape,
21/// enabling efficient GPU processing with shared memory optimization.
22///
23/// # Type Parameters
24///
25/// * `T` - Element type of the underlying tensor
26///
27/// # cuda-tile-behavior.md References
28///
29/// - Falsification test #36: Tile count calculation is correct
30/// - Falsification test #37: Tile iteration covers all elements
31/// - Falsification test #38: Edge tiles are handled correctly
32#[derive(Debug)]
33pub struct PartitionView<T> {
34 /// The underlying tensor being partitioned
35 tensor: TensorView<T>,
36 /// Shape of each tile
37 tile_shape: [usize; 4],
38 /// Phantom data for type safety
39 _marker: PhantomData<T>,
40}
41
42/// Information about a single tile within a partition.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct TileInfo {
45 /// Tile index in each dimension
46 pub tile_idx: [usize; 4],
47 /// Starting element index in each dimension
48 pub start: [usize; 4],
49 /// Size of this tile in each dimension (may be smaller at edges)
50 pub size: [usize; 4],
51 /// Whether this is an edge tile (smaller than full tile size)
52 pub is_edge: bool,
53}
54
55impl<T> PartitionView<T> {
56 /// Create a new PartitionView with the given tile shape.
57 ///
58 /// # Arguments
59 ///
60 /// * `tensor` - The tensor to partition
61 /// * `tile_shape` - Shape of each tile
62 ///
63 /// # Panics
64 ///
65 /// Panics if any tile dimension is zero.
66 ///
67 /// # cuda-tile-behavior.md References
68 ///
69 /// - Falsification test #36: Tile count calculation is correct
70 pub fn new(tensor: TensorView<T>, tile_shape: [usize; 4]) -> Self {
71 assert!(tile_shape.iter().all(|&d| d > 0), "Tile dimensions must be non-zero");
72 Self { tensor, tile_shape, _marker: PhantomData }
73 }
74
75 /// Create a PartitionView with power-of-two tile sizes.
76 ///
77 /// This is recommended for GPU compute as it enables efficient
78 /// memory coalescing and avoids bank conflicts.
79 ///
80 /// # Arguments
81 ///
82 /// * `tensor` - The tensor to partition
83 /// * `tile_log2` - Log2 of tile size for each dimension
84 ///
85 /// # cuda-tile-behavior.md References
86 ///
87 /// - Falsification test #1: Power-of-two tiles improve GPU occupancy
88 pub fn new_power_of_two(tensor: TensorView<T>, tile_log2: [usize; 4]) -> Self {
89 let tile_shape =
90 [1 << tile_log2[0], 1 << tile_log2[1], 1 << tile_log2[2], 1 << tile_log2[3]];
91 Self::new(tensor, tile_shape)
92 }
93
94 /// Create a PartitionView with 2D tiles (for matrix operations).
95 ///
96 /// # Arguments
97 ///
98 /// * `tensor` - The tensor to partition
99 /// * `tile_rows` - Number of rows per tile
100 /// * `tile_cols` - Number of columns per tile
101 pub fn new_2d(tensor: TensorView<T>, tile_rows: usize, tile_cols: usize) -> Self {
102 Self::new(tensor, [tile_rows, tile_cols, 1, 1])
103 }
104
105 /// Get the underlying tensor.
106 pub fn tensor(&self) -> &TensorView<T> {
107 &self.tensor
108 }
109
110 /// Get the tile shape.
111 pub fn tile_shape(&self) -> &[usize; 4] {
112 &self.tile_shape
113 }
114
115 /// Get the number of tiles in each dimension.
116 ///
117 /// # cuda-tile-behavior.md References
118 ///
119 /// - Falsification test #36: Tile count calculation is correct
120 pub fn tile_count(&self) -> [usize; 4] {
121 let tensor_shape = self.tensor.shape();
122 [
123 tensor_shape[0].div_ceil(self.tile_shape[0]),
124 tensor_shape[1].div_ceil(self.tile_shape[1]),
125 tensor_shape[2].div_ceil(self.tile_shape[2]),
126 tensor_shape[3].div_ceil(self.tile_shape[3]),
127 ]
128 }
129
130 /// Get the total number of tiles.
131 pub fn total_tiles(&self) -> usize {
132 let count = self.tile_count();
133 count.iter().product()
134 }
135
136 /// Get information about a specific tile.
137 ///
138 /// # Arguments
139 ///
140 /// * `tile_idx` - Index of the tile in each dimension
141 ///
142 /// # Returns
143 ///
144 /// TileInfo containing the tile's position and size.
145 ///
146 /// # cuda-tile-behavior.md References
147 ///
148 /// - Falsification test #38: Edge tiles are handled correctly
149 pub fn get_tile(&self, tile_idx: [usize; 4]) -> Option<TileInfo> {
150 let tile_count = self.tile_count();
151
152 // Validate tile index
153 for i in 0..4 {
154 if tile_idx[i] >= tile_count[i] {
155 return None;
156 }
157 }
158
159 let tensor_shape = self.tensor.shape();
160 let mut start = [0usize; 4];
161 let mut size = [0usize; 4];
162 let mut is_edge = false;
163
164 for i in 0..4 {
165 start[i] = tile_idx[i] * self.tile_shape[i];
166 let remaining = tensor_shape[i] - start[i];
167 size[i] = remaining.min(self.tile_shape[i]);
168
169 // Check if this is an edge tile
170 if size[i] < self.tile_shape[i] {
171 is_edge = true;
172 }
173 }
174
175 Some(TileInfo { tile_idx, start, size, is_edge })
176 }
177
178 /// Get a TensorView for a specific tile.
179 ///
180 /// # Arguments
181 ///
182 /// * `tile_idx` - Index of the tile in each dimension
183 ///
184 /// # Returns
185 ///
186 /// A TensorView representing the tile, or None if index is invalid.
187 pub fn get_tile_view(&self, tile_idx: [usize; 4]) -> Option<TensorView<T>> {
188 let info = self.get_tile(tile_idx)?;
189
190 // Create a sliced view for this tile
191 let mut view = self.tensor.clone();
192
193 for i in 0..4 {
194 if self.tensor.shape()[i] > 1 {
195 view = view.slice_dim(i, info.start[i]..info.start[i] + info.size[i]);
196 }
197 }
198
199 Some(view)
200 }
201
202 /// Iterate over all tiles.
203 ///
204 /// # cuda-tile-behavior.md References
205 ///
206 /// - Falsification test #37: Tile iteration covers all elements
207 pub fn iter_tiles(&self) -> TileIterator<'_, T> {
208 TileIterator { partition: self, current: [0, 0, 0, 0], done: false }
209 }
210
211 /// Check if tiles are power-of-two sized.
212 ///
213 /// Power-of-two tiles are preferred for GPU compute.
214 pub fn is_power_of_two_tiles(&self) -> bool {
215 self.tile_shape.iter().all(|&d| d.is_power_of_two())
216 }
217
218 /// Get the number of elements per tile (maximum).
219 pub fn elements_per_tile(&self) -> usize {
220 self.tile_shape.iter().product()
221 }
222
223 /// Get recommended workgroup size for GPU dispatch.
224 ///
225 /// Returns (x, y, z) workgroup dimensions based on tile shape.
226 pub fn recommended_workgroup_size(&self) -> (u32, u32, u32) {
227 // Common GPU workgroup limits
228 const MAX_WORKGROUP_SIZE: usize = 256;
229 const MAX_DIM: usize = 16;
230
231 let tile_2d = [self.tile_shape[0], self.tile_shape[1]];
232
233 // For 2D tiles, use 2D workgroups
234 if tile_2d[0] > 1 && tile_2d[1] > 1 {
235 let x = tile_2d[1].min(MAX_DIM) as u32;
236 let y = tile_2d[0].min(MAX_DIM) as u32;
237 let z = 1;
238 (x, y, z)
239 } else {
240 // 1D workgroup
241 let size = self.elements_per_tile().min(MAX_WORKGROUP_SIZE);
242 (size as u32, 1, 1)
243 }
244 }
245}
246
247impl<T> Clone for PartitionView<T> {
248 fn clone(&self) -> Self {
249 Self { tensor: self.tensor.clone(), tile_shape: self.tile_shape, _marker: PhantomData }
250 }
251}
252
253/// Iterator over tiles in a PartitionView.
254pub struct TileIterator<'a, T> {
255 partition: &'a PartitionView<T>,
256 current: [usize; 4],
257 done: bool,
258}
259
260impl<T> Iterator for TileIterator<'_, T> {
261 type Item = TileInfo;
262
263 fn next(&mut self) -> Option<Self::Item> {
264 if self.done {
265 return None;
266 }
267
268 let tile = self.partition.get_tile(self.current)?;
269 let tile_count = self.partition.tile_count();
270
271 // Advance to next tile (row-major order)
272 self.current[3] += 1;
273 for i in (0..4).rev() {
274 if self.current[i] >= tile_count[i] {
275 self.current[i] = 0;
276 if i > 0 {
277 self.current[i - 1] += 1;
278 } else {
279 self.done = true;
280 }
281 } else {
282 break;
283 }
284 }
285
286 Some(tile)
287 }
288
289 fn size_hint(&self) -> (usize, Option<usize>) {
290 let total = self.partition.total_tiles();
291 (total, Some(total))
292 }
293}
294
295impl<T> ExactSizeIterator for TileIterator<'_, T> {}
296
297#[cfg(test)]
298mod tests;