1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::hash::Hash;
4use std::sync::{Arc, Mutex, Weak};
5use std::time::Instant;
6
7use rbtree::RBTree;
8use std::task;
9
10pub type VirtualTime = u128;
11
12pub trait TaskIdType: Clone + Send + Eq + Hash + Debug + 'static {}
13impl<T: Clone + Send + Eq + Hash + Debug + 'static> TaskIdType for T {}
14
15type WeakScheduler<TaskId> = Weak<Mutex<SchedulerInner<TaskId>>>;
16
17#[derive(Clone)]
18pub struct TaskScheduler<TaskId: TaskIdType> {
19 inner: Arc<Mutex<SchedulerInner<TaskId>>>,
20}
21
22impl<TaskId: TaskIdType> TaskScheduler<TaskId> {
23 pub fn new(virtual_cores: u32) -> Self {
24 Self {
25 inner: Arc::new_cyclic(|weak_inner| {
26 Mutex::new(SchedulerInner::new(virtual_cores, weak_inner.clone()))
27 }),
28 }
29 }
30
31 pub fn poll_resume(
32 &self,
33 cx: &task::Context<'_>,
34 task_id: &TaskId,
35 weight: u32,
36 ) -> task::Poll<RunningGuard<TaskId>> {
37 self.inner.lock().unwrap().poll_resume(cx, task_id, weight)
38 }
39
40 pub fn reset(&self, task_id: &TaskId) {
41 self.exit(task_id)
42 }
43
44 pub fn exit(&self, task_id: &TaskId) {
45 self.inner.lock().unwrap().exit(task_id)
46 }
47}
48
49#[derive(Debug, PartialEq, Eq)]
50enum TaskState {
51 Idle,
52 Ready,
53 ToRun,
54 Running,
55}
56
57struct Task {
58 state: TaskState,
59 virtual_runtime: VirtualTime,
60}
61
62struct ReadyTask<TaskId> {
63 id: TaskId,
64 waker: task::Waker,
65}
66
67pub struct RunningGuard<TaskId: TaskIdType> {
68 queue: WeakScheduler<TaskId>,
69 task_id: TaskId,
70 start_time: Instant,
71 actual_cost: Option<VirtualTime>,
72 weight: u32,
73}
74
75impl<TaskId: TaskIdType> RunningGuard<TaskId> {
76 pub fn set_cost(&mut self, cost: VirtualTime) {
77 self.actual_cost = Some(cost);
78 }
79}
80
81struct SchedulerInner<TaskId: TaskIdType> {
82 weak_self: WeakScheduler<TaskId>,
83 tasks: HashMap<TaskId, Task>,
84 ready_tasks: RBTree<VirtualTime, ReadyTask<TaskId>>,
85 virtual_clock: VirtualTime,
86 virtual_cores: u32,
87 running_tasks: u32,
88}
89
90unsafe impl<T: TaskIdType> Send for SchedulerInner<T> {}
91
92impl<TaskId: TaskIdType> SchedulerInner<TaskId> {
93 fn new(virtual_cores: u32, weak_self: WeakScheduler<TaskId>) -> Self {
94 Self {
95 weak_self,
96 tasks: Default::default(),
97 ready_tasks: RBTree::new(),
98 virtual_cores,
99 virtual_clock: 0,
100 running_tasks: 0,
101 }
102 }
103
104 fn exit(&mut self, id: &TaskId) {
105 let task = match self.tasks.remove(id) {
106 Some(task) => task,
107 None => return,
108 };
109 if matches!(task.state, TaskState::ToRun | TaskState::Running) {
110 self.running_tasks -= 1;
111 self.schedule();
112 }
113 }
114
115 fn poll_resume(
116 &mut self,
117 cx: &task::Context<'_>,
118 id: &TaskId,
119 weight: u32,
120 ) -> task::Poll<RunningGuard<TaskId>> {
121 self.maybe_reset_clock();
122
123 let task = self.tasks.entry(id.clone()).or_insert_with(|| Task {
124 state: TaskState::Idle,
125 virtual_runtime: self.virtual_clock,
126 });
127
128 match task.state {
129 TaskState::Idle => {
130 let ready = ReadyTask {
131 id: id.clone(),
132 waker: cx.waker().clone(),
133 };
134 task.virtual_runtime = task.virtual_runtime.max(self.virtual_clock);
135 task.state = TaskState::Ready;
136 self.ready_tasks.insert(task.virtual_runtime, ready);
137 self.schedule();
138 task::Poll::Pending
139 }
140 TaskState::Ready => {
141 for task in self.ready_tasks.values_mut() {
142 if task.id == *id {
143 let waker = cx.waker();
144 if !task.waker.will_wake(waker) {
145 task.waker = waker.clone();
146 }
147 break;
148 }
149 }
150 task::Poll::Pending
151 }
152 TaskState::ToRun => {
153 let guard = RunningGuard {
154 queue: self.weak_self.clone(),
155 task_id: id.clone(),
156 start_time: Instant::now(),
157 actual_cost: None,
158 weight,
159 };
160 task.state = TaskState::Running;
161 task::Poll::Ready(guard)
162 }
163 TaskState::Running => panic!("BUG: resuming a running task"),
164 }
165 }
166
167 fn park(&mut self, task_id: &TaskId, vruntime: VirtualTime) {
169 let task = match self.tasks.get_mut(task_id) {
170 Some(task) => task,
171 None => return,
172 };
173
174 assert_eq!(
175 task.state,
176 TaskState::Running,
177 "BUG: parking a non-running task"
178 );
179
180 task.virtual_runtime += vruntime.max(1);
181 task.state = TaskState::Idle;
182 self.running_tasks -= 1;
183 self.schedule();
184 }
185
186 fn schedule(&mut self) {
188 while self.running_tasks < self.virtual_cores {
189 let (vruntime, ready_task) = match self.ready_tasks.pop_first() {
190 Some(v) => v,
191 None => break,
192 };
193 let task = match self.tasks.get_mut(&ready_task.id) {
194 Some(task) => task,
195 None => continue,
197 };
198 self.running_tasks += 1;
199 self.virtual_clock = vruntime;
200
201 task.state = TaskState::ToRun;
202 ready_task.waker.wake();
203 }
204 }
205
206 fn maybe_reset_clock(&mut self) {
207 if self.virtual_clock > VirtualTime::MAX / 2 {
208 for task in self.tasks.values_mut() {
209 task.virtual_runtime = task.virtual_runtime.saturating_sub(self.virtual_clock);
210 }
211 self.virtual_clock = 0;
212 }
213 }
214}
215
216impl<TaskId: TaskIdType> Drop for RunningGuard<TaskId> {
217 fn drop(&mut self) {
218 if let Some(inner) = self.queue.upgrade() {
219 let actual_cost = self.actual_cost.unwrap_or_else(|| {
220 let cost = self.start_time.elapsed().as_nanos() as VirtualTime;
221 cost << 32
223 });
224 let vruntime = actual_cost / self.weight.max(1) as VirtualTime;
225 inner.lock().unwrap().park(&self.task_id, vruntime.max(1));
226 }
227 }
228}
229
230#[cfg(test)]
231mod test {
232 use super::*;
233 use std::future::Future;
234 use std::pin::Pin;
235 use std::time::Duration;
236
237 struct TestTask {
238 scheduler: TaskScheduler<u32>,
239 id: u32,
240 weight: u32,
241 cost: VirtualTime,
242 realtime: VirtualTime,
243 }
244
245 impl Future for TestTask {
246 type Output = ();
247 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
248 let _guard = match self.scheduler.poll_resume(cx, &self.id, self.weight) {
249 task::Poll::Ready(guard) => guard,
250 task::Poll::Pending => return task::Poll::Pending,
251 };
252 std::thread::sleep(Duration::from_millis(self.cost as _));
253 self.realtime += self.cost;
254 println!(
255 "Task [{}] w={}, t={}, r={}",
256 self.id,
257 self.weight,
258 self.realtime,
259 self.realtime / self.weight as VirtualTime,
260 );
261 task::Poll::Ready(())
262 }
263 }
264
265 #[tokio::test]
266 #[ignore]
267 async fn it_works() {
268 let scheduler = TaskScheduler::new(3);
269 let mut tasks = vec![];
270 for id in 1..=9_u32 {
271 let scheduler = scheduler.clone();
272 let handle = tokio::spawn(async move {
273 let task = TestTask {
274 scheduler,
275 id,
276 weight: id,
277 cost: 100,
278 realtime: 0,
279 };
280 tokio::pin!(task);
281 loop {
282 (&mut task).await;
283 }
284 });
285 tasks.push(handle);
286 }
287 for task in tasks {
288 let _ = task.await;
289 }
290 }
291}