1use std::collections::{BTreeMap, BTreeSet, VecDeque};
19use thiserror::Error;
20
21#[derive(Debug, Clone, PartialEq, Error)]
25pub enum SchedulerError {
26 #[error("Kernel not found: {0}")]
28 KernelNotFound(u32),
29 #[error("Dependency would create a cycle between kernel {from} and kernel {to}")]
31 CyclicDependency { from: u32, to: u32 },
32 #[error("Kernel already registered: {0}")]
34 DuplicateKernel(u32),
35 #[error("Scheduler graph contains a cycle; cannot produce valid launch order")]
37 CycleDetected,
38 #[error("Requested {requested} warps exceeds SM limit of {limit}")]
40 WarpLimitExceeded { requested: u32, limit: u32 },
41}
42
43#[derive(Debug, Clone, PartialEq)]
47pub struct KernelSpec {
48 pub id: u32,
50 pub name: String,
52 pub work_groups: u32,
54 pub threads_per_group: u32,
56 pub estimated_us: u64,
58}
59
60impl KernelSpec {
61 #[must_use]
63 pub fn new(
64 id: u32,
65 name: impl Into<String>,
66 work_groups: u32,
67 threads_per_group: u32,
68 estimated_us: u64,
69 ) -> Self {
70 Self {
71 id,
72 name: name.into(),
73 work_groups,
74 threads_per_group,
75 estimated_us,
76 }
77 }
78
79 #[must_use]
81 pub fn total_threads(&self) -> u64 {
82 u64::from(self.work_groups) * u64::from(self.threads_per_group)
83 }
84}
85
86#[derive(Debug, Clone)]
90pub struct OccupancyEstimate {
91 pub theoretical_occupancy: f32,
93 pub active_warps: u32,
95 pub max_warps: u32,
97}
98
99impl OccupancyEstimate {
100 #[must_use]
107 pub fn compute(kernel: &KernelSpec, sm_warp_limit: u32, warp_size: u32) -> Self {
108 let warp_size = warp_size.max(1);
109 let warps_per_group = (kernel.threads_per_group + warp_size - 1) / warp_size;
110 let active_warps = (warps_per_group * kernel.work_groups).min(sm_warp_limit);
111 let max_warps = sm_warp_limit.max(1);
112 let theoretical_occupancy = active_warps as f32 / max_warps as f32;
113 Self {
114 theoretical_occupancy: theoretical_occupancy.clamp(0.0, 1.0),
115 active_warps,
116 max_warps,
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
125pub struct WarpStats {
126 pub kernel_id: u32,
128 pub active_warps: u32,
130 pub stalled_warps: u32,
132 pub utilisation: f32,
134}
135
136impl WarpStats {
137 #[must_use]
141 pub fn new(kernel_id: u32, active_warps: u32, stalled_warps: u32) -> Self {
142 let total = active_warps + stalled_warps;
143 let utilisation = if total == 0 {
144 0.0
145 } else {
146 active_warps as f32 / total as f32
147 };
148 Self {
149 kernel_id,
150 active_warps,
151 stalled_warps,
152 utilisation,
153 }
154 }
155}
156
157pub struct KernelScheduler {
169 kernels: BTreeMap<u32, KernelSpec>,
171 deps: BTreeMap<u32, BTreeSet<u32>>,
174 rdeps: BTreeMap<u32, BTreeSet<u32>>,
176}
177
178impl KernelScheduler {
179 #[must_use]
181 pub fn new() -> Self {
182 Self {
183 kernels: BTreeMap::new(),
184 deps: BTreeMap::new(),
185 rdeps: BTreeMap::new(),
186 }
187 }
188
189 pub fn add_kernel(&mut self, spec: KernelSpec) -> Result<(), SchedulerError> {
196 if self.kernels.contains_key(&spec.id) {
197 return Err(SchedulerError::DuplicateKernel(spec.id));
198 }
199 let id = spec.id;
200 self.kernels.insert(id, spec);
201 self.deps.entry(id).or_default();
202 self.rdeps.entry(id).or_default();
203 Ok(())
204 }
205
206 pub fn add_dependency(
214 &mut self,
215 dependent: u32,
216 dependency: u32,
217 ) -> Result<(), SchedulerError> {
218 if !self.kernels.contains_key(&dependent) {
219 return Err(SchedulerError::KernelNotFound(dependent));
220 }
221 if !self.kernels.contains_key(&dependency) {
222 return Err(SchedulerError::KernelNotFound(dependency));
223 }
224 if self.is_reachable(dependency, dependent) {
228 return Err(SchedulerError::CyclicDependency {
229 from: dependent,
230 to: dependency,
231 });
232 }
233 self.deps.entry(dependent).or_default().insert(dependency);
234 self.rdeps.entry(dependency).or_default().insert(dependent);
235 Ok(())
236 }
237
238 pub fn dependencies_of(&self, kernel_id: u32) -> Result<Vec<u32>, SchedulerError> {
244 if !self.kernels.contains_key(&kernel_id) {
245 return Err(SchedulerError::KernelNotFound(kernel_id));
246 }
247 let empty = BTreeSet::new();
248 let set = self.deps.get(&kernel_id).unwrap_or(&empty);
249 Ok(set.iter().copied().collect())
250 }
251
252 pub fn launch_order(&self) -> Result<Vec<u32>, SchedulerError> {
265 let mut in_degree: BTreeMap<u32, usize> = self
267 .kernels
268 .keys()
269 .map(|&id| (id, self.deps[&id].len()))
270 .collect();
271
272 let mut ready: BTreeSet<u32> = in_degree
274 .iter()
275 .filter_map(|(&id, °)| if deg == 0 { Some(id) } else { None })
276 .collect();
277
278 let mut order = Vec::with_capacity(self.kernels.len());
279
280 while let Some(&next) = ready.iter().next() {
281 ready.remove(&next);
282 order.push(next);
283 if let Some(dependents) = self.rdeps.get(&next) {
285 for &dep in dependents {
286 let deg = in_degree.entry(dep).or_insert(0);
287 *deg = deg.saturating_sub(1);
288 if *deg == 0 {
289 ready.insert(dep);
290 }
291 }
292 }
293 }
294
295 if order.len() != self.kernels.len() {
296 return Err(SchedulerError::CycleDetected);
297 }
298 Ok(order)
299 }
300
301 pub fn occupancy(
307 &self,
308 kernel_id: u32,
309 sm_warp_limit: u32,
310 warp_size: u32,
311 ) -> Result<OccupancyEstimate, SchedulerError> {
312 let spec = self
313 .kernels
314 .get(&kernel_id)
315 .ok_or(SchedulerError::KernelNotFound(kernel_id))?;
316 Ok(OccupancyEstimate::compute(spec, sm_warp_limit, warp_size))
317 }
318
319 pub fn simulate_warp_stats(
330 &self,
331 sm_warp_limit: u32,
332 warp_size: u32,
333 ) -> Result<Vec<WarpStats>, SchedulerError> {
334 let order = self.launch_order()?;
335 let warp_size = warp_size.max(1);
336 order
337 .iter()
338 .map(|&id| {
339 let spec = self
340 .kernels
341 .get(&id)
342 .ok_or(SchedulerError::KernelNotFound(id))?;
343 let warps_per_group = (spec.threads_per_group + warp_size - 1) / warp_size;
344 let total_warps = warps_per_group * spec.work_groups;
345 let active = total_warps.min(sm_warp_limit);
346 let stalled = total_warps.saturating_sub(active);
347 Ok(WarpStats::new(id, active, stalled))
348 })
349 .collect()
350 }
351
352 #[must_use]
354 pub fn kernel_count(&self) -> usize {
355 self.kernels.len()
356 }
357
358 #[must_use]
360 pub fn spec(&self, kernel_id: u32) -> Option<&KernelSpec> {
361 self.kernels.get(&kernel_id)
362 }
363
364 fn is_reachable(&self, start: u32, target: u32) -> bool {
369 if start == target {
370 return true;
371 }
372 let mut visited = BTreeSet::new();
373 let mut queue = VecDeque::new();
374 queue.push_back(start);
375 while let Some(current) = queue.pop_front() {
376 if visited.contains(¤t) {
377 continue;
378 }
379 visited.insert(current);
380 if let Some(deps) = self.deps.get(¤t) {
381 for &d in deps {
382 if d == target {
383 return true;
384 }
385 queue.push_back(d);
386 }
387 }
388 }
389 false
390 }
391}
392
393impl Default for KernelScheduler {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404
405 fn make_spec(id: u32, work_groups: u32, threads: u32) -> KernelSpec {
406 KernelSpec::new(id, format!("kernel_{id}"), work_groups, threads, 100)
407 }
408
409 #[test]
412 fn test_kernel_spec_total_threads() {
413 let spec = make_spec(1, 4, 64);
414 assert_eq!(spec.total_threads(), 256);
415 }
416
417 #[test]
418 fn test_kernel_spec_zero_work_groups() {
419 let spec = make_spec(2, 0, 64);
420 assert_eq!(spec.total_threads(), 0);
421 }
422
423 #[test]
426 fn test_occupancy_full() {
427 let spec = make_spec(1, 8, 256); let est = OccupancyEstimate::compute(&spec, 64, 32);
429 assert_eq!(est.active_warps, 64);
430 assert_eq!(est.max_warps, 64);
431 assert!((est.theoretical_occupancy - 1.0).abs() < 1e-6);
432 }
433
434 #[test]
435 fn test_occupancy_capped_at_sm_limit() {
436 let spec = make_spec(1, 100, 1024); let est = OccupancyEstimate::compute(&spec, 64, 32);
438 assert_eq!(est.active_warps, 64);
439 assert!((est.theoretical_occupancy - 1.0).abs() < 1e-6);
440 }
441
442 #[test]
443 fn test_occupancy_partial() {
444 let spec = make_spec(1, 2, 64); let est = OccupancyEstimate::compute(&spec, 32, 32);
446 assert_eq!(est.active_warps, 4);
447 assert!((est.theoretical_occupancy - 4.0 / 32.0).abs() < 1e-6);
448 }
449
450 #[test]
453 fn test_warp_stats_utilisation_all_active() {
454 let ws = WarpStats::new(1, 32, 0);
455 assert!((ws.utilisation - 1.0).abs() < 1e-6);
456 }
457
458 #[test]
459 fn test_warp_stats_utilisation_half() {
460 let ws = WarpStats::new(2, 16, 16);
461 assert!((ws.utilisation - 0.5).abs() < 1e-6);
462 }
463
464 #[test]
465 fn test_warp_stats_zero_warps() {
466 let ws = WarpStats::new(3, 0, 0);
467 assert_eq!(ws.utilisation, 0.0);
468 }
469
470 #[test]
473 fn test_add_kernel_and_count() -> Result<(), SchedulerError> {
474 let mut sched = KernelScheduler::new();
475 sched.add_kernel(make_spec(1, 4, 64))?;
476 sched.add_kernel(make_spec(2, 4, 64))?;
477 assert_eq!(sched.kernel_count(), 2);
478 Ok(())
479 }
480
481 #[test]
482 fn test_add_duplicate_kernel_error() -> Result<(), SchedulerError> {
483 let mut sched = KernelScheduler::new();
484 sched.add_kernel(make_spec(1, 4, 64))?;
485 let err = sched.add_kernel(make_spec(1, 8, 128));
486 assert!(matches!(err, Err(SchedulerError::DuplicateKernel(1))));
487 Ok(())
488 }
489
490 #[test]
493 fn test_launch_order_single_kernel() -> Result<(), SchedulerError> {
494 let mut sched = KernelScheduler::new();
495 sched.add_kernel(make_spec(7, 1, 64))?;
496 let order = sched.launch_order()?;
497 assert_eq!(order, vec![7]);
498 Ok(())
499 }
500
501 #[test]
502 fn test_launch_order_linear_chain() -> Result<(), SchedulerError> {
503 let mut sched = KernelScheduler::new();
505 for id in [1, 2, 3] {
506 sched.add_kernel(make_spec(id, 1, 64))?;
507 }
508 sched.add_dependency(2, 1)?; sched.add_dependency(3, 2)?; let order = sched.launch_order()?;
511 assert_eq!(order, vec![1, 2, 3]);
512 Ok(())
513 }
514
515 #[test]
516 fn test_launch_order_diamond() -> Result<(), SchedulerError> {
517 let mut sched = KernelScheduler::new();
519 for id in [1, 2, 3, 4] {
520 sched.add_kernel(make_spec(id, 1, 64))?;
521 }
522 sched.add_dependency(2, 1)?;
523 sched.add_dependency(3, 1)?;
524 sched.add_dependency(4, 2)?;
525 sched.add_dependency(4, 3)?;
526 let order = sched.launch_order()?;
527 assert_eq!(order[0], 1);
529 assert_eq!(order[3], 4);
530 assert!(order.contains(&2));
532 assert!(order.contains(&3));
533 Ok(())
534 }
535
536 #[test]
537 fn test_launch_order_independent_kernels_sorted_by_id() -> Result<(), SchedulerError> {
538 let mut sched = KernelScheduler::new();
539 for id in [5, 3, 1, 4, 2] {
540 sched.add_kernel(make_spec(id, 1, 64))?;
541 }
542 let order = sched.launch_order()?;
543 assert_eq!(order, vec![1, 2, 3, 4, 5]);
544 Ok(())
545 }
546
547 #[test]
550 fn test_add_dependency_unknown_dependent() -> Result<(), SchedulerError> {
551 let mut sched = KernelScheduler::new();
552 sched.add_kernel(make_spec(1, 1, 64))?;
553 let err = sched.add_dependency(99, 1);
554 assert!(matches!(err, Err(SchedulerError::KernelNotFound(99))));
555 Ok(())
556 }
557
558 #[test]
559 fn test_add_dependency_unknown_dependency() -> Result<(), SchedulerError> {
560 let mut sched = KernelScheduler::new();
561 sched.add_kernel(make_spec(1, 1, 64))?;
562 let err = sched.add_dependency(1, 99);
563 assert!(matches!(err, Err(SchedulerError::KernelNotFound(99))));
564 Ok(())
565 }
566
567 #[test]
568 fn test_add_dependency_cycle_detected() -> Result<(), SchedulerError> {
569 let mut sched = KernelScheduler::new();
570 sched.add_kernel(make_spec(1, 1, 64))?;
571 sched.add_kernel(make_spec(2, 1, 64))?;
572 sched.add_dependency(2, 1)?; let err = sched.add_dependency(1, 2);
575 assert!(matches!(err, Err(SchedulerError::CyclicDependency { .. })));
576 Ok(())
577 }
578
579 #[test]
582 fn test_scheduler_occupancy() -> Result<(), SchedulerError> {
583 let mut sched = KernelScheduler::new();
584 sched.add_kernel(make_spec(1, 4, 128))?; let est = sched.occupancy(1, 64, 32)?;
586 assert_eq!(est.active_warps, 16);
587 Ok(())
588 }
589
590 #[test]
591 fn test_scheduler_occupancy_unknown_kernel() -> Result<(), SchedulerError> {
592 let sched = KernelScheduler::new();
593 let err = sched.occupancy(42, 64, 32);
594 assert!(matches!(err, Err(SchedulerError::KernelNotFound(42))));
595 Ok(())
596 }
597
598 #[test]
601 fn test_simulate_warp_stats_basic() -> Result<(), SchedulerError> {
602 let mut sched = KernelScheduler::new();
603 sched.add_kernel(make_spec(1, 2, 64))?; sched.add_kernel(make_spec(2, 1, 64))?; sched.add_dependency(2, 1)?;
606 let stats = sched.simulate_warp_stats(32, 32)?;
607 assert_eq!(stats.len(), 2);
608 assert_eq!(stats[0].kernel_id, 1);
609 assert_eq!(stats[1].kernel_id, 2);
610 Ok(())
611 }
612
613 #[test]
614 fn test_simulate_warp_stats_overflow_clamps() -> Result<(), SchedulerError> {
615 let mut sched = KernelScheduler::new();
616 sched.add_kernel(make_spec(1, 1000, 256))?;
618 let stats = sched.simulate_warp_stats(64, 32)?;
619 assert_eq!(stats[0].active_warps, 64);
620 assert!(stats[0].stalled_warps > 0);
621 assert!(stats[0].utilisation < 1.0 || stats[0].stalled_warps == 0);
622 Ok(())
623 }
624
625 #[test]
628 fn test_dependencies_of() -> Result<(), SchedulerError> {
629 let mut sched = KernelScheduler::new();
630 for id in [1, 2, 3] {
631 sched.add_kernel(make_spec(id, 1, 64))?;
632 }
633 sched.add_dependency(3, 1)?;
634 sched.add_dependency(3, 2)?;
635 let mut deps = sched.dependencies_of(3)?;
636 deps.sort_unstable();
637 assert_eq!(deps, vec![1, 2]);
638 Ok(())
639 }
640
641 #[test]
642 fn test_dependencies_of_no_deps() -> Result<(), SchedulerError> {
643 let mut sched = KernelScheduler::new();
644 sched.add_kernel(make_spec(1, 1, 64))?;
645 let deps = sched.dependencies_of(1)?;
646 assert!(deps.is_empty());
647 Ok(())
648 }
649}