Skip to main content

gam_gpu/
pool.rs

1//! Multi-GPU device pool.
2//!
3//! The runtime probe (`super::device_runtime`) already discovers every usable CUDA
4//! device into `GpuRuntime::devices` (sorted by [`GpuDeviceInfo::score`] desc),
5//! but every dispatch path historically pinned its work to the single primary
6//! `GpuRuntime::device`. This module turns that pool into usable parallelism:
7//!
8//!   * [`GpuRuntime::device_ordinals`] / [`GpuRuntime::device_count`] expose the
9//!     full set of usable ordinals (highest-score first).
10//!   * [`GpuRuntime::memory_budget_for`] gives a per-device byte budget so each
11//!     tile can size its device buffers against the device it actually runs on.
12//!   * [`balanced_partition`] splits `n` independent work items across the pool
13//!     weighted by each device's [`GpuDeviceInfo::score`].
14//!   * [`scatter_batched`] runs an independent-per-item closure across every
15//!     device concurrently, binding each ordinal's context on its own worker
16//!     thread.
17//!
18//! ## Concurrency model
19//!
20//! Per-device fan-out uses [`std::thread::scope`], **not** rayon. A rayon
21//! `par_iter` worker that reaches a `OnceLock::get_or_init` whose closure itself
22//! does `into_par_iter` deadlocks the whole process (team-known hazard), and the
23//! cudarc context cache (`cuda_context_for`) is exactly such a lazily-initialized
24//! global. Scoped OS threads sidestep that entirely: each worker calls
25//! `ctx.bind_to_thread()` for its ordinal before issuing any CUDA work, so the
26//! thread-local current context is correct for every kernel launched on it.
27
28use super::device::GpuDeviceInfo;
29use super::device_runtime::GpuRuntime;
30
31impl GpuRuntime {
32    /// Ordinals of all usable devices, highest-score first.
33    ///
34    /// `self.devices` is already score-sorted at probe time, so this simply
35    /// projects out the ordinals. Empty only if the probe somehow produced no
36    /// devices (the public `probe()` guarantees at least one on `Ok(Some(_))`).
37    #[must_use]
38    pub fn device_ordinals(&self) -> Vec<usize> {
39        self.devices.iter().map(|device| device.ordinal).collect()
40    }
41
42    /// Number of usable devices in the pool.
43    #[must_use]
44    pub fn device_count(&self) -> usize {
45        self.devices.len()
46    }
47
48    /// Per-device byte budget: free memory capped at half of total, matching the
49    /// primary-device budget computed in `device_runtime::probe`. Falls back to the
50    /// primary `memory_budget_bytes` when the ordinal is not in the pool so a
51    /// caller that passes a stale ordinal still gets a usable (conservative)
52    /// budget rather than zero.
53    #[must_use]
54    pub fn memory_budget_for(&self, ordinal: usize) -> usize {
55        self.devices
56            .iter()
57            .find(|device| device.ordinal == ordinal)
58            .map_or(self.memory_budget_bytes, GpuDeviceInfo::memory_budget_bytes)
59    }
60}
61
62/// Partition `n_units` independent work items across all usable devices,
63/// weighted by [`GpuDeviceInfo::score`].
64///
65/// Returns `(ordinal, Range)` tiles that exactly cover `0..n_units` with no gaps
66/// or overlaps, largest-score device first. A single device yields one full-span
67/// tile. `n_units == 0` or no GPU yields an empty `Vec`.
68///
69/// Allocation is largest-remainder by score: each device's ideal share is
70/// `score_i / Σscore · n_units`; floors are assigned first, then the remaining
71/// units (from rounding) go to the devices with the largest fractional parts.
72/// This keeps the split proportional to capability while guaranteeing the tiles
73/// tile the whole range. Devices that round to a zero-width tile are dropped so
74/// no worker is spawned for empty work.
75#[must_use]
76pub fn balanced_partition(rt: &GpuRuntime, n_units: usize) -> Vec<(usize, std::ops::Range<usize>)> {
77    if n_units == 0 || rt.devices.is_empty() {
78        return Vec::new();
79    }
80    if rt.devices.len() == 1 {
81        return vec![(rt.devices[0].ordinal, 0..n_units)];
82    }
83
84    let scores: Vec<f64> = rt
85        .devices
86        .iter()
87        .map(|device| device.score().max(0.0))
88        .collect();
89    let total_score: f64 = scores.iter().sum();
90
91    // Degenerate weighting (all scores zero/non-finite): fall back to an even
92    // split so we never collapse the whole batch onto one device.
93    let even = !(total_score.is_finite() && total_score > 0.0);
94
95    let n = n_units as f64;
96    let mut counts: Vec<usize> = Vec::with_capacity(rt.devices.len());
97    let mut remainders: Vec<(usize, f64)> = Vec::with_capacity(rt.devices.len());
98    let mut assigned = 0usize;
99    for (idx, score) in scores.iter().enumerate() {
100        let ideal = if even {
101            n / rt.devices.len() as f64
102        } else {
103            n * score / total_score
104        };
105        let floor = ideal.floor();
106        let count = floor as usize;
107        counts.push(count);
108        assigned += count;
109        remainders.push((idx, ideal - floor));
110    }
111
112    // Distribute the leftover units (from flooring) to the largest fractional
113    // remainders, breaking ties toward the higher-score (earlier) device.
114    let mut leftover = n_units.saturating_sub(assigned);
115    if leftover > 0 {
116        remainders.sort_by(|a, b| {
117            b.1.partial_cmp(&a.1)
118                .unwrap_or(std::cmp::Ordering::Equal)
119                .then(a.0.cmp(&b.0))
120        });
121        for (idx, _) in &remainders {
122            if leftover == 0 {
123                break;
124            }
125            counts[*idx] += 1;
126            leftover -= 1;
127        }
128    }
129
130    let mut tiles = Vec::with_capacity(rt.devices.len());
131    let mut start = 0usize;
132    for (idx, device) in rt.devices.iter().enumerate() {
133        let count = counts[idx];
134        if count == 0 {
135            continue;
136        }
137        let end = start + count;
138        tiles.push((device.ordinal, start..end));
139        start = end;
140    }
141    assert_eq!(start, n_units, "balanced_partition tiles must cover 0..n");
142    tiles
143}
144
145/// Run independent work across ALL devices concurrently.
146///
147/// Splits `items` via [`balanced_partition`]; each tile runs on its own
148/// [`std::thread::scope`] thread that binds that ordinal's context
149/// (`cuda_context_for(ordinal).bind_to_thread()`) before calling
150/// `f(ordinal, &mut items[range])`. Returns `Some(())` only if EVERY tile's
151/// closure returned `Some(())`; if any tile fails, panics, or a context cannot
152/// be bound, returns `None` so the caller can run its deterministic whole-batch
153/// CPU fallback over the (still untouched-by-a-successful-result) `items`.
154///
155/// Non-linux builds have no CUDA contexts to bind and so always return `None`.
156#[cfg(target_os = "linux")]
157#[must_use]
158pub fn scatter_batched<T: Send>(
159    rt: &GpuRuntime,
160    items: &mut [T],
161    f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
162) -> Option<()> {
163    let n_units = items.len();
164    let tiles = balanced_partition(rt, n_units);
165    if tiles.is_empty() {
166        return None;
167    }
168
169    // Carve `items` into disjoint mutable sub-slices matching the tiles so each
170    // worker thread owns its range exclusively (no aliasing across threads).
171    let mut slices: Vec<(usize, &mut [T])> = Vec::with_capacity(tiles.len());
172    let mut rest = items;
173    let mut consumed = 0usize;
174    for (ordinal, range) in &tiles {
175        let take = range.end - consumed;
176        let (head, tail) = rest.split_at_mut(take);
177        slices.push((*ordinal, head));
178        rest = tail;
179        consumed = range.end;
180    }
181
182    let f = &f;
183    std::thread::scope(|scope| {
184        let handles: Vec<_> = slices
185            .into_iter()
186            .map(|(ordinal, slice)| {
187                scope.spawn(move || {
188                    // Bind this ordinal's cached context on this worker thread so
189                    // every CUDA launch the closure issues targets `ordinal`.
190                    let ctx = super::device_runtime::cuda_context_for(ordinal)?;
191                    ctx.bind_to_thread().ok()?;
192                    f(ordinal, slice)
193                })
194            })
195            .collect();
196
197        // A panicking worker yields `Err` from `join`; treat it like a tile
198        // failure so the caller falls back to CPU for the whole batch.
199        let mut all_ok = true;
200        for handle in handles {
201            match handle.join() {
202                Ok(Some(())) => {}
203                _ => all_ok = false,
204            }
205        }
206        if all_ok { Some(()) } else { None }
207    })
208}
209
210/// Non-linux `scatter_batched`: there are no CUDA contexts to bind off Linux,
211/// so device fan-out is unavailable.
212///
213/// This must exist on every target, not just Linux: not every caller is inside
214/// `#[cfg(target_os = "linux")]` — the SAE manifold per-atom Gram/smoothness
215/// scatters (`src/terms/sae/manifold/mod.rs`) call it from platform-independent
216/// code. At runtime off Linux `GpuRuntime::global()` returns `None`, so the
217/// `Some(rt)` branch that reaches here is never taken; the body only needs to
218/// compile and honour the contract. `balanced_partition` yields no tiles when
219/// the runtime has no devices, so this reports `None` and the caller runs its
220/// deterministic whole-batch CPU fallback. The per-tile invocation is kept so
221/// the contract is honoured verbatim if a non-Linux backend ever exposes
222/// devices: each tile's closure runs over its own disjoint sub-slice, with no
223/// device binding to perform on this platform (the only step the Linux path
224/// adds).
225#[cfg(not(target_os = "linux"))]
226#[must_use]
227pub fn scatter_batched<T: Send>(
228    rt: &GpuRuntime,
229    items: &mut [T],
230    f: impl Fn(usize, &mut [T]) -> Option<()> + Sync,
231) -> Option<()> {
232    let tiles = balanced_partition(rt, items.len());
233    if tiles.is_empty() {
234        return None;
235    }
236    let mut rest = items;
237    let mut consumed = 0usize;
238    let mut all_ok = true;
239    for (ordinal, range) in &tiles {
240        let take = range.end - consumed;
241        let (head, tail) = rest.split_at_mut(take);
242        if f(*ordinal, head).is_none() {
243            all_ok = false;
244        }
245        rest = tail;
246        consumed = range.end;
247    }
248    if all_ok { Some(()) } else { None }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::device::{GpuCapability, GpuDeviceInfo};
255    use crate::policy::GpuDispatchPolicy;
256
257    fn device_with(ordinal: usize, sm_count: i32, free_gib: f64) -> GpuDeviceInfo {
258        GpuDeviceInfo {
259            ordinal,
260            name: format!("synthetic-{ordinal}"),
261            capability: GpuCapability::from_compute_capability(7, 0),
262            sm_count,
263            max_threads_per_sm: 2048,
264            max_shared_mem_per_block: 49_152,
265            l2_cache_bytes: 6 * 1024 * 1024,
266            total_mem_bytes: (free_gib as usize) * 1_073_741_824 * 2,
267            free_mem_bytes: (free_gib * 1_073_741_824.0) as usize,
268            ecc_enabled: false,
269            integrated: false,
270            mig_mode: false,
271        }
272    }
273
274    fn runtime_with(devices: Vec<GpuDeviceInfo>) -> GpuRuntime {
275        let device = devices
276            .first()
277            .cloned()
278            .expect("test runtime needs ≥1 device");
279        let memory_budget_bytes = device.free_mem_bytes.min(device.total_mem_bytes / 2);
280        GpuRuntime {
281            device,
282            devices,
283            policy: GpuDispatchPolicy::default(),
284            memory_budget_bytes,
285        }
286    }
287
288    /// Tiles must exactly tile `0..n_units`: contiguous, gap-free, no overlap,
289    /// and ordered largest-score-device first.
290    fn assert_covers(tiles: &[(usize, std::ops::Range<usize>)], n_units: usize) {
291        let mut cursor = 0usize;
292        for (_, range) in tiles {
293            assert_eq!(range.start, cursor, "tile gap/overlap at {cursor}");
294            assert!(range.end > range.start, "empty tile emitted");
295            cursor = range.end;
296        }
297        assert_eq!(cursor, n_units, "tiles must cover the whole range");
298    }
299
300    #[test]
301    fn single_device_one_full_tile() {
302        let rt = runtime_with(vec![device_with(0, 80, 16.0)]);
303        let tiles = balanced_partition(&rt, 100);
304        assert_eq!(tiles, vec![(0, 0..100)]);
305    }
306
307    #[test]
308    fn three_devices_even_split_when_scores_equal() {
309        // Identical devices → identical scores → even split (largest-remainder
310        // pushes the +1 toward the earliest devices for the rounding leftover).
311        let rt = runtime_with(vec![
312            device_with(0, 80, 16.0),
313            device_with(1, 80, 16.0),
314            device_with(2, 80, 16.0),
315        ]);
316        let tiles = balanced_partition(&rt, 99);
317        assert_eq!(
318            tiles,
319            vec![(0, 0..33), (1, 33..66), (2, 66..99)],
320            "equal scores must split evenly"
321        );
322        assert_covers(&tiles, 99);
323
324        // 100 units across 3 equal devices: 34/33/33, extra unit to the first.
325        let tiles = balanced_partition(&rt, 100);
326        assert_eq!(tiles, vec![(0, 0..34), (1, 34..67), (2, 67..100)]);
327        assert_covers(&tiles, 100);
328    }
329
330    #[test]
331    fn three_devices_weighted_by_unequal_scores() {
332        // Device 0 has far more SMs and memory than 1 and 2, so it must take a
333        // strictly larger tile; the split stays proportional and tiling holds.
334        let devices = vec![
335            device_with(0, 132, 40.0),
336            device_with(1, 40, 8.0),
337            device_with(2, 40, 8.0),
338        ];
339        let rt = runtime_with(devices.clone());
340        let n_units = 1000;
341        let tiles = balanced_partition(&rt, n_units);
342        assert_covers(&tiles, n_units);
343        // Highest-score device first and its tile is the largest.
344        assert_eq!(tiles[0].0, 0);
345        let widths: Vec<usize> = tiles.iter().map(|(_, r)| r.end - r.start).collect();
346        assert!(
347            widths[0] > widths[1] && widths[0] > widths[2],
348            "highest-score device must get the largest tile, got {widths:?}"
349        );
350        // Tiles 1 and 2 have equal scores → equal widths.
351        assert_eq!(widths[1], widths[2]);
352        // Proportionality: each width tracks score share within ±1 unit.
353        let total_score: f64 = devices.iter().map(GpuDeviceInfo::score).sum();
354        for (device, width) in devices.iter().zip(&widths) {
355            let ideal = device.score() / total_score * n_units as f64;
356            assert!(
357                (*width as f64 - ideal).abs() <= 1.0,
358                "width {width} not within 1 of ideal {ideal} for ordinal {}",
359                device.ordinal
360            );
361        }
362    }
363
364    #[test]
365    fn fewer_units_than_devices_drops_empty_tiles() {
366        // 2 units, 5 devices: only the 2 highest-score devices get a tile and
367        // no zero-width tile is emitted.
368        let rt = runtime_with(vec![
369            device_with(0, 132, 40.0),
370            device_with(1, 100, 24.0),
371            device_with(2, 80, 16.0),
372            device_with(3, 60, 12.0),
373            device_with(4, 40, 8.0),
374        ]);
375        let tiles = balanced_partition(&rt, 2);
376        assert_covers(&tiles, 2);
377        assert_eq!(tiles.len(), 2, "one tile per unit when units < devices");
378        assert_eq!(tiles[0].0, 0, "highest-score device served first");
379        assert_eq!(tiles[1].0, 1);
380    }
381
382    #[test]
383    fn zero_units_yields_no_tiles() {
384        let rt = runtime_with(vec![device_with(0, 80, 16.0), device_with(1, 80, 16.0)]);
385        assert!(balanced_partition(&rt, 0).is_empty());
386    }
387
388    #[test]
389    fn device_ordinals_and_count_track_pool() {
390        let rt = runtime_with(vec![
391            device_with(0, 80, 16.0),
392            device_with(3, 80, 16.0),
393            device_with(5, 80, 16.0),
394        ]);
395        assert_eq!(rt.device_count(), 3);
396        assert_eq!(rt.device_ordinals(), vec![0, 3, 5]);
397    }
398
399    #[test]
400    fn memory_budget_for_caps_free_at_half_total() {
401        // free = 8 GiB, total = 16 GiB → budget = min(8, 8) = 8 GiB.
402        let rt = runtime_with(vec![device_with(0, 80, 8.0)]);
403        let gib = 1_073_741_824usize;
404        assert_eq!(rt.memory_budget_for(0), 8 * gib);
405        // Unknown ordinal falls back to the primary budget rather than zero.
406        assert_eq!(rt.memory_budget_for(99), rt.memory_budget_bytes);
407    }
408}