use core::future::Future;
use core::time::Duration;
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::collections::BinaryHeap;
use alloc::collections::VecDeque;
use alloc::sync::Arc;
use crate::os::sync::spinlock::IrqSpinlock;
use super::task::{TaskHandle, TaskId, TaskMetadata, TaskPriority, TaskRef};
static GLOBAL_WAKEUP_QUEUE: IrqSpinlock<VecDeque<TaskId>> = IrqSpinlock::new(VecDeque::new());
pub fn enqueue_task_wakeup(task_id: TaskId) {
let mut queue = GLOBAL_WAKEUP_QUEUE.lock();
if !queue.contains(&task_id) {
queue.push_back(task_id);
}
}
pub struct SingleCpuExecutor {
priority_task_queue: IrqSpinlock<BinaryHeap<PriorityTaskWrapper>>,
active_task_registry: IrqSpinlock<alloc::collections::BTreeMap<TaskId, Arc<TaskRef>>>,
executor_running: IrqSpinlock<bool>,
task_timeout_milliseconds: u64,
}
#[derive(Debug)]
struct PriorityTaskWrapper {
task_priority: TaskPriority,
task_reference: Arc<TaskRef>,
}
impl PriorityTaskWrapper {
fn new(task_ref: Arc<TaskRef>) -> Self {
Self {
task_priority: task_ref.priority(),
task_reference: task_ref,
}
}
}
impl PartialEq for PriorityTaskWrapper {
fn eq(&self, other: &Self) -> bool {
self.task_priority == other.task_priority
}
}
impl Eq for PriorityTaskWrapper {}
impl PartialOrd for PriorityTaskWrapper {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityTaskWrapper {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
other.task_priority.cmp(&self.task_priority)
}
}
impl SingleCpuExecutor {
pub fn new() -> Self {
Self::with_timeout(Duration::from_secs(1))
}
pub fn with_timeout(timeout: Duration) -> Self {
Self {
priority_task_queue: IrqSpinlock::new(BinaryHeap::new()),
active_task_registry: IrqSpinlock::new(BTreeMap::new()),
executor_running: IrqSpinlock::new(false),
task_timeout_milliseconds: timeout.as_millis() as u64,
}
}
pub fn global() -> &'static Self {
use core::sync::atomic::{AtomicPtr, Ordering};
static EXECUTOR_PTR: AtomicPtr<SingleCpuExecutor> = AtomicPtr::new(core::ptr::null_mut());
let ptr = EXECUTOR_PTR.load(Ordering::Acquire);
if ptr.is_null() {
let executor = Box::leak(Box::new(SingleCpuExecutor::new()));
match EXECUTOR_PTR.compare_exchange(
core::ptr::null_mut(),
executor,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => executor,
Err(existing) => {
unsafe {
let _ = Box::from_raw(executor);
}
unsafe { &*existing }
}
}
} else {
unsafe { &*ptr }
}
}
pub fn spawn<F, T>(&self, future: F) -> TaskHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let task_id = TaskId::new();
let metadata = TaskMetadata::new(task_id);
let wrapped_future = async move {
let _ = future.await;
};
let task_ref = Arc::new(TaskRef::new(wrapped_future, metadata));
let handle = TaskHandle::with_ref(task_id, task_ref.clone());
{
let mut registry = self.active_task_registry.lock();
registry.insert(task_id, task_ref);
}
self.add_task_to_queue(task_id);
debug!("Spawned task {task_id:?}",);
handle
}
fn add_task_to_queue(&self, task_id: TaskId) {
let registry = self.active_task_registry.lock();
if let Some(task_ref) = registry.get(&task_id) {
let priority_task = PriorityTaskWrapper::new(task_ref.clone());
let mut queue = self.priority_task_queue.lock();
queue.push(priority_task);
}
}
pub fn wake_by_id(&self, task_id: TaskId) -> bool {
let registry = self.active_task_registry.lock();
if let Some(task_ref) = registry.get(&task_id) {
task_ref.metadata.lock().mark_woken();
self.add_task_to_queue(task_id);
true
} else {
false
}
}
fn process_one_task(&self) -> bool {
let priority_task = {
let mut queue = self.priority_task_queue.lock();
queue.pop()
};
if let Some(priority_task) = priority_task {
let task_ref = priority_task.task_reference.clone();
let task_id = task_ref.id();
if task_ref.is_completed() {
let mut registry = self.active_task_registry.lock();
registry.remove(&task_id);
debug!("Task {task_id:?} completed and cleaned up");
return true;
}
let should_execute = {
let mut metadata = task_ref.metadata.lock();
if metadata.state == super::task::TaskState::Woken {
true
} else if metadata.is_expired(self.task_timeout_milliseconds) {
metadata.mark_woken();
debug!("Task {task_id:?} expired, promoting priority");
true
} else {
false
}
};
if !should_execute {
return true;
}
let waker = ExecutorWaker::new(task_id);
let waker = waker.into_waker();
match task_ref.poll(&waker) {
core::task::Poll::Ready(()) => {
let mut registry = self.active_task_registry.lock();
registry.remove(&task_id);
debug!("Task {task_id:?} completed");
true
}
core::task::Poll::Pending => {
debug!("Task {task_id:?} pending, waiting for wake");
true
}
}
} else {
false }
}
pub fn tick(&self) {
self.process_wake_queue();
let mut processed = 0;
const MAX_TASKS_PER_TICK: usize = 10;
while processed < MAX_TASKS_PER_TICK && self.process_one_task() {
processed += 1;
}
if processed == 0 {
log::debug!("No tasks to process in this tick");
}
}
fn process_wake_queue(&self) {
loop {
let task_id = {
let mut queue = GLOBAL_WAKEUP_QUEUE.lock();
queue.pop_front()
};
if let Some(task_id) = task_id {
let registry = self.active_task_registry.lock();
if let Some(task_ref) = registry.get(&task_id) {
task_ref.metadata.lock().mark_woken();
let priority_task = PriorityTaskWrapper::new(task_ref.clone());
let mut queue = self.priority_task_queue.lock();
queue.push(priority_task);
}
} else {
break;
}
}
}
pub fn run_until_completion(&self) {
*self.executor_running.lock() = true;
debug!("Executor started, running until completion");
while self.has_pending_tasks() {
self.tick();
for _ in 0..1000 {
core::hint::spin_loop();
}
}
*self.executor_running.lock() = false;
debug!("Executor finished, all tasks completed");
}
pub fn has_pending_tasks(&self) -> bool {
if !GLOBAL_WAKEUP_QUEUE.lock().is_empty() {
return true;
}
if !self.priority_task_queue.lock().is_empty() {
return true;
}
let registry = self.active_task_registry.lock();
!registry.is_empty()
}
pub fn task_count(&self) -> usize {
self.active_task_registry.lock().len()
}
pub fn queued_task_count(&self) -> usize {
self.priority_task_queue.lock().len()
}
pub fn is_running(&self) -> bool {
*self.executor_running.lock()
}
}
impl Default for SingleCpuExecutor {
fn default() -> Self {
Self::new()
}
}
pub fn spawn<F, T>(future: F) -> TaskHandle
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
SingleCpuExecutor::global().spawn(future)
}
pub fn block_on<F>(future: F)
where
F: Future<Output = ()> + Send + 'static,
{
let executor = SingleCpuExecutor::global();
executor.spawn(future);
executor.run_until_completion();
}
pub fn tick() {
SingleCpuExecutor::global().tick();
}
pub fn has_pending_tasks() -> bool {
SingleCpuExecutor::global().has_pending_tasks()
}
pub fn task_count() -> usize {
SingleCpuExecutor::global().task_count()
}
#[derive(Debug)]
struct ExecutorWaker {
task_id: TaskId,
}
impl ExecutorWaker {
fn new(task_id: TaskId) -> Self {
Self { task_id }
}
fn into_waker(self) -> core::task::Waker {
let arc = Arc::new(self);
core::task::Waker::from(arc)
}
}
impl alloc::task::Wake for ExecutorWaker {
fn wake(self: Arc<Self>) {
enqueue_task_wakeup(self.task_id);
}
fn wake_by_ref(self: &Arc<Self>) {
enqueue_task_wakeup(self.task_id);
}
}