Skip to main content

oximedia_gpu/
kernel_scheduler.rs

1//! GPU kernel scheduling simulation.
2//!
3//! Simulates the kernel dispatch pipeline found in modern GPU compute stacks.
4//! Key concepts modelled:
5//!
6//! * **Kernel dependency graph** – a directed acyclic graph where edges encode
7//!   "must finish before" relationships between kernels.
8//! * **Launch ordering** – topological ordering of the DAG that respects all
9//!   dependencies, choosing lexicographic tie-breaking for determinism.
10//! * **Occupancy estimation** – computes theoretical occupancy (0.0–1.0) from
11//!   active warps vs the SM warp limit.
12//! * **Warp utilisation** – tracks active vs stalled warps per kernel to
13//!   produce a utilisation metric.
14//!
15//! All structures are pure-Rust, CPU-side simulations that mirror GPU scheduler
16//! semantics without requiring actual GPU hardware.
17
18use std::collections::{BTreeMap, BTreeSet, VecDeque};
19use thiserror::Error;
20
21// ─── Error ────────────────────────────────────────────────────────────────────
22
23/// Errors returned by kernel scheduler operations.
24#[derive(Debug, Clone, PartialEq, Error)]
25pub enum SchedulerError {
26    /// A kernel with the specified ID does not exist in the graph.
27    #[error("Kernel not found: {0}")]
28    KernelNotFound(u32),
29    /// Adding the dependency edge would introduce a cycle.
30    #[error("Dependency would create a cycle between kernel {from} and kernel {to}")]
31    CyclicDependency { from: u32, to: u32 },
32    /// A kernel with this ID has already been registered.
33    #[error("Kernel already registered: {0}")]
34    DuplicateKernel(u32),
35    /// The graph contains a cycle (internal invariant violation).
36    #[error("Scheduler graph contains a cycle; cannot produce valid launch order")]
37    CycleDetected,
38    /// Requested warp count exceeds device limit.
39    #[error("Requested {requested} warps exceeds SM limit of {limit}")]
40    WarpLimitExceeded { requested: u32, limit: u32 },
41}
42
43// ─── KernelSpec ───────────────────────────────────────────────────────────────
44
45/// Specification for a single compute kernel.
46#[derive(Debug, Clone, PartialEq)]
47pub struct KernelSpec {
48    /// Unique kernel identifier within the scheduler.
49    pub id: u32,
50    /// Human-readable name (for profiling / debug output).
51    pub name: String,
52    /// Number of thread groups (work groups) to dispatch.
53    pub work_groups: u32,
54    /// Threads per work group.
55    pub threads_per_group: u32,
56    /// Estimated execution time in microseconds (for scheduling heuristics).
57    pub estimated_us: u64,
58}
59
60impl KernelSpec {
61    /// Construct a new `KernelSpec`.
62    #[must_use]
63    pub fn new(
64        id: u32,
65        name: impl Into<String>,
66        work_groups: u32,
67        threads_per_group: u32,
68        estimated_us: u64,
69    ) -> Self {
70        Self {
71            id,
72            name: name.into(),
73            work_groups,
74            threads_per_group,
75            estimated_us,
76        }
77    }
78
79    /// Total number of threads this kernel launches.
80    #[must_use]
81    pub fn total_threads(&self) -> u64 {
82        u64::from(self.work_groups) * u64::from(self.threads_per_group)
83    }
84}
85
86// ─── OccupancyEstimate ────────────────────────────────────────────────────────
87
88/// Occupancy estimate for a single kernel on a given SM configuration.
89#[derive(Debug, Clone)]
90pub struct OccupancyEstimate {
91    /// Fraction of SM warp slots that would be active (0.0 – 1.0).
92    pub theoretical_occupancy: f32,
93    /// Number of warps the kernel uses.
94    pub active_warps: u32,
95    /// Maximum warps the SM can hold concurrently.
96    pub max_warps: u32,
97}
98
99impl OccupancyEstimate {
100    /// Compute occupancy for `kernel` on an SM with `sm_warp_limit` warp slots.
101    ///
102    /// Warp count is derived from `threads_per_group / warp_size` (rounded up),
103    /// multiplied by `work_groups` (capped at `sm_warp_limit`).
104    ///
105    /// `warp_size` is typically 32 on NVIDIA hardware; 64 on AMD.
106    #[must_use]
107    pub fn compute(kernel: &KernelSpec, sm_warp_limit: u32, warp_size: u32) -> Self {
108        let warp_size = warp_size.max(1);
109        let warps_per_group = (kernel.threads_per_group + warp_size - 1) / warp_size;
110        let active_warps = (warps_per_group * kernel.work_groups).min(sm_warp_limit);
111        let max_warps = sm_warp_limit.max(1);
112        let theoretical_occupancy = active_warps as f32 / max_warps as f32;
113        Self {
114            theoretical_occupancy: theoretical_occupancy.clamp(0.0, 1.0),
115            active_warps,
116            max_warps,
117        }
118    }
119}
120
121// ─── WarpStats ────────────────────────────────────────────────────────────────
122
123/// Per-kernel warp utilisation statistics gathered after (simulated) execution.
124#[derive(Debug, Clone)]
125pub struct WarpStats {
126    /// Kernel identifier this record belongs to.
127    pub kernel_id: u32,
128    /// Number of warps actively issuing instructions during the kernel.
129    pub active_warps: u32,
130    /// Number of warps stalled (waiting on memory / barriers).
131    pub stalled_warps: u32,
132    /// Warp utilisation: `active / (active + stalled)`.
133    pub utilisation: f32,
134}
135
136impl WarpStats {
137    /// Build `WarpStats` from active and stalled warp counts.
138    ///
139    /// `utilisation` is 0.0 when both counts are zero.
140    #[must_use]
141    pub fn new(kernel_id: u32, active_warps: u32, stalled_warps: u32) -> Self {
142        let total = active_warps + stalled_warps;
143        let utilisation = if total == 0 {
144            0.0
145        } else {
146            active_warps as f32 / total as f32
147        };
148        Self {
149            kernel_id,
150            active_warps,
151            stalled_warps,
152            utilisation,
153        }
154    }
155}
156
157// ─── KernelScheduler ──────────────────────────────────────────────────────────
158
159/// Kernel dependency graph and launch-order scheduler.
160///
161/// Kernels are registered via [`add_kernel`] and dependencies added via
162/// [`add_dependency`].  Once the graph is complete, [`launch_order`] returns
163/// a topological ordering that satisfies all constraints.
164///
165/// [`add_kernel`]: KernelScheduler::add_kernel
166/// [`add_dependency`]: KernelScheduler::add_dependency
167/// [`launch_order`]: KernelScheduler::launch_order
168pub struct KernelScheduler {
169    /// All registered kernels, keyed by their ID.
170    kernels: BTreeMap<u32, KernelSpec>,
171    /// Adjacency list: `deps[id]` = set of kernel IDs that `id` depends on.
172    /// An edge `a → b` means "kernel `a` must wait for kernel `b`".
173    deps: BTreeMap<u32, BTreeSet<u32>>,
174    /// Reverse adjacency: `rdeps[b]` = kernels that depend on `b`.
175    rdeps: BTreeMap<u32, BTreeSet<u32>>,
176}
177
178impl KernelScheduler {
179    /// Create an empty scheduler.
180    #[must_use]
181    pub fn new() -> Self {
182        Self {
183            kernels: BTreeMap::new(),
184            deps: BTreeMap::new(),
185            rdeps: BTreeMap::new(),
186        }
187    }
188
189    /// Register a kernel with the scheduler.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`SchedulerError::DuplicateKernel`] if a kernel with the same ID
194    /// has already been registered.
195    pub fn add_kernel(&mut self, spec: KernelSpec) -> Result<(), SchedulerError> {
196        if self.kernels.contains_key(&spec.id) {
197            return Err(SchedulerError::DuplicateKernel(spec.id));
198        }
199        let id = spec.id;
200        self.kernels.insert(id, spec);
201        self.deps.entry(id).or_default();
202        self.rdeps.entry(id).or_default();
203        Ok(())
204    }
205
206    /// Declare that kernel `dependent` must not start until kernel `dependency`
207    /// has finished.
208    ///
209    /// # Errors
210    ///
211    /// * [`SchedulerError::KernelNotFound`] if either ID is not registered.
212    /// * [`SchedulerError::CyclicDependency`] if the edge would introduce a cycle.
213    pub fn add_dependency(
214        &mut self,
215        dependent: u32,
216        dependency: u32,
217    ) -> Result<(), SchedulerError> {
218        if !self.kernels.contains_key(&dependent) {
219            return Err(SchedulerError::KernelNotFound(dependent));
220        }
221        if !self.kernels.contains_key(&dependency) {
222            return Err(SchedulerError::KernelNotFound(dependency));
223        }
224        // Check for cycle: would `dependency` become reachable from itself
225        // through `dependent`?  i.e. is `dependency` an ancestor of `dependent`
226        // already (which means adding dep→dependent creates a cycle)?
227        if self.is_reachable(dependency, dependent) {
228            return Err(SchedulerError::CyclicDependency {
229                from: dependent,
230                to: dependency,
231            });
232        }
233        self.deps.entry(dependent).or_default().insert(dependency);
234        self.rdeps.entry(dependency).or_default().insert(dependent);
235        Ok(())
236    }
237
238    /// Return the IDs of all direct dependencies of `kernel_id`.
239    ///
240    /// # Errors
241    ///
242    /// Returns [`SchedulerError::KernelNotFound`] if the ID is not registered.
243    pub fn dependencies_of(&self, kernel_id: u32) -> Result<Vec<u32>, SchedulerError> {
244        if !self.kernels.contains_key(&kernel_id) {
245            return Err(SchedulerError::KernelNotFound(kernel_id));
246        }
247        let empty = BTreeSet::new();
248        let set = self.deps.get(&kernel_id).unwrap_or(&empty);
249        Ok(set.iter().copied().collect())
250    }
251
252    /// Compute a valid topological launch order for all registered kernels.
253    ///
254    /// Uses Kahn's algorithm with a min-heap (via `BTreeSet`) for deterministic
255    /// output: among ready kernels, the one with the smallest ID is picked first.
256    ///
257    /// # Errors
258    ///
259    /// Returns [`SchedulerError::CycleDetected`] if the graph contains a cycle
260    /// (which should not happen if [`add_dependency`] correctly enforces the
261    /// acyclicity invariant, but is checked defensively here).
262    ///
263    /// [`add_dependency`]: KernelScheduler::add_dependency
264    pub fn launch_order(&self) -> Result<Vec<u32>, SchedulerError> {
265        // in-degree for each kernel
266        let mut in_degree: BTreeMap<u32, usize> = self
267            .kernels
268            .keys()
269            .map(|&id| (id, self.deps[&id].len()))
270            .collect();
271
272        // Seeds: kernels with no dependencies.
273        let mut ready: BTreeSet<u32> = in_degree
274            .iter()
275            .filter_map(|(&id, &deg)| if deg == 0 { Some(id) } else { None })
276            .collect();
277
278        let mut order = Vec::with_capacity(self.kernels.len());
279
280        while let Some(&next) = ready.iter().next() {
281            ready.remove(&next);
282            order.push(next);
283            // Reduce in-degree of kernels that depend on `next`.
284            if let Some(dependents) = self.rdeps.get(&next) {
285                for &dep in dependents {
286                    let deg = in_degree.entry(dep).or_insert(0);
287                    *deg = deg.saturating_sub(1);
288                    if *deg == 0 {
289                        ready.insert(dep);
290                    }
291                }
292            }
293        }
294
295        if order.len() != self.kernels.len() {
296            return Err(SchedulerError::CycleDetected);
297        }
298        Ok(order)
299    }
300
301    /// Compute occupancy for a specific kernel.
302    ///
303    /// # Errors
304    ///
305    /// Returns [`SchedulerError::KernelNotFound`] if the ID is not registered.
306    pub fn occupancy(
307        &self,
308        kernel_id: u32,
309        sm_warp_limit: u32,
310        warp_size: u32,
311    ) -> Result<OccupancyEstimate, SchedulerError> {
312        let spec = self
313            .kernels
314            .get(&kernel_id)
315            .ok_or(SchedulerError::KernelNotFound(kernel_id))?;
316        Ok(OccupancyEstimate::compute(spec, sm_warp_limit, warp_size))
317    }
318
319    /// Simulate execution and return warp statistics for each kernel in launch
320    /// order.
321    ///
322    /// The simulation model:
323    /// * Active warps = `min(warps_per_group * work_groups, sm_warp_limit)`.
324    /// * Stalled warps = max(0, total_warps_launched − active_warps).
325    ///
326    /// # Errors
327    ///
328    /// Returns an error if a valid launch order cannot be produced.
329    pub fn simulate_warp_stats(
330        &self,
331        sm_warp_limit: u32,
332        warp_size: u32,
333    ) -> Result<Vec<WarpStats>, SchedulerError> {
334        let order = self.launch_order()?;
335        let warp_size = warp_size.max(1);
336        order
337            .iter()
338            .map(|&id| {
339                let spec = self
340                    .kernels
341                    .get(&id)
342                    .ok_or(SchedulerError::KernelNotFound(id))?;
343                let warps_per_group = (spec.threads_per_group + warp_size - 1) / warp_size;
344                let total_warps = warps_per_group * spec.work_groups;
345                let active = total_warps.min(sm_warp_limit);
346                let stalled = total_warps.saturating_sub(active);
347                Ok(WarpStats::new(id, active, stalled))
348            })
349            .collect()
350    }
351
352    /// Number of kernels registered in the scheduler.
353    #[must_use]
354    pub fn kernel_count(&self) -> usize {
355        self.kernels.len()
356    }
357
358    /// Retrieve the `KernelSpec` for a given ID, if registered.
359    #[must_use]
360    pub fn spec(&self, kernel_id: u32) -> Option<&KernelSpec> {
361        self.kernels.get(&kernel_id)
362    }
363
364    // ── Private helpers ───────────────────────────────────────────────────────
365
366    /// BFS/DFS reachability: can `target` be reached from `start` following
367    /// reverse-dependency edges (i.e. following "depends on" links)?
368    fn is_reachable(&self, start: u32, target: u32) -> bool {
369        if start == target {
370            return true;
371        }
372        let mut visited = BTreeSet::new();
373        let mut queue = VecDeque::new();
374        queue.push_back(start);
375        while let Some(current) = queue.pop_front() {
376            if visited.contains(&current) {
377                continue;
378            }
379            visited.insert(current);
380            if let Some(deps) = self.deps.get(&current) {
381                for &d in deps {
382                    if d == target {
383                        return true;
384                    }
385                    queue.push_back(d);
386                }
387            }
388        }
389        false
390    }
391}
392
393impl Default for KernelScheduler {
394    fn default() -> Self {
395        Self::new()
396    }
397}
398
399// ─── Tests ───────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    fn make_spec(id: u32, work_groups: u32, threads: u32) -> KernelSpec {
406        KernelSpec::new(id, format!("kernel_{id}"), work_groups, threads, 100)
407    }
408
409    // ── KernelSpec ────────────────────────────────────────────────────────────
410
411    #[test]
412    fn test_kernel_spec_total_threads() {
413        let spec = make_spec(1, 4, 64);
414        assert_eq!(spec.total_threads(), 256);
415    }
416
417    #[test]
418    fn test_kernel_spec_zero_work_groups() {
419        let spec = make_spec(2, 0, 64);
420        assert_eq!(spec.total_threads(), 0);
421    }
422
423    // ── OccupancyEstimate ─────────────────────────────────────────────────────
424
425    #[test]
426    fn test_occupancy_full() {
427        let spec = make_spec(1, 8, 256); // 8 warps per group (256/32), 8 groups → 64 warps
428        let est = OccupancyEstimate::compute(&spec, 64, 32);
429        assert_eq!(est.active_warps, 64);
430        assert_eq!(est.max_warps, 64);
431        assert!((est.theoretical_occupancy - 1.0).abs() < 1e-6);
432    }
433
434    #[test]
435    fn test_occupancy_capped_at_sm_limit() {
436        let spec = make_spec(1, 100, 1024); // many warps — exceeds SM limit
437        let est = OccupancyEstimate::compute(&spec, 64, 32);
438        assert_eq!(est.active_warps, 64);
439        assert!((est.theoretical_occupancy - 1.0).abs() < 1e-6);
440    }
441
442    #[test]
443    fn test_occupancy_partial() {
444        let spec = make_spec(1, 2, 64); // 2 warps per group, 2 groups → 4 warps
445        let est = OccupancyEstimate::compute(&spec, 32, 32);
446        assert_eq!(est.active_warps, 4);
447        assert!((est.theoretical_occupancy - 4.0 / 32.0).abs() < 1e-6);
448    }
449
450    // ── WarpStats ─────────────────────────────────────────────────────────────
451
452    #[test]
453    fn test_warp_stats_utilisation_all_active() {
454        let ws = WarpStats::new(1, 32, 0);
455        assert!((ws.utilisation - 1.0).abs() < 1e-6);
456    }
457
458    #[test]
459    fn test_warp_stats_utilisation_half() {
460        let ws = WarpStats::new(2, 16, 16);
461        assert!((ws.utilisation - 0.5).abs() < 1e-6);
462    }
463
464    #[test]
465    fn test_warp_stats_zero_warps() {
466        let ws = WarpStats::new(3, 0, 0);
467        assert_eq!(ws.utilisation, 0.0);
468    }
469
470    // ── KernelScheduler – add / basic queries ─────────────────────────────────
471
472    #[test]
473    fn test_add_kernel_and_count() {
474        let mut sched = KernelScheduler::new();
475        sched.add_kernel(make_spec(1, 4, 64)).unwrap();
476        sched.add_kernel(make_spec(2, 4, 64)).unwrap();
477        assert_eq!(sched.kernel_count(), 2);
478    }
479
480    #[test]
481    fn test_add_duplicate_kernel_error() {
482        let mut sched = KernelScheduler::new();
483        sched.add_kernel(make_spec(1, 4, 64)).unwrap();
484        let err = sched.add_kernel(make_spec(1, 8, 128));
485        assert!(matches!(err, Err(SchedulerError::DuplicateKernel(1))));
486    }
487
488    // ── launch_order ──────────────────────────────────────────────────────────
489
490    #[test]
491    fn test_launch_order_single_kernel() {
492        let mut sched = KernelScheduler::new();
493        sched.add_kernel(make_spec(7, 1, 64)).unwrap();
494        let order = sched.launch_order().unwrap();
495        assert_eq!(order, vec![7]);
496    }
497
498    #[test]
499    fn test_launch_order_linear_chain() {
500        // 1 → 2 → 3  (1 must run before 2, 2 before 3)
501        let mut sched = KernelScheduler::new();
502        for id in [1, 2, 3] {
503            sched.add_kernel(make_spec(id, 1, 64)).unwrap();
504        }
505        sched.add_dependency(2, 1).unwrap(); // 2 waits for 1
506        sched.add_dependency(3, 2).unwrap(); // 3 waits for 2
507        let order = sched.launch_order().unwrap();
508        assert_eq!(order, vec![1, 2, 3]);
509    }
510
511    #[test]
512    fn test_launch_order_diamond() {
513        // 1 → 2, 1 → 3, 2 → 4, 3 → 4
514        let mut sched = KernelScheduler::new();
515        for id in [1, 2, 3, 4] {
516            sched.add_kernel(make_spec(id, 1, 64)).unwrap();
517        }
518        sched.add_dependency(2, 1).unwrap();
519        sched.add_dependency(3, 1).unwrap();
520        sched.add_dependency(4, 2).unwrap();
521        sched.add_dependency(4, 3).unwrap();
522        let order = sched.launch_order().unwrap();
523        // 1 must be first, 4 must be last
524        assert_eq!(order[0], 1);
525        assert_eq!(order[3], 4);
526        // 2 and 3 must appear between them
527        assert!(order.contains(&2));
528        assert!(order.contains(&3));
529    }
530
531    #[test]
532    fn test_launch_order_independent_kernels_sorted_by_id() {
533        let mut sched = KernelScheduler::new();
534        for id in [5, 3, 1, 4, 2] {
535            sched.add_kernel(make_spec(id, 1, 64)).unwrap();
536        }
537        let order = sched.launch_order().unwrap();
538        assert_eq!(order, vec![1, 2, 3, 4, 5]);
539    }
540
541    // ── add_dependency errors ─────────────────────────────────────────────────
542
543    #[test]
544    fn test_add_dependency_unknown_dependent() {
545        let mut sched = KernelScheduler::new();
546        sched.add_kernel(make_spec(1, 1, 64)).unwrap();
547        let err = sched.add_dependency(99, 1);
548        assert!(matches!(err, Err(SchedulerError::KernelNotFound(99))));
549    }
550
551    #[test]
552    fn test_add_dependency_unknown_dependency() {
553        let mut sched = KernelScheduler::new();
554        sched.add_kernel(make_spec(1, 1, 64)).unwrap();
555        let err = sched.add_dependency(1, 99);
556        assert!(matches!(err, Err(SchedulerError::KernelNotFound(99))));
557    }
558
559    #[test]
560    fn test_add_dependency_cycle_detected() {
561        let mut sched = KernelScheduler::new();
562        sched.add_kernel(make_spec(1, 1, 64)).unwrap();
563        sched.add_kernel(make_spec(2, 1, 64)).unwrap();
564        sched.add_dependency(2, 1).unwrap(); // 2 waits for 1
565                                             // Trying to make 1 wait for 2 would create a cycle.
566        let err = sched.add_dependency(1, 2);
567        assert!(matches!(err, Err(SchedulerError::CyclicDependency { .. })));
568    }
569
570    // ── occupancy via scheduler ───────────────────────────────────────────────
571
572    #[test]
573    fn test_scheduler_occupancy() {
574        let mut sched = KernelScheduler::new();
575        sched.add_kernel(make_spec(1, 4, 128)).unwrap(); // 4 warps/group, 4 groups → 16 warps
576        let est = sched.occupancy(1, 64, 32).unwrap();
577        assert_eq!(est.active_warps, 16);
578    }
579
580    #[test]
581    fn test_scheduler_occupancy_unknown_kernel() {
582        let sched = KernelScheduler::new();
583        let err = sched.occupancy(42, 64, 32);
584        assert!(matches!(err, Err(SchedulerError::KernelNotFound(42))));
585    }
586
587    // ── simulate_warp_stats ───────────────────────────────────────────────────
588
589    #[test]
590    fn test_simulate_warp_stats_basic() {
591        let mut sched = KernelScheduler::new();
592        sched.add_kernel(make_spec(1, 2, 64)).unwrap(); // 4 warps total
593        sched.add_kernel(make_spec(2, 1, 64)).unwrap(); // 2 warps total
594        sched.add_dependency(2, 1).unwrap();
595        let stats = sched.simulate_warp_stats(32, 32).unwrap();
596        assert_eq!(stats.len(), 2);
597        assert_eq!(stats[0].kernel_id, 1);
598        assert_eq!(stats[1].kernel_id, 2);
599    }
600
601    #[test]
602    fn test_simulate_warp_stats_overflow_clamps() {
603        let mut sched = KernelScheduler::new();
604        // 1000 work groups × 256 threads/group → 8000 warps; SM limit = 64
605        sched.add_kernel(make_spec(1, 1000, 256)).unwrap();
606        let stats = sched.simulate_warp_stats(64, 32).unwrap();
607        assert_eq!(stats[0].active_warps, 64);
608        assert!(stats[0].stalled_warps > 0);
609        assert!(stats[0].utilisation < 1.0 || stats[0].stalled_warps == 0);
610    }
611
612    // ── dependencies_of ───────────────────────────────────────────────────────
613
614    #[test]
615    fn test_dependencies_of() {
616        let mut sched = KernelScheduler::new();
617        for id in [1, 2, 3] {
618            sched.add_kernel(make_spec(id, 1, 64)).unwrap();
619        }
620        sched.add_dependency(3, 1).unwrap();
621        sched.add_dependency(3, 2).unwrap();
622        let mut deps = sched.dependencies_of(3).unwrap();
623        deps.sort_unstable();
624        assert_eq!(deps, vec![1, 2]);
625    }
626
627    #[test]
628    fn test_dependencies_of_no_deps() {
629        let mut sched = KernelScheduler::new();
630        sched.add_kernel(make_spec(1, 1, 64)).unwrap();
631        let deps = sched.dependencies_of(1).unwrap();
632        assert!(deps.is_empty());
633    }
634}