phala_scheduler/
task_scheduler.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex, Weak};
5use std::time::Instant;
6
7use rbtree::RBTree;
8use std::task;
9
10pub type VirtualTime = u128;
11
12pub trait TaskIdType: Clone + Send + Eq + Hash + Debug + 'static {}
13impl<T: Clone + Send + Eq + Hash + Debug + 'static> TaskIdType for T {}
14
15type WeakScheduler<TaskId> = Weak<Mutex<SchedulerInner<TaskId>>>;
16
17#[derive(Clone)]
18pub struct TaskScheduler<TaskId: TaskIdType> {
19    inner: Arc<Mutex<SchedulerInner<TaskId>>>,
20}
21
22impl<TaskId: TaskIdType> TaskScheduler<TaskId> {
23    pub fn new(virtual_cores: u32) -> Self {
24        Self {
25            inner: Arc::new_cyclic(|weak_inner| {
26                Mutex::new(SchedulerInner::new(virtual_cores, weak_inner.clone()))
27            }),
28        }
29    }
30
31    pub fn poll_resume(
32        &self,
33        cx: &task::Context<'_>,
34        task_id: &TaskId,
35        weight: u32,
36    ) -> task::Poll<RunningGuard<TaskId>> {
37        self.inner.lock().unwrap().poll_resume(cx, task_id, weight)
38    }
39
40    pub fn reset(&self, task_id: &TaskId) {
41        self.exit(task_id)
42    }
43
44    pub fn exit(&self, task_id: &TaskId) {
45        self.inner.lock().unwrap().exit(task_id)
46    }
47}
48
49#[derive(Debug, PartialEq, Eq)]
50enum TaskState {
51    Idle,
52    Ready,
53    ToRun,
54    Running,
55}
56
57struct Task {
58    state: TaskState,
59    virtual_runtime: VirtualTime,
60}
61
62struct ReadyTask<TaskId> {
63    id: TaskId,
64    waker: task::Waker,
65}
66
67pub struct RunningGuard<TaskId: TaskIdType> {
68    queue: WeakScheduler<TaskId>,
69    task_id: TaskId,
70    start_time: Instant,
71    actual_cost: Option<VirtualTime>,
72    weight: u32,
73}
74
75impl<TaskId: TaskIdType> RunningGuard<TaskId> {
76    pub fn set_cost(&mut self, cost: VirtualTime) {
77        self.actual_cost = Some(cost);
78    }
79}
80
81struct SchedulerInner<TaskId: TaskIdType> {
82    weak_self: WeakScheduler<TaskId>,
83    tasks: HashMap<TaskId, Task>,
84    ready_tasks: RBTree<VirtualTime, ReadyTask<TaskId>>,
85    virtual_clock: VirtualTime,
86    virtual_cores: u32,
87    running_tasks: u32,
88}
89
90unsafe impl<T: TaskIdType> Send for SchedulerInner<T> {}
91
92impl<TaskId: TaskIdType> SchedulerInner<TaskId> {
93    fn new(virtual_cores: u32, weak_self: WeakScheduler<TaskId>) -> Self {
94        Self {
95            weak_self,
96            tasks: Default::default(),
97            ready_tasks: RBTree::new(),
98            virtual_cores,
99            virtual_clock: 0,
100            running_tasks: 0,
101        }
102    }
103
104    fn exit(&mut self, id: &TaskId) {
105        let task = match self.tasks.remove(id) {
106            Some(task) => task,
107            None => return,
108        };
109        if matches!(task.state, TaskState::ToRun | TaskState::Running) {
110            self.running_tasks -= 1;
111            self.schedule();
112        }
113    }
114
115    fn poll_resume(
116        &mut self,
117        cx: &task::Context<'_>,
118        id: &TaskId,
119        weight: u32,
120    ) -> task::Poll<RunningGuard<TaskId>> {
121        self.maybe_reset_clock();
122
123        let task = self.tasks.entry(id.clone()).or_insert_with(|| Task {
124            state: TaskState::Idle,
125            virtual_runtime: self.virtual_clock,
126        });
127
128        match task.state {
129            TaskState::Idle => {
130                let ready = ReadyTask {
131                    id: id.clone(),
132                    waker: cx.waker().clone(),
133                };
134                task.virtual_runtime = task.virtual_runtime.max(self.virtual_clock);
135                task.state = TaskState::Ready;
136                self.ready_tasks.insert(task.virtual_runtime, ready);
137                self.schedule();
138                task::Poll::Pending
139            }
140            TaskState::Ready => {
141                for task in self.ready_tasks.values_mut() {
142                    if task.id == *id {
143                        let waker = cx.waker();
144                        if !task.waker.will_wake(waker) {
145                            task.waker = waker.clone();
146                        }
147                        break;
148                    }
149                }
150                task::Poll::Pending
151            }
152            TaskState::ToRun => {
153                let guard = RunningGuard {
154                    queue: self.weak_self.clone(),
155                    task_id: id.clone(),
156                    start_time: Instant::now(),
157                    actual_cost: None,
158                    weight,
159                };
160                task.state = TaskState::Running;
161                task::Poll::Ready(guard)
162            }
163            TaskState::Running => panic!("BUG: resuming a running task"),
164        }
165    }
166
167    /// Marks a task as finished and maybe schedule more tasks
168    fn park(&mut self, task_id: &TaskId, vruntime: VirtualTime) {
169        let task = match self.tasks.get_mut(task_id) {
170            Some(task) => task,
171            None => return,
172        };
173
174        assert_eq!(
175            task.state,
176            TaskState::Running,
177            "BUG: parking a non-running task"
178        );
179
180        task.virtual_runtime += vruntime.max(1);
181        task.state = TaskState::Idle;
182        self.running_tasks -= 1;
183        self.schedule();
184    }
185
186    /// Tries to wake up ready tasks in the schedule queue untils all the cores are occupied
187    fn schedule(&mut self) {
188        while self.running_tasks < self.virtual_cores {
189            let (vruntime, ready_task) = match self.ready_tasks.pop_first() {
190                Some(v) => v,
191                None => break,
192            };
193            let task = match self.tasks.get_mut(&ready_task.id) {
194                Some(task) => task,
195                // The task has already been dropped.
196                None => continue,
197            };
198            self.running_tasks += 1;
199            self.virtual_clock = vruntime;
200
201            task.state = TaskState::ToRun;
202            ready_task.waker.wake();
203        }
204    }
205
206    fn maybe_reset_clock(&mut self) {
207        if self.virtual_clock > VirtualTime::MAX / 2 {
208            for task in self.tasks.values_mut() {
209                task.virtual_runtime = task.virtual_runtime.saturating_sub(self.virtual_clock);
210            }
211            self.virtual_clock = 0;
212        }
213    }
214}
215
216impl<TaskId: TaskIdType> Drop for RunningGuard<TaskId> {
217    fn drop(&mut self) {
218        if let Some(inner) = self.queue.upgrade() {
219            let actual_cost = self.actual_cost.unwrap_or_else(|| {
220                let cost = self.start_time.elapsed().as_nanos() as VirtualTime;
221                // Scale it in order to avoid underflow while dividing the cost by the weight.
222                cost << 32
223            });
224            let vruntime = actual_cost / self.weight.max(1) as VirtualTime;
225            inner.lock().unwrap().park(&self.task_id, vruntime.max(1));
226        }
227    }
228}
229
230#[cfg(test)]
231mod test {
232    use super::*;
233    use std::future::Future;
234    use std::pin::Pin;
235    use std::time::Duration;
236
237    struct TestTask {
238        scheduler: TaskScheduler<u32>,
239        id: u32,
240        weight: u32,
241        cost: VirtualTime,
242        realtime: VirtualTime,
243    }
244
245    impl Future for TestTask {
246        type Output = ();
247        fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
248            let _guard = match self.scheduler.poll_resume(cx, &self.id, self.weight) {
249                task::Poll::Ready(guard) => guard,
250                task::Poll::Pending => return task::Poll::Pending,
251            };
252            std::thread::sleep(Duration::from_millis(self.cost as _));
253            self.realtime += self.cost;
254            println!(
255                "Task [{}] w={}, t={}, r={}",
256                self.id,
257                self.weight,
258                self.realtime,
259                self.realtime / self.weight as VirtualTime,
260            );
261            task::Poll::Ready(())
262        }
263    }
264
265    #[tokio::test]
266    #[ignore]
267    async fn it_works() {
268        let scheduler = TaskScheduler::new(3);
269        let mut tasks = vec![];
270        for id in 1..=9_u32 {
271            let scheduler = scheduler.clone();
272            let handle = tokio::spawn(async move {
273                let task = TestTask {
274                    scheduler,
275                    id,
276                    weight: id,
277                    cost: 100,
278                    realtime: 0,
279                };
280                tokio::pin!(task);
281                loop {
282                    (&mut task).await;
283                }
284            });
285            tasks.push(handle);
286        }
287        for task in tasks {
288            let _ = task.await;
289        }
290    }
291}