1use std::fmt;
55
56#[derive(Debug, Clone)]
58pub struct SchedulerConfig {
59 pub steal_threshold: u32,
61 pub share_threshold: u32,
63 pub max_steal_batch: u32,
65 pub steal_neighborhood: u32,
67 pub enabled: bool,
69 pub strategy: SchedulingStrategy,
71}
72
73impl Default for SchedulerConfig {
74 fn default() -> Self {
75 Self {
76 steal_threshold: 4,
77 share_threshold: 64,
78 max_steal_batch: 16,
79 steal_neighborhood: 4,
80 enabled: true,
81 strategy: SchedulingStrategy::WorkStealing,
82 }
83 }
84}
85
86impl SchedulerConfig {
87 pub fn static_scheduling() -> Self {
92 Self {
93 enabled: false,
94 strategy: SchedulingStrategy::Static,
95 ..Default::default()
96 }
97 }
98
99 pub fn work_stealing(steal_threshold: u32) -> Self {
104 Self {
105 steal_threshold,
106 strategy: SchedulingStrategy::WorkStealing,
107 ..Default::default()
108 }
109 }
110
111 pub fn round_robin() -> Self {
116 Self {
117 strategy: SchedulingStrategy::RoundRobin,
118 ..Default::default()
119 }
120 }
121
122 pub fn priority(levels: u32) -> Self {
127 let levels = levels.clamp(1, 16);
128 Self {
129 strategy: SchedulingStrategy::Priority { levels },
130 ..Default::default()
131 }
132 }
133
134 pub fn with_steal_threshold(mut self, threshold: u32) -> Self {
136 self.steal_threshold = threshold;
137 self
138 }
139
140 pub fn with_share_threshold(mut self, threshold: u32) -> Self {
142 self.share_threshold = threshold;
143 self
144 }
145
146 pub fn with_max_steal_batch(mut self, batch: u32) -> Self {
148 self.max_steal_batch = batch;
149 self
150 }
151
152 pub fn with_steal_neighborhood(mut self, neighborhood: u32) -> Self {
154 self.steal_neighborhood = neighborhood;
155 self
156 }
157
158 pub fn with_enabled(mut self, enabled: bool) -> Self {
160 self.enabled = enabled;
161 self
162 }
163
164 pub fn is_dynamic(&self) -> bool {
166 self.enabled && self.strategy != SchedulingStrategy::Static
167 }
168}
169
170#[repr(C)]
178#[derive(Debug, Clone, Copy, Default)]
179pub struct WorkItem {
180 pub message_id: u64,
182 pub actor_id: u32,
184 pub priority: u32,
186}
187
188impl WorkItem {
189 pub fn new(actor_id: u32, message_id: u64, priority: u32) -> Self {
191 Self {
192 message_id,
193 actor_id,
194 priority,
195 }
196 }
197}
198
199impl fmt::Display for WorkItem {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 write!(
202 f,
203 "WorkItem(actor={}, msg={}, pri={})",
204 self.actor_id, self.message_id, self.priority
205 )
206 }
207}
208
209#[derive(Debug, Clone)]
216pub struct SchedulerWarpConfig {
217 pub scheduler_warp_id: u32,
219 pub scheduler: SchedulerConfig,
221 pub work_queue_capacity: usize,
224 pub poll_interval_ns: u32,
227}
228
229impl Default for SchedulerWarpConfig {
230 fn default() -> Self {
231 Self {
232 scheduler_warp_id: 0,
233 scheduler: SchedulerConfig::default(),
234 work_queue_capacity: 1024,
235 poll_interval_ns: 1000,
236 }
237 }
238}
239
240impl SchedulerWarpConfig {
241 pub fn new(scheduler: SchedulerConfig) -> Self {
243 Self {
244 scheduler,
245 ..Default::default()
246 }
247 }
248
249 pub fn disabled() -> Self {
252 Self {
253 scheduler: SchedulerConfig::static_scheduling(),
254 ..Default::default()
255 }
256 }
257
258 pub fn with_scheduler_warp(mut self, warp_id: u32) -> Self {
260 self.scheduler_warp_id = warp_id;
261 self
262 }
263
264 pub fn with_work_queue_capacity(mut self, capacity: usize) -> Self {
266 debug_assert!(
267 capacity.is_power_of_two(),
268 "Work queue capacity must be power of 2"
269 );
270 self.work_queue_capacity = capacity;
271 self
272 }
273
274 pub fn with_poll_interval_ns(mut self, ns: u32) -> Self {
276 self.poll_interval_ns = ns;
277 self
278 }
279
280 pub fn is_enabled(&self) -> bool {
282 self.scheduler.is_dynamic()
283 }
284}
285
286#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288pub enum SchedulingStrategy {
289 Static,
291 WorkStealing,
293 WorkSharing,
295 Hybrid,
297 RoundRobin,
300 Priority {
303 levels: u32,
305 },
306}
307
308impl fmt::Display for SchedulingStrategy {
309 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310 match self {
311 Self::Static => write!(f, "static"),
312 Self::WorkStealing => write!(f, "work-stealing"),
313 Self::WorkSharing => write!(f, "work-sharing"),
314 Self::Hybrid => write!(f, "hybrid"),
315 Self::RoundRobin => write!(f, "round-robin"),
316 Self::Priority { levels } => write!(f, "priority({})", levels),
317 }
318 }
319}
320
321#[repr(C, align(32))]
326#[derive(Debug, Clone, Copy, Default)]
327pub struct LoadEntry {
328 pub queue_depth: u32,
330 pub capacity: u32,
332 pub messages_processed: u64,
334 pub steal_requests: u32,
336 pub offer_count: u32,
338 pub load_score: u32,
340 pub _pad: u32,
342}
343
344impl LoadEntry {
345 pub fn compute_load_score(&mut self) {
347 if self.capacity > 0 {
348 self.load_score = ((self.queue_depth as u64 * 255) / self.capacity as u64) as u32;
349 } else {
350 self.load_score = 0;
351 }
352 }
353
354 pub fn is_overloaded(&self, threshold: u32) -> bool {
356 self.queue_depth > threshold
357 }
358
359 pub fn is_underloaded(&self, threshold: u32) -> bool {
361 self.queue_depth < threshold
362 }
363}
364
365pub struct LoadTable {
369 entries: Vec<LoadEntry>,
370}
371
372impl LoadTable {
373 pub fn new(num_actors: usize) -> Self {
375 Self {
376 entries: vec![LoadEntry::default(); num_actors],
377 }
378 }
379
380 pub fn get(&self, actor_id: u32) -> Option<&LoadEntry> {
382 self.entries.get(actor_id as usize)
383 }
384
385 pub fn get_mut(&mut self, actor_id: u32) -> Option<&mut LoadEntry> {
387 self.entries.get_mut(actor_id as usize)
388 }
389
390 pub fn most_loaded(&self) -> Option<(u32, &LoadEntry)> {
392 self.entries
393 .iter()
394 .enumerate()
395 .filter(|(_, e)| e.queue_depth > 0)
396 .max_by_key(|(_, e)| e.queue_depth)
397 .map(|(i, e)| (i as u32, e))
398 }
399
400 pub fn least_loaded(&self) -> Option<(u32, &LoadEntry)> {
402 self.entries
403 .iter()
404 .enumerate()
405 .filter(|(_, e)| e.capacity > 0)
406 .min_by_key(|(_, e)| e.queue_depth)
407 .map(|(i, e)| (i as u32, e))
408 }
409
410 pub fn imbalance_ratio(&self) -> f64 {
413 let active: Vec<&LoadEntry> = self.entries.iter().filter(|e| e.capacity > 0).collect();
414 if active.is_empty() {
415 return 1.0;
416 }
417
418 let max = active.iter().map(|e| e.queue_depth).max().unwrap_or(0);
419 let min = active.iter().map(|e| e.queue_depth).min().unwrap_or(0);
420
421 if min == 0 {
422 if max == 0 {
423 1.0
424 } else {
425 f64::INFINITY
426 }
427 } else {
428 max as f64 / min as f64
429 }
430 }
431
432 pub fn compute_steal_plan(&self, config: &SchedulerConfig) -> Vec<StealOp> {
438 if !config.enabled || config.strategy == SchedulingStrategy::Static {
439 return Vec::new();
440 }
441
442 if matches!(
444 config.strategy,
445 SchedulingStrategy::RoundRobin | SchedulingStrategy::Priority { .. }
446 ) {
447 return Vec::new();
448 }
449
450 let mut ops = Vec::new();
451
452 let thieves: Vec<u32> = self
454 .entries
455 .iter()
456 .enumerate()
457 .filter(|(_, e)| e.is_underloaded(config.steal_threshold) && e.capacity > 0)
458 .map(|(i, _)| i as u32)
459 .collect();
460
461 let mut victims: Vec<(u32, u32)> = self
463 .entries
464 .iter()
465 .enumerate()
466 .filter(|(_, e)| e.is_overloaded(config.share_threshold))
467 .map(|(i, e)| (i as u32, e.queue_depth - config.share_threshold))
468 .collect();
469
470 victims.sort_by_key(|v| std::cmp::Reverse(v.1));
472
473 let mut victim_idx = 0;
475 for thief in &thieves {
476 if victim_idx >= victims.len() {
477 break;
478 }
479
480 let (victim_id, available) = &mut victims[victim_idx];
481 if *available == 0 {
482 victim_idx += 1;
483 continue;
484 }
485
486 let steal_count = (*available).min(config.max_steal_batch);
487 ops.push(StealOp {
488 thief: *thief,
489 victim: *victim_id,
490 count: steal_count,
491 });
492
493 *available -= steal_count;
494 if *available == 0 {
495 victim_idx += 1;
496 }
497 }
498
499 ops
500 }
501
502 pub fn entries(&self) -> &[LoadEntry] {
504 &self.entries
505 }
506
507 pub fn len(&self) -> usize {
509 self.entries.len()
510 }
511
512 pub fn is_empty(&self) -> bool {
514 self.entries.is_empty()
515 }
516}
517
518#[derive(Debug, Clone, Copy)]
520pub struct StealOp {
521 pub thief: u32,
523 pub victim: u32,
525 pub count: u32,
527}
528
529impl fmt::Display for StealOp {
530 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
531 write!(
532 f,
533 "steal {} msgs: actor {} ← actor {}",
534 self.count, self.thief, self.victim
535 )
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_scheduler_config_defaults() {
545 let config = SchedulerConfig::default();
546 assert_eq!(config.steal_threshold, 4);
547 assert_eq!(config.share_threshold, 64);
548 assert!(config.enabled);
549 }
550
551 #[test]
552 fn test_load_entry_score() {
553 let mut entry = LoadEntry {
554 queue_depth: 50,
555 capacity: 100,
556 ..Default::default()
557 };
558 entry.compute_load_score();
559 assert_eq!(entry.load_score, 127); }
561
562 #[test]
563 fn test_load_table_most_least_loaded() {
564 let mut table = LoadTable::new(4);
565 table.entries[0] = LoadEntry {
566 queue_depth: 10,
567 capacity: 100,
568 ..Default::default()
569 };
570 table.entries[1] = LoadEntry {
571 queue_depth: 90,
572 capacity: 100,
573 ..Default::default()
574 };
575 table.entries[2] = LoadEntry {
576 queue_depth: 50,
577 capacity: 100,
578 ..Default::default()
579 };
580 table.entries[3] = LoadEntry {
581 queue_depth: 5,
582 capacity: 100,
583 ..Default::default()
584 };
585
586 let (most_id, most) = table.most_loaded().unwrap();
587 assert_eq!(most_id, 1);
588 assert_eq!(most.queue_depth, 90);
589
590 let (least_id, least) = table.least_loaded().unwrap();
591 assert_eq!(least_id, 3);
592 assert_eq!(least.queue_depth, 5);
593 }
594
595 #[test]
596 fn test_imbalance_ratio() {
597 let mut table = LoadTable::new(4);
598 for e in &mut table.entries {
600 e.queue_depth = 50;
601 e.capacity = 100;
602 }
603 assert!((table.imbalance_ratio() - 1.0).abs() < 0.01);
604
605 table.entries[0].queue_depth = 10;
607 table.entries[1].queue_depth = 100;
608 assert!((table.imbalance_ratio() - 10.0).abs() < 0.01);
609 }
610
611 #[test]
612 fn test_steal_plan_static_disabled() {
613 let table = LoadTable::new(4);
614 let config = SchedulerConfig {
615 strategy: SchedulingStrategy::Static,
616 ..Default::default()
617 };
618 let plan = table.compute_steal_plan(&config);
619 assert!(plan.is_empty());
620 }
621
622 #[test]
623 fn test_steal_plan_work_stealing() {
624 let mut table = LoadTable::new(4);
625 table.entries[0] = LoadEntry {
627 queue_depth: 2,
628 capacity: 100,
629 ..Default::default()
630 };
631 table.entries[1] = LoadEntry {
633 queue_depth: 80,
634 capacity: 100,
635 ..Default::default()
636 };
637 table.entries[2] = LoadEntry {
639 queue_depth: 30,
640 capacity: 100,
641 ..Default::default()
642 };
643 table.entries[3] = LoadEntry {
645 queue_depth: 1,
646 capacity: 100,
647 ..Default::default()
648 };
649
650 let config = SchedulerConfig::default();
651 let plan = table.compute_steal_plan(&config);
652
653 assert!(!plan.is_empty(), "Should produce steal operations");
654 assert!(plan.iter().all(|op| op.victim == 1));
656 assert!(plan.iter().any(|op| op.thief == 0 || op.thief == 3));
657 }
658
659 #[test]
660 fn test_steal_plan_respects_max_batch() {
661 let mut table = LoadTable::new(2);
662 table.entries[0] = LoadEntry {
663 queue_depth: 0,
664 capacity: 100,
665 ..Default::default()
666 };
667 table.entries[1] = LoadEntry {
668 queue_depth: 100,
669 capacity: 100,
670 ..Default::default()
671 };
672
673 let config = SchedulerConfig {
674 max_steal_batch: 8,
675 ..Default::default()
676 };
677 let plan = table.compute_steal_plan(&config);
678
679 assert!(!plan.is_empty());
680 for op in &plan {
681 assert!(
682 op.count <= 8,
683 "Steal count {} exceeds max batch 8",
684 op.count
685 );
686 }
687 }
688
689 #[test]
690 fn test_load_entry_size() {
691 assert_eq!(
692 std::mem::size_of::<LoadEntry>(),
693 32,
694 "LoadEntry must be 32 bytes for GPU cache efficiency"
695 );
696 }
697
698 #[test]
699 fn test_work_item_size() {
700 assert_eq!(
701 std::mem::size_of::<WorkItem>(),
702 16,
703 "WorkItem must be 16 bytes for GPU cache efficiency"
704 );
705 }
706
707 #[test]
708 fn test_work_item_display() {
709 let item = WorkItem::new(3, 42, 2);
710 let s = format!("{}", item);
711 assert!(s.contains("actor=3"));
712 assert!(s.contains("msg=42"));
713 assert!(s.contains("pri=2"));
714 }
715
716 #[test]
717 fn test_scheduler_config_static() {
718 let config = SchedulerConfig::static_scheduling();
719 assert!(!config.enabled);
720 assert_eq!(config.strategy, SchedulingStrategy::Static);
721 assert!(!config.is_dynamic());
722 }
723
724 #[test]
725 fn test_scheduler_config_work_stealing() {
726 let config = SchedulerConfig::work_stealing(8);
727 assert_eq!(config.steal_threshold, 8);
728 assert_eq!(config.strategy, SchedulingStrategy::WorkStealing);
729 assert!(config.is_dynamic());
730 }
731
732 #[test]
733 fn test_scheduler_config_round_robin() {
734 let config = SchedulerConfig::round_robin();
735 assert_eq!(config.strategy, SchedulingStrategy::RoundRobin);
736 assert!(config.is_dynamic());
737 }
738
739 #[test]
740 fn test_scheduler_config_priority() {
741 let config = SchedulerConfig::priority(4);
742 assert_eq!(config.strategy, SchedulingStrategy::Priority { levels: 4 });
743 assert!(config.is_dynamic());
744 }
745
746 #[test]
747 fn test_scheduler_config_priority_clamped() {
748 let config = SchedulerConfig::priority(100);
749 assert_eq!(config.strategy, SchedulingStrategy::Priority { levels: 16 });
750 }
751
752 #[test]
753 fn test_scheduler_config_builder_chain() {
754 let config = SchedulerConfig::work_stealing(10)
755 .with_share_threshold(80)
756 .with_max_steal_batch(32)
757 .with_steal_neighborhood(6);
758
759 assert_eq!(config.steal_threshold, 10);
760 assert_eq!(config.share_threshold, 80);
761 assert_eq!(config.max_steal_batch, 32);
762 assert_eq!(config.steal_neighborhood, 6);
763 }
764
765 #[test]
766 fn test_scheduler_warp_config_default() {
767 let config = SchedulerWarpConfig::default();
768 assert_eq!(config.scheduler_warp_id, 0);
769 assert_eq!(config.work_queue_capacity, 1024);
770 assert_eq!(config.poll_interval_ns, 1000);
771 assert!(config.is_enabled());
772 }
773
774 #[test]
775 fn test_scheduler_warp_config_disabled() {
776 let config = SchedulerWarpConfig::disabled();
777 assert!(!config.is_enabled());
778 }
779
780 #[test]
781 fn test_scheduler_warp_config_builder() {
782 let config = SchedulerWarpConfig::new(SchedulerConfig::round_robin())
783 .with_scheduler_warp(1)
784 .with_work_queue_capacity(2048)
785 .with_poll_interval_ns(500);
786
787 assert_eq!(config.scheduler_warp_id, 1);
788 assert_eq!(config.work_queue_capacity, 2048);
789 assert_eq!(config.poll_interval_ns, 500);
790 assert!(config.is_enabled());
791 }
792
793 #[test]
794 fn test_strategy_display() {
795 assert_eq!(format!("{}", SchedulingStrategy::Static), "static");
796 assert_eq!(
797 format!("{}", SchedulingStrategy::WorkStealing),
798 "work-stealing"
799 );
800 assert_eq!(
801 format!("{}", SchedulingStrategy::WorkSharing),
802 "work-sharing"
803 );
804 assert_eq!(format!("{}", SchedulingStrategy::Hybrid), "hybrid");
805 assert_eq!(format!("{}", SchedulingStrategy::RoundRobin), "round-robin");
806 assert_eq!(
807 format!("{}", SchedulingStrategy::Priority { levels: 4 }),
808 "priority(4)"
809 );
810 }
811
812 #[test]
813 fn test_steal_plan_round_robin_empty() {
814 let mut table = LoadTable::new(4);
815 table.entries[0] = LoadEntry {
816 queue_depth: 2,
817 capacity: 100,
818 ..Default::default()
819 };
820 table.entries[1] = LoadEntry {
821 queue_depth: 80,
822 capacity: 100,
823 ..Default::default()
824 };
825
826 let config = SchedulerConfig::round_robin();
827 let plan = table.compute_steal_plan(&config);
828 assert!(plan.is_empty(), "Round-robin should not produce steal ops");
829 }
830
831 #[test]
832 fn test_steal_plan_priority_empty() {
833 let mut table = LoadTable::new(4);
834 table.entries[0] = LoadEntry {
835 queue_depth: 2,
836 capacity: 100,
837 ..Default::default()
838 };
839 table.entries[1] = LoadEntry {
840 queue_depth: 80,
841 capacity: 100,
842 ..Default::default()
843 };
844
845 let config = SchedulerConfig::priority(4);
846 let plan = table.compute_steal_plan(&config);
847 assert!(plan.is_empty(), "Priority should not produce steal ops");
848 }
849}