use super::morsel::Morsel;
use crossbeam::deque::{Injector, Steal, Stealer, Worker};
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
pub type NumaNode = usize;
#[derive(Debug, Clone)]
pub struct NumaConfig {
pub num_nodes: usize,
pub workers_per_node: usize,
}
impl Default for NumaConfig {
fn default() -> Self {
Self {
num_nodes: 1,
workers_per_node: usize::MAX,
}
}
}
impl NumaConfig {
#[must_use]
pub fn with_topology(num_nodes: usize, workers_per_node: usize) -> Self {
Self {
num_nodes,
workers_per_node,
}
}
#[must_use]
pub fn auto_detect(num_workers: usize) -> Self {
if num_workers > 8 {
Self {
num_nodes: 2,
workers_per_node: (num_workers + 1) / 2,
}
} else {
Self::default()
}
}
#[must_use]
pub fn worker_node(&self, worker_id: usize) -> NumaNode {
if self.workers_per_node == usize::MAX {
0
} else {
worker_id / self.workers_per_node
}
}
}
pub struct MorselScheduler {
num_workers: usize,
global_queue: Injector<Morsel>,
stealers: Mutex<Vec<Stealer<Morsel>>>,
active_morsels: AtomicUsize,
total_submitted: AtomicUsize,
submission_done: AtomicBool,
done: AtomicBool,
numa_config: NumaConfig,
}
impl MorselScheduler {
#[must_use]
pub fn new(num_workers: usize) -> Self {
Self::with_numa_config(num_workers, NumaConfig::auto_detect(num_workers))
}
#[must_use]
pub fn with_numa_config(num_workers: usize, numa_config: NumaConfig) -> Self {
Self {
num_workers,
global_queue: Injector::new(),
stealers: Mutex::new(Vec::with_capacity(num_workers)),
active_morsels: AtomicUsize::new(0),
total_submitted: AtomicUsize::new(0),
submission_done: AtomicBool::new(false),
done: AtomicBool::new(false),
numa_config,
}
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn submit(&self, morsel: Morsel) {
self.global_queue.push(morsel);
self.active_morsels.fetch_add(1, Ordering::Relaxed);
self.total_submitted.fetch_add(1, Ordering::Relaxed);
}
pub fn submit_batch(&self, morsels: Vec<Morsel>) {
let count = morsels.len();
for morsel in morsels {
self.global_queue.push(morsel);
}
self.active_morsels.fetch_add(count, Ordering::Relaxed);
self.total_submitted.fetch_add(count, Ordering::Relaxed);
}
pub fn finish_submission(&self) {
self.submission_done.store(true, Ordering::Release);
if self.active_morsels.load(Ordering::Acquire) == 0 {
self.done.store(true, Ordering::Release);
}
}
pub fn register_worker(&self, stealer: Stealer<Morsel>) -> usize {
let mut stealers = self.stealers.lock();
let worker_id = stealers.len();
stealers.push(stealer);
worker_id
}
pub fn get_global_work(&self) -> Option<Morsel> {
loop {
match self.global_queue.steal() {
Steal::Success(morsel) => return Some(morsel),
Steal::Empty => return None,
Steal::Retry => continue,
}
}
}
pub fn steal_work(&self, my_id: usize) -> Option<Morsel> {
let stealers = self.stealers.lock();
let num_stealers = stealers.len();
if num_stealers <= 1 {
return None;
}
let my_node = self.numa_config.worker_node(my_id);
for i in 1..num_stealers {
let victim = (my_id + i) % num_stealers;
let victim_node = self.numa_config.worker_node(victim);
if victim_node != my_node {
continue;
}
if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
return Some(morsel);
}
}
for i in 1..num_stealers {
let victim = (my_id + i) % num_stealers;
let victim_node = self.numa_config.worker_node(victim);
if victim_node == my_node {
continue;
}
if let Some(morsel) = Self::try_steal_from(&stealers[victim]) {
return Some(morsel);
}
}
None
}
fn try_steal_from(stealer: &Stealer<Morsel>) -> Option<Morsel> {
loop {
match stealer.steal() {
Steal::Success(morsel) => return Some(morsel),
Steal::Empty => return None,
Steal::Retry => continue,
}
}
}
#[must_use]
pub fn worker_node(&self, worker_id: usize) -> NumaNode {
self.numa_config.worker_node(worker_id)
}
pub fn complete_morsel(&self) {
let prev = self.active_morsels.fetch_sub(1, Ordering::Release);
if prev == 1 && self.submission_done.load(Ordering::Acquire) {
self.done.store(true, Ordering::Release);
}
}
#[must_use]
pub fn is_done(&self) -> bool {
self.done.load(Ordering::Acquire)
}
#[must_use]
pub fn is_submission_done(&self) -> bool {
self.submission_done.load(Ordering::Acquire)
}
#[must_use]
pub fn active_count(&self) -> usize {
self.active_morsels.load(Ordering::Relaxed)
}
#[must_use]
pub fn total_submitted(&self) -> usize {
self.total_submitted.load(Ordering::Relaxed)
}
}
impl std::fmt::Debug for MorselScheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MorselScheduler")
.field("num_workers", &self.num_workers)
.field(
"active_morsels",
&self.active_morsels.load(Ordering::Relaxed),
)
.field(
"total_submitted",
&self.total_submitted.load(Ordering::Relaxed),
)
.field(
"submission_done",
&self.submission_done.load(Ordering::Relaxed),
)
.field("done", &self.done.load(Ordering::Relaxed))
.finish()
}
}
pub struct WorkerHandle {
scheduler: Arc<MorselScheduler>,
worker_id: usize,
local_queue: Worker<Morsel>,
}
impl WorkerHandle {
#[must_use]
pub fn new(scheduler: Arc<MorselScheduler>) -> Self {
let local_queue = Worker::new_fifo();
let worker_id = scheduler.register_worker(local_queue.stealer());
Self {
scheduler,
worker_id,
local_queue,
}
}
pub fn get_work(&self) -> Option<Morsel> {
if let Some(morsel) = self.local_queue.pop() {
return Some(morsel);
}
if let Some(morsel) = self.scheduler.get_global_work() {
return Some(morsel);
}
if let Some(morsel) = self.scheduler.steal_work(self.worker_id) {
return Some(morsel);
}
if self.scheduler.is_submission_done() && self.scheduler.active_count() == 0 {
return None;
}
None
}
pub fn push_local(&self, morsel: Morsel) {
self.local_queue.push(morsel);
self.scheduler
.active_morsels
.fetch_add(1, Ordering::Relaxed);
}
pub fn complete_morsel(&self) {
self.scheduler.complete_morsel();
}
#[must_use]
pub fn worker_id(&self) -> usize {
self.worker_id
}
#[must_use]
pub fn is_done(&self) -> bool {
self.scheduler.is_done()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_creation() {
let scheduler = MorselScheduler::new(4);
assert_eq!(scheduler.num_workers(), 4);
assert_eq!(scheduler.active_count(), 0);
assert!(!scheduler.is_done());
}
#[test]
fn test_submit_and_get_work() {
let scheduler = Arc::new(MorselScheduler::new(2));
scheduler.submit(Morsel::new(0, 0, 0, 1000));
scheduler.submit(Morsel::new(1, 0, 1000, 2000));
assert_eq!(scheduler.total_submitted(), 2);
assert_eq!(scheduler.active_count(), 2);
let morsel = scheduler.get_global_work().unwrap();
assert_eq!(morsel.id, 0);
scheduler.complete_morsel();
assert_eq!(scheduler.active_count(), 1);
let morsel = scheduler.get_global_work().unwrap();
assert_eq!(morsel.id, 1);
scheduler.complete_morsel();
scheduler.finish_submission();
assert!(scheduler.is_done());
}
#[test]
fn test_submit_batch() {
let scheduler = MorselScheduler::new(4);
let morsels = vec![
Morsel::new(0, 0, 0, 100),
Morsel::new(1, 0, 100, 200),
Morsel::new(2, 0, 200, 300),
];
scheduler.submit_batch(morsels);
assert_eq!(scheduler.total_submitted(), 3);
assert_eq!(scheduler.active_count(), 3);
}
#[test]
fn test_worker_handle() {
let scheduler = Arc::new(MorselScheduler::new(2));
let handle = WorkerHandle::new(Arc::clone(&scheduler));
assert_eq!(handle.worker_id(), 0);
assert!(!handle.is_done());
scheduler.submit(Morsel::new(0, 0, 0, 100));
let morsel = handle.get_work().unwrap();
assert_eq!(morsel.id, 0);
handle.complete_morsel();
scheduler.finish_submission();
assert!(handle.is_done());
}
#[test]
fn test_worker_local_queue() {
let scheduler = Arc::new(MorselScheduler::new(2));
let handle = WorkerHandle::new(Arc::clone(&scheduler));
handle.push_local(Morsel::new(0, 0, 0, 100));
let morsel = handle.get_work().unwrap();
assert_eq!(morsel.id, 0);
}
#[test]
fn test_work_stealing() {
let scheduler = Arc::new(MorselScheduler::new(2));
let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
for i in 0..5 {
handle1.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
}
let _ = handle1.get_work().unwrap();
let stolen = handle2.get_work();
assert!(stolen.is_some());
}
#[test]
fn test_concurrent_workers() {
use std::thread;
let scheduler = Arc::new(MorselScheduler::new(4));
let total_morsels = 100;
for i in 0..total_morsels {
scheduler.submit(Morsel::new(i, 0, i * 100, (i + 1) * 100));
}
scheduler.finish_submission();
let completed = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let sched = Arc::clone(&scheduler);
let completed = Arc::clone(&completed);
handles.push(thread::spawn(move || {
let handle = WorkerHandle::new(sched);
let mut count = 0;
while let Some(_morsel) = handle.get_work() {
count += 1;
handle.complete_morsel();
}
completed.fetch_add(count, Ordering::Relaxed);
}));
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(completed.load(Ordering::Relaxed), total_morsels);
}
#[test]
fn test_numa_config_default() {
let config = NumaConfig::default();
assert_eq!(config.num_nodes, 1);
assert_eq!(config.worker_node(0), 0);
assert_eq!(config.worker_node(100), 0);
}
#[test]
fn test_numa_config_auto_detect() {
let config = NumaConfig::auto_detect(4);
assert_eq!(config.num_nodes, 1);
let config = NumaConfig::auto_detect(16);
assert_eq!(config.num_nodes, 2);
assert_eq!(config.workers_per_node, 8);
}
#[test]
fn test_numa_config_worker_node() {
let config = NumaConfig::with_topology(2, 4);
assert_eq!(config.worker_node(0), 0);
assert_eq!(config.worker_node(1), 0);
assert_eq!(config.worker_node(2), 0);
assert_eq!(config.worker_node(3), 0);
assert_eq!(config.worker_node(4), 1);
assert_eq!(config.worker_node(5), 1);
assert_eq!(config.worker_node(6), 1);
assert_eq!(config.worker_node(7), 1);
}
#[test]
fn test_scheduler_with_numa_config() {
let config = NumaConfig::with_topology(2, 2);
let scheduler = MorselScheduler::with_numa_config(4, config);
assert_eq!(scheduler.num_workers(), 4);
assert_eq!(scheduler.worker_node(0), 0);
assert_eq!(scheduler.worker_node(1), 0);
assert_eq!(scheduler.worker_node(2), 1);
assert_eq!(scheduler.worker_node(3), 1);
}
#[test]
fn test_numa_aware_stealing() {
let config = NumaConfig::with_topology(2, 2);
let scheduler = Arc::new(MorselScheduler::with_numa_config(4, config));
let handle0 = WorkerHandle::new(Arc::clone(&scheduler));
let handle1 = WorkerHandle::new(Arc::clone(&scheduler));
let handle2 = WorkerHandle::new(Arc::clone(&scheduler));
let _handle3 = WorkerHandle::new(Arc::clone(&scheduler));
for i in 0..10 {
handle0.push_local(Morsel::new(i, 0, i * 100, (i + 1) * 100));
}
let stolen1 = handle1.get_work();
assert!(stolen1.is_some(), "Same-node worker should steal first");
let stolen2 = handle2.get_work();
assert!(stolen2.is_some(), "Cross-node worker can steal");
}
}