use std::sync::Arc;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use super::{RawTask, handle::WakeChannel, queue::LocalQueue};
pub struct WorkStealingScheduler {
workers: Vec<WorkerContext>,
state: Arc<std::sync::atomic::AtomicU8>,
}
struct WorkerContext {
queue: Arc<LocalQueue>,
thread_handle: Option<JoinHandle<()>>,
}
const STATE_RUNNING: u8 = 0;
const STATE_SHUTTING_DOWN: u8 = 1;
const STATE_STOPPED: u8 = 2;
impl WorkStealingScheduler {
pub fn with_config(config: &WorkStealingConfig) -> std::io::Result<Self> {
let num_workers = config.num_workers;
let queue_size = config.queue_size;
let thread_name = config.thread_name.clone();
let state = Arc::new(std::sync::atomic::AtomicU8::new(STATE_RUNNING));
let mut workers = Vec::with_capacity(num_workers);
let mut worker_queues = Vec::with_capacity(num_workers);
for _worker_id in 0..num_workers {
let queue = Arc::new(LocalQueue::new(queue_size));
let wake = Arc::new(WakeChannel::new()?);
worker_queues.push((queue.clone(), wake));
}
for worker_id in 0..num_workers {
let (queue, _wake) = &worker_queues[worker_id];
let queues: Vec<_> = worker_queues.iter().map(|(q, _)| q.clone()).collect();
let state_clone = state.clone();
let thread_name = format!("{}-{}", thread_name, worker_id);
let thread_handle = thread::Builder::new().name(thread_name).spawn(move || {
Self::run_worker(worker_id, &queues, &state_clone);
})?;
workers.push(WorkerContext {
queue: queue.clone(),
thread_handle: Some(thread_handle),
});
}
Ok(Self { workers, state })
}
fn run_worker(
worker_id: usize,
queues: &[Arc<LocalQueue>],
state: &std::sync::atomic::AtomicU8,
) {
let my_queue = &queues[worker_id];
let num_workers = queues.len();
while state.load(std::sync::atomic::Ordering::Relaxed) == STATE_RUNNING {
let task = my_queue.pop().or_else(|| {
for i in 1..num_workers {
let target = (worker_id + i) % num_workers;
if let Some(task) = queues[target].pop() {
return Some(task);
}
}
None
});
if let Some(task) = task {
let completed = unsafe { crate::task::raw_task::poll_raw_task(task) };
if completed {
unsafe {
crate::task::raw_task::deallocate_completed_task(task);
}
}
} else {
thread::sleep(Duration::from_millis(1));
}
}
state.store(STATE_STOPPED, std::sync::atomic::Ordering::Release);
}
pub fn submit(&self, task: RawTask) -> Result<(), RawTask> {
static WORKER_INDEX: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
let index =
WORKER_INDEX.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % self.workers.len();
if self.workers[index].queue.push(task) {
Ok(())
} else {
for worker in &self.workers {
if worker.queue.push(task) {
return Ok(());
}
}
Err(task)
}
}
pub fn shutdown(&self) {
self.state
.store(STATE_SHUTTING_DOWN, std::sync::atomic::Ordering::Release);
}
pub fn join(&mut self) -> thread::Result<()> {
for worker in &mut self.workers {
if let Some(handle) = worker.thread_handle.take() {
handle.join()?;
}
}
Ok(())
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.workers.len()
}
}
impl Drop for WorkStealingScheduler {
fn drop(&mut self) {
self.shutdown();
let _ = self.join();
}
}
#[derive(Debug, Clone)]
pub struct WorkStealingConfig {
pub num_workers: usize,
pub queue_size: usize,
pub thread_name: String,
}
impl Default for WorkStealingConfig {
fn default() -> Self {
Self {
num_workers: 0, queue_size: 256,
thread_name: "hiver-worker".to_string(),
}
}
}
impl WorkStealingConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn worker_threads(mut self, count: usize) -> Self {
self.num_workers = if count == 0 { num_cpus::get() } else { count };
self
}
pub fn queue_size(mut self, size: usize) -> Self {
self.queue_size = size;
self
}
pub fn thread_name(mut self, name: impl Into<String>) -> Self {
self.thread_name = name.into();
self
}
pub fn build(self) -> std::io::Result<WorkStealingScheduler> {
WorkStealingScheduler::with_config(&self)
}
}
#[derive(Clone)]
pub struct WorkStealingHandle {
queues: Vec<Arc<LocalQueue>>,
}
impl WorkStealingHandle {
#[allow(dead_code)]
pub(crate) fn new(queues: Vec<Arc<LocalQueue>>) -> Self {
Self { queues }
}
pub fn submit(&self, task: RawTask) -> std::io::Result<()> {
static WORKER_INDEX: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
let index =
WORKER_INDEX.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % self.queues.len();
if self.queues[index].push(task) {
Ok(())
} else {
for queue in &self.queues {
if queue.push(task) {
return Ok(());
}
}
Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"All worker queues are full",
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_work_stealing_config() {
let config = WorkStealingConfig::new()
.worker_threads(4)
.queue_size(512)
.thread_name("test-worker");
assert_eq!(config.num_workers, 4);
assert_eq!(config.queue_size, 512);
assert_eq!(config.thread_name, "test-worker");
}
#[test]
fn test_work_stealing_config_default() {
let config = WorkStealingConfig::default();
assert_eq!(config.num_workers, 0); assert_eq!(config.queue_size, 256);
}
}