use std::collections::BinaryHeap;
use log::debug;
use std::sync::Arc;
use tokio::sync::Mutex;
use async_trait::async_trait;
use tokio::sync::Notify;
use tokio::task::yield_now;
#[async_trait]
pub trait ClockTrait {
async fn get_current_cycle(&self) -> u64;
async fn wait_until(&self, cycle: u64);
async fn wait_for(&self, cycle: u64);
async fn forward_cycle(&self) -> Result<(), ClockError>;
}
pub enum ClockError {
NoMoreTasks,
MaxCycleReached,
}
#[derive(Debug)]
struct Task {
cycle: u64,
notify: Arc<Notify>,
}
impl PartialEq for Task {
fn eq(&self, other: &Self) -> bool {
self.cycle == other.cycle
}
}
impl Eq for Task {}
impl PartialOrd for Task {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
other.cycle.partial_cmp(&self.cycle)
}
}
impl Ord for Task {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.cycle.cmp(&self.cycle)
}
}
#[derive(Debug, Default)]
pub struct Clock {
current_cycle: Mutex<u64>,
waiting_tasks: Mutex<BinaryHeap<Task>>,
}
impl Clock {
pub fn new() -> Self {
Default::default()
}
}
#[async_trait]
impl ClockTrait for Clock {
async fn get_current_cycle(&self) -> u64 {
*self.current_cycle.lock().await
}
async fn wait_until(&self, cycle: u64) {
debug!("try to read the waiting tasks");
let mut waiting_tasks = self.waiting_tasks.lock().await;
let notify = Arc::new(Notify::new());
waiting_tasks.push(Task {
cycle,
notify: notify.clone(),
});
debug!("finished adding waiting tasks");
drop(waiting_tasks);
debug!("start to wait for notification");
notify.notified().await;
debug!("finished waiting for notification");
}
async fn wait_for(&self, cycle: u64) {
debug!("try to read the waiting tasks");
let mut waiting_tasks = self.waiting_tasks.lock().await;
let notify = Arc::new(Notify::new());
waiting_tasks.push(Task {
cycle,
notify: notify.clone(),
});
debug!("finished adding waiting tasks");
drop(waiting_tasks);
debug!("start to wait for notification");
notify.notified().await;
debug!("finished waiting for notification");
}
async fn forward_cycle(&self) -> Result<(), ClockError> {
debug!("start to forward cycle");
let mut waiting_tasks = self.waiting_tasks.lock().await;
debug!("get task: {:?}", waiting_tasks);
let next_task = waiting_tasks.pop().ok_or(ClockError::NoMoreTasks)?;
debug!("next tasks is {:?}", next_task);
let mut current_cycle = self.current_cycle.lock().await;
*current_cycle = next_task.cycle;
debug!("nodify the task");
next_task.notify.notify_one();
yield_now().await;
debug!("end to forward cycle and finished execute the notifiyed task, start new loop!");
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::init_log;
use super::*;
#[tokio::test]
async fn test() {
init_log("debug");
let clock = Arc::new(Clock::new());
let clock1 = clock.clone();
let task1 = async move {
println!("task1 start to wait");
clock1.wait_until(1).await;
println!("task1 end to wait");
};
let clock1 = clock.clone();
let task2 = async move {
println!("task2 start to wait");
clock1.wait_until(2).await;
println!("task2 end to wait");
};
let task_scheduler = async move {
while (clock.forward_cycle().await).is_ok() {
println!("task scheduler");
}
};
tokio::spawn(task2);
tokio::spawn(task1);
tokio::spawn(task_scheduler).await.unwrap();
}
#[tokio::test]
async fn looptest() {
init_log("info");
let clock = Arc::new(Clock::new());
let clock1 = clock.clone();
let generator = async move {
for i in 0..10 {
clock1.wait_until(10 * i).await;
println!("1: {}", i);
}
};
let clock1 = clock.clone();
let generator2 = async move {
for i in 0..10 {
clock1.wait_until(3 * i).await;
println!("2: {}", i);
}
};
let task_scheduler = async move { while (clock.forward_cycle().await).is_ok() {} };
tokio::spawn(generator);
tokio::spawn(generator2);
tokio::spawn(task_scheduler).await.unwrap();
}
}