Skip to main content

oxicuda_levelzero/
multi_tile.rs

1//! Multi-tile / multi-device dispatch for Intel Level Zero GPUs.
2//!
3//! Intel Xe-HPC (Ponte Vecchio) and other large-tile Intel GPUs expose individual
4//! compute tiles as Level Zero "sub-devices". This module discovers those
5//! sub-devices and distributes matrix work across them.
6//!
7//! # Overview
8//!
9//! ```text
10//! LevelZeroDevice (root)
11//!   ├── Tile 0  (sub-device handle)
12//!   ├── Tile 1
13//!   └── Tile N-1
14//! ```
15//!
16//! When sub-devices are not available (older GPUs, consumer Xe Arc), the
17//! [`MultiTileDispatcher`] transparently falls back to single-device dispatch.
18
19use std::fmt;
20
21// ─── Work Distribution Strategy ───────────────────────────────────────────────
22
23/// How to partition a matrix operation across multiple compute tiles.
24#[derive(Debug, Clone, PartialEq, Eq, Default)]
25pub enum WorkDistribution {
26    /// Divide the M dimension of the output matrix into equal-sized row slabs,
27    /// one per tile. If `m` is not evenly divisible, the last tile gets the
28    /// remainder rows.
29    #[default]
30    EvenSplit,
31    /// Assign exactly `rows_per_tile` rows to each tile except the last.
32    RowSlab { rows_per_tile: usize },
33}
34
35// ─── Sub-device Discovery ─────────────────────────────────────────────────────
36
37/// Metadata for a single Level Zero sub-device (compute tile).
38#[derive(Debug, Clone)]
39pub struct SubDeviceInfo {
40    /// Zero-based tile index within the parent device.
41    pub index: usize,
42    /// Human-readable device name (may include tile index suffix).
43    pub name: String,
44    /// Reported number of EU/XVE execution units on this tile.
45    pub eu_count: u32,
46}
47
48/// Work assignment for one tile during a GEMM dispatch.
49#[derive(Debug, Clone)]
50pub struct TileWorkSlice {
51    /// Tile that will execute this slice.
52    pub tile_index: usize,
53    /// Starting row index in the M dimension (inclusive).
54    pub row_start: usize,
55    /// Exclusive end row (i.e., this tile processes rows `row_start..row_end`).
56    pub row_end: usize,
57}
58
59impl TileWorkSlice {
60    /// Number of rows assigned to this tile.
61    #[inline]
62    pub fn rows(&self) -> usize {
63        self.row_end - self.row_start
64    }
65}
66
67// ─── MultiTileConfig ───────────────────────────────────────────────────────────
68
69/// Configuration for the multi-tile dispatcher.
70#[derive(Debug, Clone)]
71pub struct MultiTileConfig {
72    /// Work partitioning strategy.
73    pub strategy: WorkDistribution,
74    /// Maximum number of tiles to use (0 = use all available).
75    pub max_tiles: usize,
76    /// Minimum problem size (M rows) below which single-device dispatch is
77    /// preferred over multi-tile dispatch to avoid scheduling overhead.
78    pub min_rows_for_multi_tile: usize,
79}
80
81impl Default for MultiTileConfig {
82    fn default() -> Self {
83        Self {
84            strategy: WorkDistribution::EvenSplit,
85            max_tiles: 0,
86            min_rows_for_multi_tile: 64,
87        }
88    }
89}
90
91// ─── MultiTileDispatcher ──────────────────────────────────────────────────────
92
93/// Enumerates Level Zero sub-devices and partitions matrix work across tiles.
94///
95/// On devices without sub-devices (single-tile GPUs), the dispatcher degrades
96/// gracefully to single-device behaviour — callers do not need to special-case
97/// this.
98#[derive(Debug)]
99pub struct MultiTileDispatcher {
100    /// Discovered sub-devices. Empty means single-device fallback.
101    pub sub_devices: Vec<SubDeviceInfo>,
102    /// Active configuration.
103    pub config: MultiTileConfig,
104}
105
106impl MultiTileDispatcher {
107    /// Construct a dispatcher with pre-discovered sub-device information.
108    ///
109    /// Typically called by the backend after enumerating the Level Zero
110    /// device tree. Pass an empty `sub_devices` vec for single-tile GPUs.
111    pub fn new(sub_devices: Vec<SubDeviceInfo>, config: MultiTileConfig) -> Self {
112        Self {
113            sub_devices,
114            config,
115        }
116    }
117
118    /// Construct a single-device dispatcher (no sub-device enumeration).
119    pub fn single_device() -> Self {
120        Self::new(Vec::new(), MultiTileConfig::default())
121    }
122
123    /// Return how many tiles are available for dispatch.
124    ///
125    /// Returns 1 when no sub-devices were discovered (single-tile path).
126    pub fn tile_count(&self) -> usize {
127        let n = self.sub_devices.len().max(1);
128        if self.config.max_tiles == 0 {
129            n
130        } else {
131            n.min(self.config.max_tiles)
132        }
133    }
134
135    /// Return `true` when multi-tile dispatch should be used for a problem of
136    /// size `m` (number of output matrix rows).
137    pub fn should_use_multi_tile(&self, m: usize) -> bool {
138        self.sub_devices.len() > 1 && m >= self.config.min_rows_for_multi_tile
139    }
140
141    /// Partition `m` rows across available tiles according to the configured
142    /// `WorkDistribution` strategy.
143    ///
144    /// Returns a `Vec<TileWorkSlice>` with one entry per active tile. The
145    /// slices are non-overlapping and together cover the full `[0, m)` range.
146    ///
147    /// If there are no sub-devices or `m == 0`, returns a single slice for
148    /// tile 0 covering the entire row range.
149    pub fn partition(&self, m: usize) -> Vec<TileWorkSlice> {
150        if m == 0 {
151            return vec![TileWorkSlice {
152                tile_index: 0,
153                row_start: 0,
154                row_end: 0,
155            }];
156        }
157
158        let n_tiles = self.tile_count();
159
160        if n_tiles <= 1 {
161            return vec![TileWorkSlice {
162                tile_index: 0,
163                row_start: 0,
164                row_end: m,
165            }];
166        }
167
168        let rows_per_tile = match &self.config.strategy {
169            WorkDistribution::EvenSplit => m.div_ceil(n_tiles),
170            WorkDistribution::RowSlab { rows_per_tile } => *rows_per_tile,
171        };
172
173        let mut slices = Vec::with_capacity(n_tiles);
174        let mut row_start = 0usize;
175
176        for i in 0..n_tiles {
177            if row_start >= m {
178                break;
179            }
180            let row_end = if i == n_tiles - 1 {
181                m // last tile gets remainder
182            } else {
183                (row_start + rows_per_tile).min(m)
184            };
185            slices.push(TileWorkSlice {
186                tile_index: i,
187                row_start,
188                row_end,
189            });
190            row_start = row_end;
191        }
192
193        slices
194    }
195
196    /// Simulate sub-device enumeration from a list of synthetic device names.
197    ///
198    /// This is useful for unit tests that cannot call into Level Zero hardware.
199    pub fn from_synthetic(names: &[&str]) -> Self {
200        let sub_devices = names
201            .iter()
202            .enumerate()
203            .map(|(i, &name)| SubDeviceInfo {
204                index: i,
205                name: name.to_string(),
206                eu_count: 512,
207            })
208            .collect();
209        Self::new(sub_devices, MultiTileConfig::default())
210    }
211}
212
213impl fmt::Display for MultiTileDispatcher {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        write!(
216            f,
217            "MultiTileDispatcher {{ tiles: {}, strategy: {:?} }}",
218            self.tile_count(),
219            self.config.strategy
220        )
221    }
222}
223
224// ─── Tests ────────────────────────────────────────────────────────────────────
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    #[test]
231    fn single_device_returns_one_tile() {
232        let d = MultiTileDispatcher::single_device();
233        assert_eq!(d.tile_count(), 1);
234        assert!(!d.should_use_multi_tile(1024));
235    }
236
237    #[test]
238    fn four_tile_even_split() {
239        let d = MultiTileDispatcher::from_synthetic(&["tile0", "tile1", "tile2", "tile3"]);
240        assert_eq!(d.tile_count(), 4);
241
242        let slices = d.partition(256);
243        assert_eq!(slices.len(), 4);
244        // Each tile should get 64 rows.
245        for (i, s) in slices.iter().enumerate() {
246            assert_eq!(s.tile_index, i);
247            assert_eq!(s.rows(), 64, "tile {i} expected 64 rows");
248        }
249        // Coverage must span [0, 256).
250        assert_eq!(
251            slices
252                .first()
253                .expect("partition slice access should be valid in test context")
254                .row_start,
255            0
256        );
257        assert_eq!(
258            slices
259                .last()
260                .expect("partition slice access should be valid in test context")
261                .row_end,
262            256
263        );
264    }
265
266    #[test]
267    fn uneven_split_last_tile_gets_remainder() {
268        let d = MultiTileDispatcher::from_synthetic(&["t0", "t1", "t2"]);
269        let slices = d.partition(100); // 100 / 3 = 33 rem 1
270        assert_eq!(slices.len(), 3);
271        assert_eq!(slices[0].rows(), 34); // ceil(100/3)
272        assert_eq!(slices[1].rows(), 34);
273        assert_eq!(slices[2].rows(), 32); // remainder
274        assert_eq!(slices[2].row_end, 100);
275    }
276
277    #[test]
278    fn row_slab_strategy() {
279        let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c"]);
280        d.config.strategy = WorkDistribution::RowSlab { rows_per_tile: 50 };
281        let slices = d.partition(120);
282        assert_eq!(slices[0].rows(), 50);
283        assert_eq!(slices[1].rows(), 50);
284        assert_eq!(slices[2].rows(), 20); // remainder
285    }
286
287    #[test]
288    fn max_tiles_cap() {
289        let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c", "d"]);
290        d.config.max_tiles = 2;
291        assert_eq!(d.tile_count(), 2);
292        let slices = d.partition(200);
293        assert_eq!(slices.len(), 2);
294        assert_eq!(
295            slices
296                .last()
297                .expect("partition slice access should be valid in test context")
298                .row_end,
299            200
300        );
301    }
302
303    #[test]
304    fn zero_rows_returns_empty_slice() {
305        let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
306        let slices = d.partition(0);
307        assert_eq!(slices.len(), 1);
308        assert_eq!(slices[0].rows(), 0);
309    }
310
311    #[test]
312    fn should_use_multi_tile_threshold() {
313        let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
314        assert!(!d.should_use_multi_tile(32)); // below threshold (64)
315        assert!(d.should_use_multi_tile(64));
316        assert!(d.should_use_multi_tile(512));
317    }
318
319    #[test]
320    fn display_format() {
321        let d = MultiTileDispatcher::single_device();
322        let s = format!("{d}");
323        assert!(s.contains("MultiTileDispatcher"));
324        assert!(s.contains("tiles: 1"));
325    }
326
327    #[test]
328    fn sub_device_info_fields() {
329        let info = SubDeviceInfo {
330            index: 2,
331            name: "Intel Xe-HPC Tile 2".to_string(),
332            eu_count: 448,
333        };
334        assert_eq!(info.index, 2);
335        assert_eq!(info.eu_count, 448);
336    }
337
338    #[test]
339    fn work_slice_rows_calculation() {
340        let slice = TileWorkSlice {
341            tile_index: 0,
342            row_start: 100,
343            row_end: 200,
344        };
345        assert_eq!(slice.rows(), 100);
346    }
347}