use super::device::GpuDeviceInfo;
use super::runtime::GpuRuntime;
impl GpuRuntime {
#[must_use]
pub fn device_ordinals(&self) -> Vec<usize> {
self.devices.iter().map(|device| device.ordinal).collect()
}
#[must_use]
pub fn device_count(&self) -> usize {
self.devices.len()
}
#[must_use]
pub fn memory_budget_for(&self, ordinal: usize) -> usize {
self.devices
.iter()
.find(|device| device.ordinal == ordinal)
.map_or(self.memory_budget_bytes, GpuDeviceInfo::memory_budget_bytes)
}
}
#[must_use]
pub fn balanced_partition(rt: &GpuRuntime, n_units: usize) -> Vec<(usize, std::ops::Range<usize>)> {
if n_units == 0 || rt.devices.is_empty() {
return Vec::new();
}
if rt.devices.len() == 1 {
return vec![(rt.devices[0].ordinal, 0..n_units)];
}
let scores: Vec<f64> = rt
.devices
.iter()
.map(|device| device.score().max(0.0))
.collect();
let total_score: f64 = scores.iter().sum();
let even = !(total_score.is_finite() && total_score > 0.0);
let n = n_units as f64;
let mut counts: Vec<usize> = Vec::with_capacity(rt.devices.len());
let mut remainders: Vec<(usize, f64)> = Vec::with_capacity(rt.devices.len());
let mut assigned = 0usize;
for (idx, score) in scores.iter().enumerate() {
let ideal = if even {
n / rt.devices.len() as f64
} else {
n * score / total_score
};
let floor = ideal.floor();
let count = floor as usize;
counts.push(count);
assigned += count;
remainders.push((idx, ideal - floor));
}
let mut leftover = n_units.saturating_sub(assigned);
if leftover > 0 {
remainders.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
for (idx, _) in &remainders {
if leftover == 0 {
break;
}
counts[*idx] += 1;
leftover -= 1;
}
}
let mut tiles = Vec::with_capacity(rt.devices.len());
let mut start = 0usize;
for (idx, device) in rt.devices.iter().enumerate() {
let count = counts[idx];
if count == 0 {
continue;
}
let end = start + count;
tiles.push((device.ordinal, start..end));
start = end;
}
assert_eq!(start, n_units, "balanced_partition tiles must cover 0..n");
tiles
}
#[cfg(target_os = "linux")]
#[must_use]
pub fn scatter_batched<T: Send>(
rt: &GpuRuntime,
items: &mut [T],
f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
) -> Option<()> {
let n_units = items.len();
let tiles = balanced_partition(rt, n_units);
if tiles.is_empty() {
return None;
}
let mut slices: Vec<(usize, &mut [T])> = Vec::with_capacity(tiles.len());
let mut rest = items;
let mut consumed = 0usize;
for (ordinal, range) in &tiles {
let take = range.end - consumed;
let (head, tail) = rest.split_at_mut(take);
slices.push((*ordinal, head));
rest = tail;
consumed = range.end;
}
let f = &f;
std::thread::scope(|scope| {
let handles: Vec<_> = slices
.into_iter()
.map(|(ordinal, slice)| {
scope.spawn(move || {
let ctx = super::runtime::cuda_context_for(ordinal)?;
ctx.bind_to_thread().ok()?;
f(ordinal, slice)
})
})
.collect();
let mut all_ok = true;
for handle in handles {
match handle.join() {
Ok(Some(())) => {}
_ => all_ok = false,
}
}
if all_ok { Some(()) } else { None }
})
}
#[cfg(not(target_os = "linux"))]
#[must_use]
pub fn scatter_batched<T: Send>(
rt: &GpuRuntime,
items: &mut [T],
f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
) -> Option<()> {
let tiles = balanced_partition(rt, items.len());
if tiles.is_empty() {
return None;
}
let mut rest = items;
let mut consumed = 0usize;
let mut all_ok = true;
for (ordinal, range) in &tiles {
let take = range.end - consumed;
let (head, tail) = rest.split_at_mut(take);
if f(*ordinal, head).is_none() {
all_ok = false;
}
rest = tail;
consumed = range.end;
}
if all_ok { Some(()) } else { None }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::device::{GpuCapability, GpuDeviceInfo};
use crate::gpu::policy::GpuDispatchPolicy;
fn device_with(ordinal: usize, sm_count: i32, free_gib: f64) -> GpuDeviceInfo {
GpuDeviceInfo {
ordinal,
name: format!("synthetic-{ordinal}"),
capability: GpuCapability::from_compute_capability(7, 0),
sm_count,
max_threads_per_sm: 2048,
max_shared_mem_per_block: 49_152,
l2_cache_bytes: 6 * 1024 * 1024,
total_mem_bytes: (free_gib as usize) * 1_073_741_824 * 2,
free_mem_bytes: (free_gib * 1_073_741_824.0) as usize,
ecc_enabled: false,
integrated: false,
mig_mode: false,
}
}
fn runtime_with(devices: Vec<GpuDeviceInfo>) -> GpuRuntime {
let device = devices
.first()
.cloned()
.expect("test runtime needs ≥1 device");
let memory_budget_bytes = device.free_mem_bytes.min(device.total_mem_bytes / 2);
GpuRuntime {
device,
devices,
policy: GpuDispatchPolicy::default(),
memory_budget_bytes,
}
}
fn assert_covers(tiles: &[(usize, std::ops::Range<usize>)], n_units: usize) {
let mut cursor = 0usize;
for (_, range) in tiles {
assert_eq!(range.start, cursor, "tile gap/overlap at {cursor}");
assert!(range.end > range.start, "empty tile emitted");
cursor = range.end;
}
assert_eq!(cursor, n_units, "tiles must cover the whole range");
}
#[test]
fn single_device_one_full_tile() {
let rt = runtime_with(vec![device_with(0, 80, 16.0)]);
let tiles = balanced_partition(&rt, 100);
assert_eq!(tiles, vec![(0, 0..100)]);
}
#[test]
fn three_devices_even_split_when_scores_equal() {
let rt = runtime_with(vec![
device_with(0, 80, 16.0),
device_with(1, 80, 16.0),
device_with(2, 80, 16.0),
]);
let tiles = balanced_partition(&rt, 99);
assert_eq!(
tiles,
vec![(0, 0..33), (1, 33..66), (2, 66..99)],
"equal scores must split evenly"
);
assert_covers(&tiles, 99);
let tiles = balanced_partition(&rt, 100);
assert_eq!(tiles, vec![(0, 0..34), (1, 34..67), (2, 67..100)]);
assert_covers(&tiles, 100);
}
#[test]
fn three_devices_weighted_by_unequal_scores() {
let devices = vec![
device_with(0, 132, 40.0),
device_with(1, 40, 8.0),
device_with(2, 40, 8.0),
];
let rt = runtime_with(devices.clone());
let n_units = 1000;
let tiles = balanced_partition(&rt, n_units);
assert_covers(&tiles, n_units);
assert_eq!(tiles[0].0, 0);
let widths: Vec<usize> = tiles.iter().map(|(_, r)| r.end - r.start).collect();
assert!(
widths[0] > widths[1] && widths[0] > widths[2],
"highest-score device must get the largest tile, got {widths:?}"
);
assert_eq!(widths[1], widths[2]);
let total_score: f64 = devices.iter().map(GpuDeviceInfo::score).sum();
for (device, width) in devices.iter().zip(&widths) {
let ideal = device.score() / total_score * n_units as f64;
assert!(
(*width as f64 - ideal).abs() <= 1.0,
"width {width} not within 1 of ideal {ideal} for ordinal {}",
device.ordinal
);
}
}
#[test]
fn fewer_units_than_devices_drops_empty_tiles() {
let rt = runtime_with(vec![
device_with(0, 132, 40.0),
device_with(1, 100, 24.0),
device_with(2, 80, 16.0),
device_with(3, 60, 12.0),
device_with(4, 40, 8.0),
]);
let tiles = balanced_partition(&rt, 2);
assert_covers(&tiles, 2);
assert_eq!(tiles.len(), 2, "one tile per unit when units < devices");
assert_eq!(tiles[0].0, 0, "highest-score device served first");
assert_eq!(tiles[1].0, 1);
}
#[test]
fn zero_units_yields_no_tiles() {
let rt = runtime_with(vec![device_with(0, 80, 16.0), device_with(1, 80, 16.0)]);
assert!(balanced_partition(&rt, 0).is_empty());
}
#[test]
fn device_ordinals_and_count_track_pool() {
let rt = runtime_with(vec![
device_with(0, 80, 16.0),
device_with(3, 80, 16.0),
device_with(5, 80, 16.0),
]);
assert_eq!(rt.device_count(), 3);
assert_eq!(rt.device_ordinals(), vec![0, 3, 5]);
}
#[test]
fn memory_budget_for_caps_free_at_half_total() {
let rt = runtime_with(vec![device_with(0, 80, 8.0)]);
let gib = 1_073_741_824usize;
assert_eq!(rt.memory_budget_for(0), 8 * gib);
assert_eq!(rt.memory_budget_for(99), rt.memory_budget_bytes);
}
}