use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering as AOrdering};
use std::sync::{Arc, Mutex};
use std::thread;
use crate::numa_scheduler::migration::worker_numa_node;
use crate::numa_scheduler::topology::{cores_in_node, nearest_numa_nodes};
use crate::numa_scheduler::types::{NumaTopology, SchedulerStats, Task, WorkStealingConfig};
pub struct WorkerDeque {
pub local_queue: VecDeque<Task<()>>,
pub worker_id: usize,
pub numa_node: usize,
}
impl WorkerDeque {
pub fn new(worker_id: usize, numa_node: usize, capacity: usize) -> Self {
Self {
local_queue: VecDeque::with_capacity(capacity),
worker_id,
numa_node,
}
}
pub fn push_local(&mut self, task: Task<()>) {
self.local_queue.push_back(task);
}
pub fn pop_local(&mut self) -> Option<Task<()>> {
self.local_queue.pop_back()
}
pub fn steal_remote(&mut self) -> Option<Task<()>> {
self.local_queue.pop_front()
}
pub fn len(&self) -> usize {
self.local_queue.len()
}
pub fn is_empty(&self) -> bool {
self.local_queue.is_empty()
}
}
pub fn choose_victim(
worker_id: usize,
topology: &NumaTopology,
queue_lengths: &[usize],
) -> Option<usize> {
let n = queue_lengths.len();
if n == 0 {
return None;
}
let my_node = worker_numa_node(worker_id, topology);
let same_node_cores = cores_in_node(topology, my_node);
let same_node_workers: Vec<usize> = (0..n)
.filter(|&w| {
if w == worker_id {
return false;
}
let core_id = w % topology.cores.len().max(1);
same_node_cores.contains(&core_id)
})
.filter(|&w| queue_lengths[w] > 0)
.collect();
if let Some(&victim) = same_node_workers.iter().max_by_key(|&&w| queue_lengths[w]) {
return Some(victim);
}
let node_order = nearest_numa_nodes(topology, my_node);
for node in node_order {
if node == my_node {
continue; }
let node_cores = cores_in_node(topology, node);
let victim = (0..n)
.filter(|&w| w != worker_id)
.filter(|&w| {
let core_id = w % topology.cores.len().max(1);
node_cores.contains(&core_id)
})
.filter(|&w| queue_lengths[w] > 0)
.max_by_key(|&w| queue_lengths[w]);
if let Some(v) = victim {
return Some(v);
}
}
None
}
pub fn assign_task(task: Task<()>, topology: &NumaTopology, queues: &[Arc<Mutex<WorkerDeque>>]) {
let n = queues.len();
if n == 0 {
return; }
let lengths: Vec<usize> = queues
.iter()
.map(|q| q.lock().map(|g| g.len()).unwrap_or(0))
.collect();
let preferred_node: Option<usize> = task.affinity.filter(|&node| node < topology.n_nodes);
let target = if let Some(node) = preferred_node {
let node_cores = cores_in_node(topology, node);
let candidate = (0..n)
.filter(|&w| {
let core_id = w % topology.cores.len().max(1);
node_cores.contains(&core_id)
})
.min_by_key(|&w| lengths[w]);
candidate.or_else(|| (0..n).min_by_key(|&w| lengths[w]))
} else {
(0..n).min_by_key(|&w| lengths[w])
};
if let Some(idx) = target {
if let Ok(mut deque) = queues[idx].lock() {
deque.push_local(task);
}
}
}
pub struct NumaWorkStealingScheduler {
queues: Vec<Arc<Mutex<WorkerDeque>>>,
workers: Vec<thread::JoinHandle<()>>,
shutdown_flag: Arc<AtomicBool>,
stats: Arc<Mutex<SchedulerStats>>,
topology: Arc<NumaTopology>,
}
impl NumaWorkStealingScheduler {
pub fn new(config: &WorkStealingConfig, topology: NumaTopology) -> Self {
let n = config.n_workers.max(1);
let topology = Arc::new(topology);
let shutdown_flag = Arc::new(AtomicBool::new(false));
let stats = Arc::new(Mutex::new(SchedulerStats::default()));
let queues: Vec<Arc<Mutex<WorkerDeque>>> = (0..n)
.map(|id| {
let node = worker_numa_node(id, &topology);
Arc::new(Mutex::new(WorkerDeque::new(
id,
node,
config.local_queue_size,
)))
})
.collect();
let mut workers = Vec::with_capacity(n);
for worker_id in 0..n {
let queues_clone = queues.clone();
let topo_clone = Arc::clone(&topology);
let stop = Arc::clone(&shutdown_flag);
let stats_clone = Arc::clone(&stats);
let handle = thread::Builder::new()
.name(format!("numa-ws-{}", worker_id))
.spawn(move || {
worker_loop(worker_id, &queues_clone, &topo_clone, &stop, &stats_clone);
})
.expect("failed to spawn worker thread");
workers.push(handle);
}
Self {
queues,
workers,
shutdown_flag,
stats,
topology,
}
}
pub fn submit(&self, task: Task<()>) {
assign_task(task, &self.topology, &self.queues);
}
pub fn submit_many(&self, tasks: Vec<Task<()>>) {
for task in tasks {
assign_task(task, &self.topology, &self.queues);
}
}
pub fn stats(&self) -> SchedulerStats {
self.stats.lock().map(|g| g.clone()).unwrap_or_default()
}
pub fn shutdown(self) {
self.shutdown_flag.store(true, AOrdering::SeqCst);
for handle in self.workers {
let _ = handle.join();
}
}
pub fn n_workers(&self) -> usize {
self.queues.len()
}
pub fn queue_lengths(&self) -> Vec<usize> {
self.queues
.iter()
.map(|q| q.lock().map(|g| g.len()).unwrap_or(0))
.collect()
}
}
fn worker_loop(
worker_id: usize,
queues: &[Arc<Mutex<WorkerDeque>>],
topology: &NumaTopology,
stop: &AtomicBool,
stats: &Mutex<SchedulerStats>,
) {
loop {
let own_task = queues
.get(worker_id)
.and_then(|q| q.lock().ok())
.and_then(|mut g| g.pop_local())
.map(|t| (t, false));
if let Some((task, stolen)) = own_task {
execute_task(task, worker_id, stolen, stats);
continue;
}
if stop.load(AOrdering::Relaxed) {
while let Some(task) = queues
.get(worker_id)
.and_then(|q| q.lock().ok())
.and_then(|mut g| g.pop_local())
{
execute_task(task, worker_id, false, stats);
}
break;
}
let lengths: Vec<usize> = queues
.iter()
.map(|q| q.lock().map(|g| g.len()).unwrap_or(0))
.collect();
if let Ok(mut s) = stats.lock() {
s.steal_attempts += 1;
}
let victim = choose_victim(worker_id, topology, &lengths);
if let Some(v) = victim {
let stolen_task = queues
.get(v)
.and_then(|q| q.lock().ok())
.and_then(|mut g| g.steal_remote());
if let Some(task) = stolen_task {
execute_task(task, worker_id, true, stats);
continue;
}
}
thread::sleep(std::time::Duration::from_micros(PARK_US));
}
}
fn execute_task(task: Task<()>, worker_id: usize, stolen: bool, stats: &Mutex<SchedulerStats>) {
(task.work)();
if let Ok(mut s) = stats.lock() {
s.tasks_executed += 1;
if stolen {
s.tasks_stolen += 1;
}
}
let _ = (worker_id, stolen); }
const PARK_US: u64 = 100;
#[cfg(test)]
mod tests {
use super::*;
use crate::numa_scheduler::types::{NumaTopology, WorkStealingConfig};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
fn small_topo() -> NumaTopology {
NumaTopology::from_config(2, 2) }
fn small_config() -> WorkStealingConfig {
WorkStealingConfig {
n_workers: 4,
steal_threshold: 1,
local_queue_size: 32,
}
}
#[test]
fn test_worker_deque_push_pop() {
let mut deque = WorkerDeque::new(0, 0, 8);
let task = Task::new(|| ());
deque.push_local(task);
assert_eq!(deque.len(), 1);
assert!(deque.pop_local().is_some());
assert!(deque.is_empty());
}
#[test]
fn test_steal_victim_prefers_local() {
let t = small_topo(); let lengths = vec![0usize, 5, 3, 8];
let victim = choose_victim(0, &t, &lengths);
assert_eq!(victim, Some(1), "Expected worker 1 (same NUMA)");
}
#[test]
fn test_steal_victim_fallback_remote() {
let t = small_topo();
let lengths = vec![0usize, 0, 0, 7];
let victim = choose_victim(0, &t, &lengths);
assert_eq!(victim, Some(3), "Should fall back to remote worker 3");
}
#[test]
fn test_steal_victim_no_tasks() {
let t = small_topo();
let lengths = vec![0, 0, 0, 0];
assert!(choose_victim(0, &t, &lengths).is_none());
}
#[test]
fn test_affinity_assignment() {
let t = small_topo(); let queues: Vec<Arc<Mutex<WorkerDeque>>> = (0..4)
.map(|id| {
let node = worker_numa_node(id, &t);
Arc::new(Mutex::new(WorkerDeque::new(id, node, 8)))
})
.collect();
let task = Task::new(|| ()).with_affinity(1);
assign_task(task, &t, &queues);
let q2 = queues[2].lock().map(|g| g.len()).unwrap_or(0);
let q3 = queues[3].lock().map(|g| g.len()).unwrap_or(0);
assert_eq!(
q2 + q3,
1,
"Task with affinity=1 should be on node-1 workers"
);
}
#[test]
fn test_scheduler_submits_tasks() {
let counter = Arc::new(AtomicUsize::new(0));
let sched = NumaWorkStealingScheduler::new(&small_config(), small_topo());
let n_tasks = 20usize;
for _ in 0..n_tasks {
let c = Arc::clone(&counter);
sched.submit(Task::new(move || {
c.fetch_add(1, Ordering::Relaxed);
}));
}
thread::sleep(Duration::from_millis(200));
sched.shutdown();
assert_eq!(
counter.load(Ordering::SeqCst),
n_tasks,
"All tasks should have been executed"
);
}
#[test]
fn test_scheduler_stats_executed() {
let sched = NumaWorkStealingScheduler::new(&small_config(), small_topo());
let n_tasks = 10;
for _ in 0..n_tasks {
sched.submit(Task::new(|| {}));
}
thread::sleep(Duration::from_millis(200));
let s = sched.stats();
sched.shutdown();
assert_eq!(
s.tasks_executed, n_tasks as u64,
"stats should count all executed tasks"
);
}
#[test]
fn test_scheduler_work_stealing() {
let topo = NumaTopology::from_config(1, 4);
let config = WorkStealingConfig {
n_workers: 4,
steal_threshold: 1,
local_queue_size: 64,
};
let sched = NumaWorkStealingScheduler::new(&config, topo);
let stolen = Arc::new(AtomicUsize::new(0));
for _ in 0..40 {
sched.submit(Task::new(|| {}));
}
thread::sleep(Duration::from_millis(300));
let s = sched.stats();
sched.shutdown();
stolen.fetch_add(s.tasks_stolen as usize, Ordering::Relaxed);
assert_eq!(s.tasks_executed, 40, "All 40 tasks should be executed");
}
#[test]
fn test_scheduler_shutdown() {
let sched = NumaWorkStealingScheduler::new(&small_config(), small_topo());
for _ in 0..5 {
sched.submit(Task::new(|| {}));
}
sched.shutdown();
}
}