use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use crate::error::{KernelError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct WorkerId(pub u32);
impl WorkerId {
#[must_use]
pub const fn new(id: u32) -> Self {
Self(id)
}
#[must_use]
pub const fn as_u32(self) -> u32 {
self.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(pub u64);
impl TaskId {
#[must_use]
pub const fn new(id: u64) -> Self {
Self(id)
}
#[must_use]
pub const fn as_u64(self) -> u64 {
self.0
}
}
#[derive(Debug)]
pub struct WorkStealingDeque<T> {
deque: Mutex<VecDeque<T>>,
len: AtomicUsize,
capacity: usize,
}
impl<T> WorkStealingDeque<T> {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
deque: Mutex::new(VecDeque::with_capacity(capacity)),
len: AtomicUsize::new(0),
capacity,
}
}
pub fn push(&self, item: T) -> Result<()> {
let mut deque = self.deque.lock().map_err(|_| KernelError::WouldBlock)?;
if deque.len() >= self.capacity {
return Err(KernelError::UblkQueueFull);
}
deque.push_back(item);
self.len.fetch_add(1, Ordering::Release);
Ok(())
}
pub fn pop(&self) -> Option<T> {
let mut deque = self.deque.lock().ok()?;
let item = deque.pop_back()?;
self.len.fetch_sub(1, Ordering::Release);
Some(item)
}
pub fn steal(&self) -> Option<T> {
let mut deque = self.deque.lock().ok()?;
let item = deque.pop_front()?;
self.len.fetch_sub(1, Ordering::Release);
Some(item)
}
pub fn steal_half(&self) -> Vec<T> {
let mut deque = match self.deque.lock() {
Ok(d) => d,
Err(_) => return Vec::new(),
};
let steal_count = deque.len() / 2;
if steal_count == 0 {
return Vec::new();
}
let mut stolen = Vec::with_capacity(steal_count);
for _ in 0..steal_count {
if let Some(item) = deque.pop_front() {
stolen.push(item);
}
}
self.len.fetch_sub(stolen.len(), Ordering::Release);
stolen
}
#[must_use]
pub fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
}
impl<T> Default for WorkStealingDeque<T> {
fn default() -> Self {
Self::new(1024)
}
}
#[derive(Debug)]
struct WorkerState<T> {
deque: WorkStealingDeque<T>,
active: AtomicBool,
}
impl<T> WorkerState<T> {
fn new(capacity: usize) -> Self {
Self {
deque: WorkStealingDeque::new(capacity),
active: AtomicBool::new(true),
}
}
}
#[derive(Debug)]
pub struct Scheduler<T> {
workers: RwLock<Vec<Arc<WorkerState<T>>>>,
num_workers: AtomicUsize,
next_task_id: AtomicU64,
queue_capacity: usize,
submit_counter: AtomicUsize,
running: AtomicBool,
}
impl<T> Scheduler<T> {
#[must_use]
pub fn new(num_workers: usize) -> Self {
Self::with_capacity(num_workers, 1024)
}
#[must_use]
pub fn with_capacity(num_workers: usize, queue_capacity: usize) -> Self {
let workers: Vec<_> = (0..num_workers)
.map(|_| Arc::new(WorkerState::new(queue_capacity)))
.collect();
Self {
workers: RwLock::new(workers),
num_workers: AtomicUsize::new(num_workers),
next_task_id: AtomicU64::new(0),
queue_capacity,
submit_counter: AtomicUsize::new(0),
running: AtomicBool::new(true),
}
}
pub fn submit(&self, task: T) -> Option<TaskId> {
if !self.running.load(Ordering::Acquire) {
return None;
}
let workers = self.workers.read().ok()?;
if workers.is_empty() {
return None;
}
let start = self.submit_counter.fetch_add(1, Ordering::Relaxed) % workers.len();
let target_idx = (0..workers.len())
.map(|i| (start + i) % workers.len())
.find(|&idx| {
workers[idx].active.load(Ordering::Acquire)
&& workers[idx].deque.len() < workers[idx].deque.capacity()
})?;
if workers[target_idx].deque.push(task).is_ok() {
let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed);
return Some(TaskId(task_id));
}
None
}
pub fn pop(&self, worker_id: WorkerId) -> Option<T> {
let workers = self.workers.read().ok()?;
let idx = worker_id.as_u32() as usize;
if idx >= workers.len() {
return None;
}
workers[idx].deque.pop()
}
pub fn steal(&self, thief_id: WorkerId) -> Option<T> {
let workers = self.workers.read().ok()?;
let thief_idx = thief_id.as_u32() as usize;
if workers.len() <= 1 || thief_idx >= workers.len() {
return None;
}
for offset in 1..workers.len() {
let victim_idx = (thief_idx + offset) % workers.len();
if workers[victim_idx].active.load(Ordering::Acquire) {
if let Some(task) = workers[victim_idx].deque.steal() {
return Some(task);
}
}
}
None
}
pub fn steal_batch(&self, thief_id: WorkerId) -> Vec<T> {
let workers = match self.workers.read() {
Ok(w) => w,
Err(_) => return Vec::new(),
};
let thief_idx = thief_id.as_u32() as usize;
if workers.len() <= 1 || thief_idx >= workers.len() {
return Vec::new();
}
let mut busiest_idx = None;
let mut max_len = 0;
for (idx, worker) in workers.iter().enumerate() {
if idx != thief_idx && worker.active.load(Ordering::Acquire) {
let len = worker.deque.len();
if len > max_len {
max_len = len;
busiest_idx = Some(idx);
}
}
}
match busiest_idx {
Some(idx) => workers[idx].deque.steal_half(),
None => Vec::new(),
}
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.num_workers.load(Ordering::Acquire)
}
#[must_use]
pub fn pending_tasks(&self) -> usize {
let workers = match self.workers.read() {
Ok(w) => w,
Err(_) => return 0,
};
workers.iter().map(|w| w.deque.len()).sum()
}
#[must_use]
pub fn worker_loads(&self) -> Vec<usize> {
let workers = match self.workers.read() {
Ok(w) => w,
Err(_) => return Vec::new(),
};
workers.iter().map(|w| w.deque.len()).collect()
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Acquire)
}
pub fn stop(&self) {
self.running.store(false, Ordering::Release);
}
pub fn add_worker(&self) -> Option<WorkerId> {
let mut workers = self.workers.write().ok()?;
let id = workers.len() as u32;
workers.push(Arc::new(WorkerState::new(self.queue_capacity)));
self.num_workers.fetch_add(1, Ordering::Release);
Some(WorkerId(id))
}
pub fn deactivate_worker(&self, worker_id: WorkerId) -> bool {
let workers = match self.workers.read() {
Ok(w) => w,
Err(_) => return false,
};
let idx = worker_id.as_u32() as usize;
if idx >= workers.len() {
return false;
}
workers[idx].active.store(false, Ordering::Release);
true
}
#[must_use]
pub fn is_worker_active(&self, worker_id: WorkerId) -> bool {
let workers = match self.workers.read() {
Ok(w) => w,
Err(_) => return false,
};
let idx = worker_id.as_u32() as usize;
if idx >= workers.len() {
return false;
}
workers[idx].active.load(Ordering::Acquire)
}
}
impl<T> Default for Scheduler<T> {
fn default() -> Self {
Self::new(4)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_worker_id_new() {
let id = WorkerId::new(42);
assert_eq!(id.as_u32(), 42);
}
#[test]
fn test_worker_id_equality() {
let a = WorkerId(1);
let b = WorkerId(1);
let c = WorkerId(2);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_worker_id_hash() {
use std::collections::HashSet;
let mut set = HashSet::new();
set.insert(WorkerId(1));
set.insert(WorkerId(2));
set.insert(WorkerId(1)); assert_eq!(set.len(), 2);
}
#[test]
fn test_task_id_new() {
let id = TaskId::new(100);
assert_eq!(id.as_u64(), 100);
}
#[test]
fn test_task_id_equality() {
let a = TaskId(1);
let b = TaskId(1);
let c = TaskId(2);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_deque_new() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::new(100);
assert_eq!(deque.capacity(), 100);
assert_eq!(deque.len(), 0);
assert!(deque.is_empty());
}
#[test]
fn test_deque_push_pop() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::new(100);
deque.push(1).unwrap();
deque.push(2).unwrap();
deque.push(3).unwrap();
assert_eq!(deque.len(), 3);
assert_eq!(deque.pop(), Some(3));
assert_eq!(deque.pop(), Some(2));
assert_eq!(deque.pop(), Some(1));
assert_eq!(deque.pop(), None);
}
#[test]
fn test_deque_steal() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::new(100);
deque.push(1).unwrap();
deque.push(2).unwrap();
deque.push(3).unwrap();
assert_eq!(deque.steal(), Some(1));
assert_eq!(deque.steal(), Some(2));
assert_eq!(deque.steal(), Some(3));
assert_eq!(deque.steal(), None);
}
#[test]
fn test_deque_steal_half() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::new(100);
for i in 0..10 {
deque.push(i).unwrap();
}
let stolen = deque.steal_half();
assert_eq!(stolen.len(), 5);
assert_eq!(deque.len(), 5);
assert_eq!(stolen, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_deque_capacity_limit() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::new(3);
assert!(deque.push(1).is_ok());
assert!(deque.push(2).is_ok());
assert!(deque.push(3).is_ok());
assert!(deque.push(4).is_err()); }
#[test]
fn test_deque_default() {
let deque: WorkStealingDeque<u32> = WorkStealingDeque::default();
assert_eq!(deque.capacity(), 1024);
}
#[test]
fn test_scheduler_new() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
assert_eq!(scheduler.num_workers(), 4);
assert!(scheduler.is_running());
assert_eq!(scheduler.pending_tasks(), 0);
}
#[test]
fn test_scheduler_submit() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
let task_id = scheduler.submit(42);
assert!(task_id.is_some());
assert_eq!(scheduler.pending_tasks(), 1);
}
#[test]
fn test_scheduler_submit_multiple() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
for i in 0..100 {
let task_id = scheduler.submit(i);
assert!(task_id.is_some());
}
assert_eq!(scheduler.pending_tasks(), 100);
}
#[test]
fn test_scheduler_pop() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
scheduler.submit(42);
let loads = scheduler.worker_loads();
let worker_idx = loads.iter().position(|&l| l > 0).unwrap();
let task = scheduler.pop(WorkerId(worker_idx as u32));
assert_eq!(task, Some(42));
}
#[test]
fn test_scheduler_steal() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
for i in 0..10 {
scheduler.submit(i);
}
let stolen = scheduler.steal(WorkerId(1));
assert!(stolen.is_some());
}
#[test]
fn test_scheduler_steal_batch() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
for i in 0..20 {
scheduler.submit(i);
}
let stolen = scheduler.steal_batch(WorkerId(1));
assert!(!stolen.is_empty());
}
#[test]
fn test_scheduler_worker_loads() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
for i in 0..20 {
scheduler.submit(i);
}
let loads = scheduler.worker_loads();
assert_eq!(loads.len(), 4);
assert_eq!(loads.iter().sum::<usize>(), 20);
}
#[test]
fn test_scheduler_stop() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
assert!(scheduler.is_running());
scheduler.stop();
assert!(!scheduler.is_running());
let task_id = scheduler.submit(42);
assert!(task_id.is_none());
}
#[test]
fn test_scheduler_add_worker() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
assert_eq!(scheduler.num_workers(), 2);
let new_id = scheduler.add_worker();
assert!(new_id.is_some());
assert_eq!(new_id.unwrap(), WorkerId(2));
assert_eq!(scheduler.num_workers(), 3);
}
#[test]
fn test_scheduler_deactivate_worker() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
assert!(scheduler.is_worker_active(WorkerId(0)));
assert!(scheduler.deactivate_worker(WorkerId(0)));
assert!(!scheduler.is_worker_active(WorkerId(0)));
}
#[test]
fn test_scheduler_deactivate_invalid_worker() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
assert!(!scheduler.deactivate_worker(WorkerId(99)));
}
#[test]
fn test_scheduler_pop_invalid_worker() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
scheduler.submit(42);
let task = scheduler.pop(WorkerId(99));
assert!(task.is_none());
}
#[test]
fn test_scheduler_default() {
let scheduler: Scheduler<u32> = Scheduler::default();
assert_eq!(scheduler.num_workers(), 4);
}
#[test]
fn test_scheduler_concurrent_submit() {
use std::thread;
let scheduler = Arc::new(Scheduler::<u32>::new(4));
let mut handles = Vec::new();
for t in 0..4 {
let s = Arc::clone(&scheduler);
handles.push(thread::spawn(move || {
for i in 0..100 {
s.submit(t * 100 + i);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(scheduler.pending_tasks(), 400);
}
#[test]
fn test_scheduler_concurrent_pop() {
use std::thread;
let scheduler = Arc::new(Scheduler::<u32>::new(4));
for i in 0..100 {
scheduler.submit(i);
}
let mut handles = Vec::new();
let popped = Arc::new(AtomicUsize::new(0));
for worker_id in 0..4 {
let s = Arc::clone(&scheduler);
let p = Arc::clone(&popped);
handles.push(thread::spawn(move || {
while s.pop(WorkerId(worker_id)).is_some() {
p.fetch_add(1, Ordering::Relaxed);
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(popped.load(Ordering::Relaxed), 100);
}
#[test]
fn test_scheduler_work_stealing_correctness() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
for _ in 0..10 {
scheduler.submit(42);
}
let loads = scheduler.worker_loads();
let total_initial = loads.iter().sum::<usize>();
assert_eq!(total_initial, 10);
let mut consumed = 0;
while let Some(_) = scheduler.pop(WorkerId(0)) {
consumed += 1;
}
while let Some(_) = scheduler.pop(WorkerId(1)) {
consumed += 1;
}
assert_eq!(scheduler.pending_tasks(), 0);
assert_eq!(consumed, 10);
}
#[test]
fn test_invariant_task_count_preserved() {
let scheduler: Scheduler<u32> = Scheduler::new(4);
let n = 50;
for i in 0..n {
scheduler.submit(i);
}
let mut popped = 0;
for worker_id in 0..4 {
while scheduler.pop(WorkerId(worker_id)).is_some() {
popped += 1;
}
}
assert_eq!(popped, n);
}
#[test]
fn test_invariant_steal_does_not_duplicate() {
let scheduler: Scheduler<u32> = Scheduler::new(2);
for i in 0..10 {
scheduler.submit(i);
}
let mut collected = Vec::new();
for worker_id in 0..2 {
while let Some(task) = scheduler.pop(WorkerId(worker_id)) {
collected.push(task);
}
}
while let Some(task) = scheduler.steal(WorkerId(0)) {
collected.push(task);
}
collected.sort();
let len_before = collected.len();
collected.dedup();
assert_eq!(collected.len(), len_before);
}
}