1use std::fmt;
20
21#[derive(Debug, Clone, PartialEq, Eq, Default)]
25pub enum WorkDistribution {
26 #[default]
30 EvenSplit,
31 RowSlab { rows_per_tile: usize },
33}
34
35#[derive(Debug, Clone)]
39pub struct SubDeviceInfo {
40 pub index: usize,
42 pub name: String,
44 pub eu_count: u32,
46}
47
48#[derive(Debug, Clone)]
50pub struct TileWorkSlice {
51 pub tile_index: usize,
53 pub row_start: usize,
55 pub row_end: usize,
57}
58
59impl TileWorkSlice {
60 #[inline]
62 pub fn rows(&self) -> usize {
63 self.row_end - self.row_start
64 }
65}
66
67#[derive(Debug, Clone)]
71pub struct MultiTileConfig {
72 pub strategy: WorkDistribution,
74 pub max_tiles: usize,
76 pub min_rows_for_multi_tile: usize,
79}
80
81impl Default for MultiTileConfig {
82 fn default() -> Self {
83 Self {
84 strategy: WorkDistribution::EvenSplit,
85 max_tiles: 0,
86 min_rows_for_multi_tile: 64,
87 }
88 }
89}
90
91#[derive(Debug)]
99pub struct MultiTileDispatcher {
100 pub sub_devices: Vec<SubDeviceInfo>,
102 pub config: MultiTileConfig,
104}
105
106impl MultiTileDispatcher {
107 pub fn new(sub_devices: Vec<SubDeviceInfo>, config: MultiTileConfig) -> Self {
112 Self {
113 sub_devices,
114 config,
115 }
116 }
117
118 pub fn single_device() -> Self {
120 Self::new(Vec::new(), MultiTileConfig::default())
121 }
122
123 pub fn tile_count(&self) -> usize {
127 let n = self.sub_devices.len().max(1);
128 if self.config.max_tiles == 0 {
129 n
130 } else {
131 n.min(self.config.max_tiles)
132 }
133 }
134
135 pub fn should_use_multi_tile(&self, m: usize) -> bool {
138 self.sub_devices.len() > 1 && m >= self.config.min_rows_for_multi_tile
139 }
140
141 pub fn partition(&self, m: usize) -> Vec<TileWorkSlice> {
150 if m == 0 {
151 return vec![TileWorkSlice {
152 tile_index: 0,
153 row_start: 0,
154 row_end: 0,
155 }];
156 }
157
158 let n_tiles = self.tile_count();
159
160 if n_tiles <= 1 {
161 return vec![TileWorkSlice {
162 tile_index: 0,
163 row_start: 0,
164 row_end: m,
165 }];
166 }
167
168 let rows_per_tile = match &self.config.strategy {
169 WorkDistribution::EvenSplit => m.div_ceil(n_tiles),
170 WorkDistribution::RowSlab { rows_per_tile } => *rows_per_tile,
171 };
172
173 let mut slices = Vec::with_capacity(n_tiles);
174 let mut row_start = 0usize;
175
176 for i in 0..n_tiles {
177 if row_start >= m {
178 break;
179 }
180 let row_end = if i == n_tiles - 1 {
181 m } else {
183 (row_start + rows_per_tile).min(m)
184 };
185 slices.push(TileWorkSlice {
186 tile_index: i,
187 row_start,
188 row_end,
189 });
190 row_start = row_end;
191 }
192
193 slices
194 }
195
196 pub fn from_synthetic(names: &[&str]) -> Self {
200 let sub_devices = names
201 .iter()
202 .enumerate()
203 .map(|(i, &name)| SubDeviceInfo {
204 index: i,
205 name: name.to_string(),
206 eu_count: 512,
207 })
208 .collect();
209 Self::new(sub_devices, MultiTileConfig::default())
210 }
211}
212
213impl fmt::Display for MultiTileDispatcher {
214 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215 write!(
216 f,
217 "MultiTileDispatcher {{ tiles: {}, strategy: {:?} }}",
218 self.tile_count(),
219 self.config.strategy
220 )
221 }
222}
223
224#[cfg(test)]
227mod tests {
228 use super::*;
229
230 #[test]
231 fn single_device_returns_one_tile() {
232 let d = MultiTileDispatcher::single_device();
233 assert_eq!(d.tile_count(), 1);
234 assert!(!d.should_use_multi_tile(1024));
235 }
236
237 #[test]
238 fn four_tile_even_split() {
239 let d = MultiTileDispatcher::from_synthetic(&["tile0", "tile1", "tile2", "tile3"]);
240 assert_eq!(d.tile_count(), 4);
241
242 let slices = d.partition(256);
243 assert_eq!(slices.len(), 4);
244 for (i, s) in slices.iter().enumerate() {
246 assert_eq!(s.tile_index, i);
247 assert_eq!(s.rows(), 64, "tile {i} expected 64 rows");
248 }
249 assert_eq!(
251 slices
252 .first()
253 .expect("partition slice access should be valid in test context")
254 .row_start,
255 0
256 );
257 assert_eq!(
258 slices
259 .last()
260 .expect("partition slice access should be valid in test context")
261 .row_end,
262 256
263 );
264 }
265
266 #[test]
267 fn uneven_split_last_tile_gets_remainder() {
268 let d = MultiTileDispatcher::from_synthetic(&["t0", "t1", "t2"]);
269 let slices = d.partition(100); assert_eq!(slices.len(), 3);
271 assert_eq!(slices[0].rows(), 34); assert_eq!(slices[1].rows(), 34);
273 assert_eq!(slices[2].rows(), 32); assert_eq!(slices[2].row_end, 100);
275 }
276
277 #[test]
278 fn row_slab_strategy() {
279 let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c"]);
280 d.config.strategy = WorkDistribution::RowSlab { rows_per_tile: 50 };
281 let slices = d.partition(120);
282 assert_eq!(slices[0].rows(), 50);
283 assert_eq!(slices[1].rows(), 50);
284 assert_eq!(slices[2].rows(), 20); }
286
287 #[test]
288 fn max_tiles_cap() {
289 let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c", "d"]);
290 d.config.max_tiles = 2;
291 assert_eq!(d.tile_count(), 2);
292 let slices = d.partition(200);
293 assert_eq!(slices.len(), 2);
294 assert_eq!(
295 slices
296 .last()
297 .expect("partition slice access should be valid in test context")
298 .row_end,
299 200
300 );
301 }
302
303 #[test]
304 fn zero_rows_returns_empty_slice() {
305 let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
306 let slices = d.partition(0);
307 assert_eq!(slices.len(), 1);
308 assert_eq!(slices[0].rows(), 0);
309 }
310
311 #[test]
312 fn should_use_multi_tile_threshold() {
313 let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
314 assert!(!d.should_use_multi_tile(32)); assert!(d.should_use_multi_tile(64));
316 assert!(d.should_use_multi_tile(512));
317 }
318
319 #[test]
320 fn display_format() {
321 let d = MultiTileDispatcher::single_device();
322 let s = format!("{d}");
323 assert!(s.contains("MultiTileDispatcher"));
324 assert!(s.contains("tiles: 1"));
325 }
326
327 #[test]
328 fn sub_device_info_fields() {
329 let info = SubDeviceInfo {
330 index: 2,
331 name: "Intel Xe-HPC Tile 2".to_string(),
332 eu_count: 448,
333 };
334 assert_eq!(info.index, 2);
335 assert_eq!(info.eu_count, 448);
336 }
337
338 #[test]
339 fn work_slice_rows_calculation() {
340 let slice = TileWorkSlice {
341 tile_index: 0,
342 row_start: 100,
343 row_end: 200,
344 };
345 assert_eq!(slice.rows(), 100);
346 }
347}