use std::os::fd::RawFd;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread::{self, JoinHandle};
use std::time::Duration;
use super::{RawTask, handle::SchedulerHandle, queue::LocalQueue};
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub queue_size: usize,
pub cpu_affinity: Option<usize>,
pub thread_name: String,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
queue_size: 256,
cpu_affinity: None,
thread_name: "hiver-worker".to_string(),
}
}
}
pub struct Scheduler {
queue: Arc<LocalQueue>,
inject_queue: Arc<LocalQueue>,
wake: Arc<super::handle::WakeChannel>,
state: Arc<std::sync::atomic::AtomicU8>,
thread_handle: Option<JoinHandle<()>>,
task_wakers: Arc<Mutex<std::collections::HashMap<u64, std::task::Waker>>>,
}
const STATE_RUNNING: u8 = 0;
const STATE_SHUTTING_DOWN: u8 = 1;
const STATE_STOPPED: u8 = 2;
impl Scheduler {
pub fn new() -> std::io::Result<Self> {
Self::with_config(&SchedulerConfig::default())
}
pub fn with_config(config: &SchedulerConfig) -> std::io::Result<Self> {
let queue = Arc::new(LocalQueue::new(config.queue_size));
let inject_queue = Arc::new(LocalQueue::new(config.queue_size));
let wake = Arc::new(super::handle::WakeChannel::new()?);
let task_wakers = Arc::new(Mutex::new(std::collections::HashMap::new()));
let state = Arc::new(std::sync::atomic::AtomicU8::new(STATE_RUNNING));
let queue_clone = queue.clone();
let inject_queue_clone = inject_queue.clone();
let wake_clone = wake.clone();
let state_clone = state.clone();
let thread_name = config.thread_name.clone();
let cpu_affinity = config.cpu_affinity;
let thread_handle = thread::Builder::new().name(thread_name).spawn(move || {
if let Some(core) = cpu_affinity {
Self::set_cpu_affinity(core);
}
Self::run_scheduler(&queue_clone, &inject_queue_clone, &wake_clone, &state_clone);
})?;
Ok(Self {
queue,
inject_queue,
wake,
state,
thread_handle: Some(thread_handle),
task_wakers,
})
}
pub fn with_config_and_driver(
config: &SchedulerConfig,
_driver: Arc<dyn crate::driver::Driver>,
) -> std::io::Result<Self> {
let queue = Arc::new(LocalQueue::new(config.queue_size));
let inject_queue = Arc::new(LocalQueue::new(config.queue_size));
let wake = Arc::new(super::handle::WakeChannel::new()?);
let task_wakers = Arc::new(Mutex::new(std::collections::HashMap::new()));
let state = Arc::new(std::sync::atomic::AtomicU8::new(STATE_RUNNING));
let queue_clone = queue.clone();
let inject_queue_clone = inject_queue.clone();
let wake_clone = wake.clone();
let state_clone = state.clone();
let thread_name = config.thread_name.clone();
let cpu_affinity = config.cpu_affinity;
let thread_handle = thread::Builder::new().name(thread_name).spawn(move || {
if let Some(core) = cpu_affinity {
Self::set_cpu_affinity(core);
}
Self::run_scheduler(&queue_clone, &inject_queue_clone, &wake_clone, &state_clone);
})?;
Ok(Self {
queue,
inject_queue,
wake,
state,
thread_handle: Some(thread_handle),
task_wakers,
})
}
#[must_use]
pub fn handle(&self) -> SchedulerHandle {
SchedulerHandle::new(self.inject_queue.clone(), self.wake.clone())
}
pub fn shutdown(&self) {
self.state
.store(STATE_SHUTTING_DOWN, std::sync::atomic::Ordering::Release);
self.wake.notify();
}
pub fn join(&mut self) -> thread::Result<()> {
if let Some(handle) = self.thread_handle.take() {
handle.join()
} else {
Ok(())
}
}
fn run_scheduler(
local_queue: &LocalQueue,
inject_queue: &LocalQueue,
wake: &super::handle::WakeChannel,
state: &std::sync::atomic::AtomicU8,
) {
while state.load(std::sync::atomic::Ordering::Relaxed) == STATE_RUNNING {
let task = local_queue.pop().or_else(|| {
inject_queue.pop()
});
if let Some(task) = task {
let completed = unsafe { crate::task::raw_task::poll_raw_task(task) };
if completed {
unsafe {
crate::task::raw_task::deallocate_completed_task(task);
}
}
} else {
wake.recv_timeout(Duration::from_millis(10));
}
}
state.store(STATE_STOPPED, std::sync::atomic::Ordering::Release);
}
#[cfg(target_os = "linux")]
fn set_cpu_affinity(core: usize) {
unsafe {
let mut cpu_set: libc::cpu_set_t = std::mem::zeroed();
libc::CPU_ZERO(&mut cpu_set);
libc::CPU_SET(core % libc::CPU_SETSIZE as usize, &mut cpu_set);
let _ = libc::sched_setaffinity(0, size_of::<libc::cpu_set_t>(), &cpu_set);
}
}
#[cfg(not(target_os = "linux"))]
fn set_cpu_affinity(_core: usize) {
}
pub fn submit(&self, task: RawTask) -> Result<(), RawTask> {
if self.queue.push(task) {
self.wake.notify();
Ok(())
} else {
Err(task)
}
}
#[must_use]
pub fn wake_fd(&self) -> RawFd {
self.wake.raw_fd()
}
pub fn get_task_waker(&self, id: u64) -> Option<std::task::Waker> {
let wakers = self.task_wakers.lock().unwrap();
wakers.get(&id).cloned()
}
pub fn register_task_waker(&self, id: u64, waker: std::task::Waker) {
let mut wakers = self.task_wakers.lock().unwrap();
wakers.insert(id, waker);
}
pub fn remove_task_waker(&self, id: u64) -> Option<std::task::Waker> {
let mut wakers = self.task_wakers.lock().unwrap();
wakers.remove(&id)
}
}
impl Drop for Scheduler {
fn drop(&mut self) {
self.shutdown();
let _ = self.join();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_creation() {
let scheduler = Scheduler::new();
assert!(scheduler.is_ok());
let scheduler = scheduler.unwrap();
let handle = scheduler.handle();
assert!(handle.submit(0x1000 as RawTask).is_ok());
}
#[test]
fn test_scheduler_config() {
let config = SchedulerConfig {
queue_size: 512,
cpu_affinity: Some(0),
thread_name: "test-worker".to_string(),
};
let scheduler = Scheduler::with_config(&config);
assert!(scheduler.is_ok());
}
#[test]
fn test_scheduler_submit_and_handle() {
let scheduler = Scheduler::new().unwrap();
let handle = scheduler.handle();
assert!(handle.submit(0x1000 as RawTask).is_ok());
assert!(handle.submit(0x2000 as RawTask).is_ok());
assert!(handle.wake_fd() >= 0);
}
#[test]
fn test_scheduler_waker_store_empty() {
let scheduler = Scheduler::new().unwrap();
assert!(scheduler.get_task_waker(9999).is_none());
assert!(scheduler.remove_task_waker(9999).is_none());
}
#[test]
fn test_scheduler_shutdown() {
let scheduler = Scheduler::new().unwrap();
scheduler.shutdown();
}
}