use std::fmt;
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub steal_threshold: u32,
pub share_threshold: u32,
pub max_steal_batch: u32,
pub steal_neighborhood: u32,
pub enabled: bool,
pub strategy: SchedulingStrategy,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
steal_threshold: 4,
share_threshold: 64,
max_steal_batch: 16,
steal_neighborhood: 4,
enabled: true,
strategy: SchedulingStrategy::WorkStealing,
}
}
}
impl SchedulerConfig {
pub fn static_scheduling() -> Self {
Self {
enabled: false,
strategy: SchedulingStrategy::Static,
..Default::default()
}
}
pub fn work_stealing(steal_threshold: u32) -> Self {
Self {
steal_threshold,
strategy: SchedulingStrategy::WorkStealing,
..Default::default()
}
}
pub fn round_robin() -> Self {
Self {
strategy: SchedulingStrategy::RoundRobin,
..Default::default()
}
}
pub fn priority(levels: u32) -> Self {
let levels = levels.clamp(1, 16);
Self {
strategy: SchedulingStrategy::Priority { levels },
..Default::default()
}
}
pub fn with_steal_threshold(mut self, threshold: u32) -> Self {
self.steal_threshold = threshold;
self
}
pub fn with_share_threshold(mut self, threshold: u32) -> Self {
self.share_threshold = threshold;
self
}
pub fn with_max_steal_batch(mut self, batch: u32) -> Self {
self.max_steal_batch = batch;
self
}
pub fn with_steal_neighborhood(mut self, neighborhood: u32) -> Self {
self.steal_neighborhood = neighborhood;
self
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn is_dynamic(&self) -> bool {
self.enabled && self.strategy != SchedulingStrategy::Static
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct WorkItem {
pub message_id: u64,
pub actor_id: u32,
pub priority: u32,
}
impl WorkItem {
pub fn new(actor_id: u32, message_id: u64, priority: u32) -> Self {
Self {
message_id,
actor_id,
priority,
}
}
}
impl fmt::Display for WorkItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"WorkItem(actor={}, msg={}, pri={})",
self.actor_id, self.message_id, self.priority
)
}
}
#[derive(Debug, Clone)]
pub struct SchedulerWarpConfig {
pub scheduler_warp_id: u32,
pub scheduler: SchedulerConfig,
pub work_queue_capacity: usize,
pub poll_interval_ns: u32,
}
impl Default for SchedulerWarpConfig {
fn default() -> Self {
Self {
scheduler_warp_id: 0,
scheduler: SchedulerConfig::default(),
work_queue_capacity: 1024,
poll_interval_ns: 1000,
}
}
}
impl SchedulerWarpConfig {
pub fn new(scheduler: SchedulerConfig) -> Self {
Self {
scheduler,
..Default::default()
}
}
pub fn disabled() -> Self {
Self {
scheduler: SchedulerConfig::static_scheduling(),
..Default::default()
}
}
pub fn with_scheduler_warp(mut self, warp_id: u32) -> Self {
self.scheduler_warp_id = warp_id;
self
}
pub fn with_work_queue_capacity(mut self, capacity: usize) -> Self {
debug_assert!(
capacity.is_power_of_two(),
"Work queue capacity must be power of 2"
);
self.work_queue_capacity = capacity;
self
}
pub fn with_poll_interval_ns(mut self, ns: u32) -> Self {
self.poll_interval_ns = ns;
self
}
pub fn is_enabled(&self) -> bool {
self.scheduler.is_dynamic()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchedulingStrategy {
Static,
WorkStealing,
WorkSharing,
Hybrid,
RoundRobin,
Priority {
levels: u32,
},
}
impl fmt::Display for SchedulingStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Static => write!(f, "static"),
Self::WorkStealing => write!(f, "work-stealing"),
Self::WorkSharing => write!(f, "work-sharing"),
Self::Hybrid => write!(f, "hybrid"),
Self::RoundRobin => write!(f, "round-robin"),
Self::Priority { levels } => write!(f, "priority({})", levels),
}
}
}
#[repr(C, align(32))]
#[derive(Debug, Clone, Copy, Default)]
pub struct LoadEntry {
pub queue_depth: u32,
pub capacity: u32,
pub messages_processed: u64,
pub steal_requests: u32,
pub offer_count: u32,
pub load_score: u32,
pub _pad: u32,
}
impl LoadEntry {
pub fn compute_load_score(&mut self) {
if self.capacity > 0 {
self.load_score = ((self.queue_depth as u64 * 255) / self.capacity as u64) as u32;
} else {
self.load_score = 0;
}
}
pub fn is_overloaded(&self, threshold: u32) -> bool {
self.queue_depth > threshold
}
pub fn is_underloaded(&self, threshold: u32) -> bool {
self.queue_depth < threshold
}
}
pub struct LoadTable {
entries: Vec<LoadEntry>,
}
impl LoadTable {
pub fn new(num_actors: usize) -> Self {
Self {
entries: vec![LoadEntry::default(); num_actors],
}
}
pub fn get(&self, actor_id: u32) -> Option<&LoadEntry> {
self.entries.get(actor_id as usize)
}
pub fn get_mut(&mut self, actor_id: u32) -> Option<&mut LoadEntry> {
self.entries.get_mut(actor_id as usize)
}
pub fn most_loaded(&self) -> Option<(u32, &LoadEntry)> {
self.entries
.iter()
.enumerate()
.filter(|(_, e)| e.queue_depth > 0)
.max_by_key(|(_, e)| e.queue_depth)
.map(|(i, e)| (i as u32, e))
}
pub fn least_loaded(&self) -> Option<(u32, &LoadEntry)> {
self.entries
.iter()
.enumerate()
.filter(|(_, e)| e.capacity > 0)
.min_by_key(|(_, e)| e.queue_depth)
.map(|(i, e)| (i as u32, e))
}
pub fn imbalance_ratio(&self) -> f64 {
let active: Vec<&LoadEntry> = self.entries.iter().filter(|e| e.capacity > 0).collect();
if active.is_empty() {
return 1.0;
}
let max = active.iter().map(|e| e.queue_depth).max().unwrap_or(0);
let min = active.iter().map(|e| e.queue_depth).min().unwrap_or(0);
if min == 0 {
if max == 0 {
1.0
} else {
f64::INFINITY
}
} else {
max as f64 / min as f64
}
}
pub fn compute_steal_plan(&self, config: &SchedulerConfig) -> Vec<StealOp> {
if !config.enabled || config.strategy == SchedulingStrategy::Static {
return Vec::new();
}
if matches!(
config.strategy,
SchedulingStrategy::RoundRobin | SchedulingStrategy::Priority { .. }
) {
return Vec::new();
}
let mut ops = Vec::new();
let thieves: Vec<u32> = self
.entries
.iter()
.enumerate()
.filter(|(_, e)| e.is_underloaded(config.steal_threshold) && e.capacity > 0)
.map(|(i, _)| i as u32)
.collect();
let mut victims: Vec<(u32, u32)> = self
.entries
.iter()
.enumerate()
.filter(|(_, e)| e.is_overloaded(config.share_threshold))
.map(|(i, e)| (i as u32, e.queue_depth - config.share_threshold))
.collect();
victims.sort_by_key(|v| std::cmp::Reverse(v.1));
let mut victim_idx = 0;
for thief in &thieves {
if victim_idx >= victims.len() {
break;
}
let (victim_id, available) = &mut victims[victim_idx];
if *available == 0 {
victim_idx += 1;
continue;
}
let steal_count = (*available).min(config.max_steal_batch);
ops.push(StealOp {
thief: *thief,
victim: *victim_id,
count: steal_count,
});
*available -= steal_count;
if *available == 0 {
victim_idx += 1;
}
}
ops
}
pub fn entries(&self) -> &[LoadEntry] {
&self.entries
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[derive(Debug, Clone, Copy)]
pub struct StealOp {
pub thief: u32,
pub victim: u32,
pub count: u32,
}
impl fmt::Display for StealOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"steal {} msgs: actor {} ← actor {}",
self.count, self.thief, self.victim
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_config_defaults() {
let config = SchedulerConfig::default();
assert_eq!(config.steal_threshold, 4);
assert_eq!(config.share_threshold, 64);
assert!(config.enabled);
}
#[test]
fn test_load_entry_score() {
let mut entry = LoadEntry {
queue_depth: 50,
capacity: 100,
..Default::default()
};
entry.compute_load_score();
assert_eq!(entry.load_score, 127); }
#[test]
fn test_load_table_most_least_loaded() {
let mut table = LoadTable::new(4);
table.entries[0] = LoadEntry {
queue_depth: 10,
capacity: 100,
..Default::default()
};
table.entries[1] = LoadEntry {
queue_depth: 90,
capacity: 100,
..Default::default()
};
table.entries[2] = LoadEntry {
queue_depth: 50,
capacity: 100,
..Default::default()
};
table.entries[3] = LoadEntry {
queue_depth: 5,
capacity: 100,
..Default::default()
};
let (most_id, most) = table.most_loaded().unwrap();
assert_eq!(most_id, 1);
assert_eq!(most.queue_depth, 90);
let (least_id, least) = table.least_loaded().unwrap();
assert_eq!(least_id, 3);
assert_eq!(least.queue_depth, 5);
}
#[test]
fn test_imbalance_ratio() {
let mut table = LoadTable::new(4);
for e in &mut table.entries {
e.queue_depth = 50;
e.capacity = 100;
}
assert!((table.imbalance_ratio() - 1.0).abs() < 0.01);
table.entries[0].queue_depth = 10;
table.entries[1].queue_depth = 100;
assert!((table.imbalance_ratio() - 10.0).abs() < 0.01);
}
#[test]
fn test_steal_plan_static_disabled() {
let table = LoadTable::new(4);
let config = SchedulerConfig {
strategy: SchedulingStrategy::Static,
..Default::default()
};
let plan = table.compute_steal_plan(&config);
assert!(plan.is_empty());
}
#[test]
fn test_steal_plan_work_stealing() {
let mut table = LoadTable::new(4);
table.entries[0] = LoadEntry {
queue_depth: 2,
capacity: 100,
..Default::default()
};
table.entries[1] = LoadEntry {
queue_depth: 80,
capacity: 100,
..Default::default()
};
table.entries[2] = LoadEntry {
queue_depth: 30,
capacity: 100,
..Default::default()
};
table.entries[3] = LoadEntry {
queue_depth: 1,
capacity: 100,
..Default::default()
};
let config = SchedulerConfig::default();
let plan = table.compute_steal_plan(&config);
assert!(!plan.is_empty(), "Should produce steal operations");
assert!(plan.iter().all(|op| op.victim == 1));
assert!(plan.iter().any(|op| op.thief == 0 || op.thief == 3));
}
#[test]
fn test_steal_plan_respects_max_batch() {
let mut table = LoadTable::new(2);
table.entries[0] = LoadEntry {
queue_depth: 0,
capacity: 100,
..Default::default()
};
table.entries[1] = LoadEntry {
queue_depth: 100,
capacity: 100,
..Default::default()
};
let config = SchedulerConfig {
max_steal_batch: 8,
..Default::default()
};
let plan = table.compute_steal_plan(&config);
assert!(!plan.is_empty());
for op in &plan {
assert!(
op.count <= 8,
"Steal count {} exceeds max batch 8",
op.count
);
}
}
#[test]
fn test_load_entry_size() {
assert_eq!(
std::mem::size_of::<LoadEntry>(),
32,
"LoadEntry must be 32 bytes for GPU cache efficiency"
);
}
#[test]
fn test_work_item_size() {
assert_eq!(
std::mem::size_of::<WorkItem>(),
16,
"WorkItem must be 16 bytes for GPU cache efficiency"
);
}
#[test]
fn test_work_item_display() {
let item = WorkItem::new(3, 42, 2);
let s = format!("{}", item);
assert!(s.contains("actor=3"));
assert!(s.contains("msg=42"));
assert!(s.contains("pri=2"));
}
#[test]
fn test_scheduler_config_static() {
let config = SchedulerConfig::static_scheduling();
assert!(!config.enabled);
assert_eq!(config.strategy, SchedulingStrategy::Static);
assert!(!config.is_dynamic());
}
#[test]
fn test_scheduler_config_work_stealing() {
let config = SchedulerConfig::work_stealing(8);
assert_eq!(config.steal_threshold, 8);
assert_eq!(config.strategy, SchedulingStrategy::WorkStealing);
assert!(config.is_dynamic());
}
#[test]
fn test_scheduler_config_round_robin() {
let config = SchedulerConfig::round_robin();
assert_eq!(config.strategy, SchedulingStrategy::RoundRobin);
assert!(config.is_dynamic());
}
#[test]
fn test_scheduler_config_priority() {
let config = SchedulerConfig::priority(4);
assert_eq!(config.strategy, SchedulingStrategy::Priority { levels: 4 });
assert!(config.is_dynamic());
}
#[test]
fn test_scheduler_config_priority_clamped() {
let config = SchedulerConfig::priority(100);
assert_eq!(config.strategy, SchedulingStrategy::Priority { levels: 16 });
}
#[test]
fn test_scheduler_config_builder_chain() {
let config = SchedulerConfig::work_stealing(10)
.with_share_threshold(80)
.with_max_steal_batch(32)
.with_steal_neighborhood(6);
assert_eq!(config.steal_threshold, 10);
assert_eq!(config.share_threshold, 80);
assert_eq!(config.max_steal_batch, 32);
assert_eq!(config.steal_neighborhood, 6);
}
#[test]
fn test_scheduler_warp_config_default() {
let config = SchedulerWarpConfig::default();
assert_eq!(config.scheduler_warp_id, 0);
assert_eq!(config.work_queue_capacity, 1024);
assert_eq!(config.poll_interval_ns, 1000);
assert!(config.is_enabled());
}
#[test]
fn test_scheduler_warp_config_disabled() {
let config = SchedulerWarpConfig::disabled();
assert!(!config.is_enabled());
}
#[test]
fn test_scheduler_warp_config_builder() {
let config = SchedulerWarpConfig::new(SchedulerConfig::round_robin())
.with_scheduler_warp(1)
.with_work_queue_capacity(2048)
.with_poll_interval_ns(500);
assert_eq!(config.scheduler_warp_id, 1);
assert_eq!(config.work_queue_capacity, 2048);
assert_eq!(config.poll_interval_ns, 500);
assert!(config.is_enabled());
}
#[test]
fn test_strategy_display() {
assert_eq!(format!("{}", SchedulingStrategy::Static), "static");
assert_eq!(
format!("{}", SchedulingStrategy::WorkStealing),
"work-stealing"
);
assert_eq!(
format!("{}", SchedulingStrategy::WorkSharing),
"work-sharing"
);
assert_eq!(format!("{}", SchedulingStrategy::Hybrid), "hybrid");
assert_eq!(format!("{}", SchedulingStrategy::RoundRobin), "round-robin");
assert_eq!(
format!("{}", SchedulingStrategy::Priority { levels: 4 }),
"priority(4)"
);
}
#[test]
fn test_steal_plan_round_robin_empty() {
let mut table = LoadTable::new(4);
table.entries[0] = LoadEntry {
queue_depth: 2,
capacity: 100,
..Default::default()
};
table.entries[1] = LoadEntry {
queue_depth: 80,
capacity: 100,
..Default::default()
};
let config = SchedulerConfig::round_robin();
let plan = table.compute_steal_plan(&config);
assert!(plan.is_empty(), "Round-robin should not produce steal ops");
}
#[test]
fn test_steal_plan_priority_empty() {
let mut table = LoadTable::new(4);
table.entries[0] = LoadEntry {
queue_depth: 2,
capacity: 100,
..Default::default()
};
table.entries[1] = LoadEntry {
queue_depth: 80,
capacity: 100,
..Default::default()
};
let config = SchedulerConfig::priority(4);
let plan = table.compute_steal_plan(&config);
assert!(plan.is_empty(), "Priority should not produce steal ops");
}
}