use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum WorkDistribution {
#[default]
EvenSplit,
RowSlab { rows_per_tile: usize },
}
#[derive(Debug, Clone)]
pub struct SubDeviceInfo {
pub index: usize,
pub name: String,
pub eu_count: u32,
}
#[derive(Debug, Clone)]
pub struct TileWorkSlice {
pub tile_index: usize,
pub row_start: usize,
pub row_end: usize,
}
impl TileWorkSlice {
#[inline]
pub fn rows(&self) -> usize {
self.row_end - self.row_start
}
}
#[derive(Debug, Clone)]
pub struct MultiTileConfig {
pub strategy: WorkDistribution,
pub max_tiles: usize,
pub min_rows_for_multi_tile: usize,
}
impl Default for MultiTileConfig {
fn default() -> Self {
Self {
strategy: WorkDistribution::EvenSplit,
max_tiles: 0,
min_rows_for_multi_tile: 64,
}
}
}
#[derive(Debug)]
pub struct MultiTileDispatcher {
pub sub_devices: Vec<SubDeviceInfo>,
pub config: MultiTileConfig,
}
impl MultiTileDispatcher {
pub fn new(sub_devices: Vec<SubDeviceInfo>, config: MultiTileConfig) -> Self {
Self {
sub_devices,
config,
}
}
pub fn single_device() -> Self {
Self::new(Vec::new(), MultiTileConfig::default())
}
pub fn tile_count(&self) -> usize {
let n = self.sub_devices.len().max(1);
if self.config.max_tiles == 0 {
n
} else {
n.min(self.config.max_tiles)
}
}
pub fn should_use_multi_tile(&self, m: usize) -> bool {
self.sub_devices.len() > 1 && m >= self.config.min_rows_for_multi_tile
}
pub fn partition(&self, m: usize) -> Vec<TileWorkSlice> {
if m == 0 {
return vec![TileWorkSlice {
tile_index: 0,
row_start: 0,
row_end: 0,
}];
}
let n_tiles = self.tile_count();
if n_tiles <= 1 {
return vec![TileWorkSlice {
tile_index: 0,
row_start: 0,
row_end: m,
}];
}
let rows_per_tile = match &self.config.strategy {
WorkDistribution::EvenSplit => m.div_ceil(n_tiles),
WorkDistribution::RowSlab { rows_per_tile } => *rows_per_tile,
};
let mut slices = Vec::with_capacity(n_tiles);
let mut row_start = 0usize;
for i in 0..n_tiles {
if row_start >= m {
break;
}
let row_end = if i == n_tiles - 1 {
m } else {
(row_start + rows_per_tile).min(m)
};
slices.push(TileWorkSlice {
tile_index: i,
row_start,
row_end,
});
row_start = row_end;
}
slices
}
pub fn from_synthetic(names: &[&str]) -> Self {
let sub_devices = names
.iter()
.enumerate()
.map(|(i, &name)| SubDeviceInfo {
index: i,
name: name.to_string(),
eu_count: 512,
})
.collect();
Self::new(sub_devices, MultiTileConfig::default())
}
}
impl fmt::Display for MultiTileDispatcher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MultiTileDispatcher {{ tiles: {}, strategy: {:?} }}",
self.tile_count(),
self.config.strategy
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_device_returns_one_tile() {
let d = MultiTileDispatcher::single_device();
assert_eq!(d.tile_count(), 1);
assert!(!d.should_use_multi_tile(1024));
}
#[test]
fn four_tile_even_split() {
let d = MultiTileDispatcher::from_synthetic(&["tile0", "tile1", "tile2", "tile3"]);
assert_eq!(d.tile_count(), 4);
let slices = d.partition(256);
assert_eq!(slices.len(), 4);
for (i, s) in slices.iter().enumerate() {
assert_eq!(s.tile_index, i);
assert_eq!(s.rows(), 64, "tile {i} expected 64 rows");
}
assert_eq!(
slices
.first()
.expect("partition slice access should be valid in test context")
.row_start,
0
);
assert_eq!(
slices
.last()
.expect("partition slice access should be valid in test context")
.row_end,
256
);
}
#[test]
fn uneven_split_last_tile_gets_remainder() {
let d = MultiTileDispatcher::from_synthetic(&["t0", "t1", "t2"]);
let slices = d.partition(100); assert_eq!(slices.len(), 3);
assert_eq!(slices[0].rows(), 34); assert_eq!(slices[1].rows(), 34);
assert_eq!(slices[2].rows(), 32); assert_eq!(slices[2].row_end, 100);
}
#[test]
fn row_slab_strategy() {
let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c"]);
d.config.strategy = WorkDistribution::RowSlab { rows_per_tile: 50 };
let slices = d.partition(120);
assert_eq!(slices[0].rows(), 50);
assert_eq!(slices[1].rows(), 50);
assert_eq!(slices[2].rows(), 20); }
#[test]
fn max_tiles_cap() {
let mut d = MultiTileDispatcher::from_synthetic(&["a", "b", "c", "d"]);
d.config.max_tiles = 2;
assert_eq!(d.tile_count(), 2);
let slices = d.partition(200);
assert_eq!(slices.len(), 2);
assert_eq!(
slices
.last()
.expect("partition slice access should be valid in test context")
.row_end,
200
);
}
#[test]
fn zero_rows_returns_empty_slice() {
let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
let slices = d.partition(0);
assert_eq!(slices.len(), 1);
assert_eq!(slices[0].rows(), 0);
}
#[test]
fn should_use_multi_tile_threshold() {
let d = MultiTileDispatcher::from_synthetic(&["a", "b"]);
assert!(!d.should_use_multi_tile(32)); assert!(d.should_use_multi_tile(64));
assert!(d.should_use_multi_tile(512));
}
#[test]
fn display_format() {
let d = MultiTileDispatcher::single_device();
let s = format!("{d}");
assert!(s.contains("MultiTileDispatcher"));
assert!(s.contains("tiles: 1"));
}
#[test]
fn sub_device_info_fields() {
let info = SubDeviceInfo {
index: 2,
name: "Intel Xe-HPC Tile 2".to_string(),
eu_count: 448,
};
assert_eq!(info.index, 2);
assert_eq!(info.eu_count, 448);
}
#[test]
fn work_slice_rows_calculation() {
let slice = TileWorkSlice {
tile_index: 0,
row_start: 100,
row_end: 200,
};
assert_eq!(slice.rows(), 100);
}
}