1use super::device::GpuDeviceInfo;
29use super::device_runtime::GpuRuntime;
30
31impl GpuRuntime {
32 #[must_use]
38 pub fn device_ordinals(&self) -> Vec<usize> {
39 self.devices.iter().map(|device| device.ordinal).collect()
40 }
41
42 #[must_use]
44 pub fn device_count(&self) -> usize {
45 self.devices.len()
46 }
47
48 #[must_use]
54 pub fn memory_budget_for(&self, ordinal: usize) -> usize {
55 self.devices
56 .iter()
57 .find(|device| device.ordinal == ordinal)
58 .map_or(self.memory_budget_bytes, GpuDeviceInfo::memory_budget_bytes)
59 }
60}
61
62#[must_use]
76pub fn balanced_partition(rt: &GpuRuntime, n_units: usize) -> Vec<(usize, std::ops::Range<usize>)> {
77 if n_units == 0 || rt.devices.is_empty() {
78 return Vec::new();
79 }
80 if rt.devices.len() == 1 {
81 return vec![(rt.devices[0].ordinal, 0..n_units)];
82 }
83
84 let scores: Vec<f64> = rt
85 .devices
86 .iter()
87 .map(|device| device.score().max(0.0))
88 .collect();
89 let total_score: f64 = scores.iter().sum();
90
91 let even = !(total_score.is_finite() && total_score > 0.0);
94
95 let n = n_units as f64;
96 let mut counts: Vec<usize> = Vec::with_capacity(rt.devices.len());
97 let mut remainders: Vec<(usize, f64)> = Vec::with_capacity(rt.devices.len());
98 let mut assigned = 0usize;
99 for (idx, score) in scores.iter().enumerate() {
100 let ideal = if even {
101 n / rt.devices.len() as f64
102 } else {
103 n * score / total_score
104 };
105 let floor = ideal.floor();
106 let count = floor as usize;
107 counts.push(count);
108 assigned += count;
109 remainders.push((idx, ideal - floor));
110 }
111
112 let mut leftover = n_units.saturating_sub(assigned);
115 if leftover > 0 {
116 remainders.sort_by(|a, b| {
117 b.1.partial_cmp(&a.1)
118 .unwrap_or(std::cmp::Ordering::Equal)
119 .then(a.0.cmp(&b.0))
120 });
121 for (idx, _) in &remainders {
122 if leftover == 0 {
123 break;
124 }
125 counts[*idx] += 1;
126 leftover -= 1;
127 }
128 }
129
130 let mut tiles = Vec::with_capacity(rt.devices.len());
131 let mut start = 0usize;
132 for (idx, device) in rt.devices.iter().enumerate() {
133 let count = counts[idx];
134 if count == 0 {
135 continue;
136 }
137 let end = start + count;
138 tiles.push((device.ordinal, start..end));
139 start = end;
140 }
141 assert_eq!(start, n_units, "balanced_partition tiles must cover 0..n");
142 tiles
143}
144
145#[cfg(target_os = "linux")]
157#[must_use]
158pub fn scatter_batched<T: Send>(
159 rt: &GpuRuntime,
160 items: &mut [T],
161 f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
162) -> Option<()> {
163 let n_units = items.len();
164 let tiles = balanced_partition(rt, n_units);
165 if tiles.is_empty() {
166 return None;
167 }
168
169 let mut slices: Vec<(usize, &mut [T])> = Vec::with_capacity(tiles.len());
172 let mut rest = items;
173 let mut consumed = 0usize;
174 for (ordinal, range) in &tiles {
175 let take = range.end - consumed;
176 let (head, tail) = rest.split_at_mut(take);
177 slices.push((*ordinal, head));
178 rest = tail;
179 consumed = range.end;
180 }
181
182 let f = &f;
183 std::thread::scope(|scope| {
184 let handles: Vec<_> = slices
185 .into_iter()
186 .map(|(ordinal, slice)| {
187 scope.spawn(move || {
188 let ctx = super::device_runtime::cuda_context_for(ordinal)?;
191 ctx.bind_to_thread().ok()?;
192 f(ordinal, slice)
193 })
194 })
195 .collect();
196
197 let mut all_ok = true;
200 for handle in handles {
201 match handle.join() {
202 Ok(Some(())) => {}
203 _ => all_ok = false,
204 }
205 }
206 if all_ok { Some(()) } else { None }
207 })
208}
209
210#[cfg(not(target_os = "linux"))]
226#[must_use]
227pub fn scatter_batched<T: Send>(
228 rt: &GpuRuntime,
229 items: &mut [T],
230 f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
231) -> Option<()> {
232 let tiles = balanced_partition(rt, items.len());
233 if tiles.is_empty() {
234 return None;
235 }
236 let mut rest = items;
237 let mut consumed = 0usize;
238 let mut all_ok = true;
239 for (ordinal, range) in &tiles {
240 let take = range.end - consumed;
241 let (head, tail) = rest.split_at_mut(take);
242 if f(*ordinal, head).is_none() {
243 all_ok = false;
244 }
245 rest = tail;
246 consumed = range.end;
247 }
248 if all_ok { Some(()) } else { None }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::device::{GpuCapability, GpuDeviceInfo};
255 use crate::policy::GpuDispatchPolicy;
256
257 fn device_with(ordinal: usize, sm_count: i32, free_gib: f64) -> GpuDeviceInfo {
258 GpuDeviceInfo {
259 ordinal,
260 name: format!("synthetic-{ordinal}"),
261 capability: GpuCapability::from_compute_capability(7, 0),
262 sm_count,
263 max_threads_per_sm: 2048,
264 max_shared_mem_per_block: 49_152,
265 l2_cache_bytes: 6 * 1024 * 1024,
266 total_mem_bytes: (free_gib as usize) * 1_073_741_824 * 2,
267 free_mem_bytes: (free_gib * 1_073_741_824.0) as usize,
268 ecc_enabled: false,
269 integrated: false,
270 mig_mode: false,
271 }
272 }
273
274 fn runtime_with(devices: Vec<GpuDeviceInfo>) -> GpuRuntime {
275 let device = devices
276 .first()
277 .cloned()
278 .expect("test runtime needs ≥1 device");
279 let memory_budget_bytes = device.free_mem_bytes.min(device.total_mem_bytes / 2);
280 GpuRuntime {
281 device,
282 devices,
283 policy: GpuDispatchPolicy::default(),
284 memory_budget_bytes,
285 }
286 }
287
288 fn assert_covers(tiles: &[(usize, std::ops::Range<usize>)], n_units: usize) {
291 let mut cursor = 0usize;
292 for (_, range) in tiles {
293 assert_eq!(range.start, cursor, "tile gap/overlap at {cursor}");
294 assert!(range.end > range.start, "empty tile emitted");
295 cursor = range.end;
296 }
297 assert_eq!(cursor, n_units, "tiles must cover the whole range");
298 }
299
300 #[test]
301 fn single_device_one_full_tile() {
302 let rt = runtime_with(vec![device_with(0, 80, 16.0)]);
303 let tiles = balanced_partition(&rt, 100);
304 assert_eq!(tiles, vec![(0, 0..100)]);
305 }
306
307 #[test]
308 fn three_devices_even_split_when_scores_equal() {
309 let rt = runtime_with(vec![
312 device_with(0, 80, 16.0),
313 device_with(1, 80, 16.0),
314 device_with(2, 80, 16.0),
315 ]);
316 let tiles = balanced_partition(&rt, 99);
317 assert_eq!(
318 tiles,
319 vec![(0, 0..33), (1, 33..66), (2, 66..99)],
320 "equal scores must split evenly"
321 );
322 assert_covers(&tiles, 99);
323
324 let tiles = balanced_partition(&rt, 100);
326 assert_eq!(tiles, vec![(0, 0..34), (1, 34..67), (2, 67..100)]);
327 assert_covers(&tiles, 100);
328 }
329
330 #[test]
331 fn three_devices_weighted_by_unequal_scores() {
332 let devices = vec![
335 device_with(0, 132, 40.0),
336 device_with(1, 40, 8.0),
337 device_with(2, 40, 8.0),
338 ];
339 let rt = runtime_with(devices.clone());
340 let n_units = 1000;
341 let tiles = balanced_partition(&rt, n_units);
342 assert_covers(&tiles, n_units);
343 assert_eq!(tiles[0].0, 0);
345 let widths: Vec<usize> = tiles.iter().map(|(_, r)| r.end - r.start).collect();
346 assert!(
347 widths[0] > widths[1] && widths[0] > widths[2],
348 "highest-score device must get the largest tile, got {widths:?}"
349 );
350 assert_eq!(widths[1], widths[2]);
352 let total_score: f64 = devices.iter().map(GpuDeviceInfo::score).sum();
354 for (device, width) in devices.iter().zip(&widths) {
355 let ideal = device.score() / total_score * n_units as f64;
356 assert!(
357 (*width as f64 - ideal).abs() <= 1.0,
358 "width {width} not within 1 of ideal {ideal} for ordinal {}",
359 device.ordinal
360 );
361 }
362 }
363
364 #[test]
365 fn fewer_units_than_devices_drops_empty_tiles() {
366 let rt = runtime_with(vec![
369 device_with(0, 132, 40.0),
370 device_with(1, 100, 24.0),
371 device_with(2, 80, 16.0),
372 device_with(3, 60, 12.0),
373 device_with(4, 40, 8.0),
374 ]);
375 let tiles = balanced_partition(&rt, 2);
376 assert_covers(&tiles, 2);
377 assert_eq!(tiles.len(), 2, "one tile per unit when units < devices");
378 assert_eq!(tiles[0].0, 0, "highest-score device served first");
379 assert_eq!(tiles[1].0, 1);
380 }
381
382 #[test]
383 fn zero_units_yields_no_tiles() {
384 let rt = runtime_with(vec![device_with(0, 80, 16.0), device_with(1, 80, 16.0)]);
385 assert!(balanced_partition(&rt, 0).is_empty());
386 }
387
388 #[test]
389 fn device_ordinals_and_count_track_pool() {
390 let rt = runtime_with(vec![
391 device_with(0, 80, 16.0),
392 device_with(3, 80, 16.0),
393 device_with(5, 80, 16.0),
394 ]);
395 assert_eq!(rt.device_count(), 3);
396 assert_eq!(rt.device_ordinals(), vec![0, 3, 5]);
397 }
398
399 #[test]
400 fn memory_budget_for_caps_free_at_half_total() {
401 let rt = runtime_with(vec![device_with(0, 80, 8.0)]);
403 let gib = 1_073_741_824usize;
404 assert_eq!(rt.memory_budget_for(0), 8 * gib);
405 assert_eq!(rt.memory_budget_for(99), rt.memory_budget_bytes);
407 }
408}