1use log;
4use tokio::macros::support::Future;
5use tokio::sync::mpsc;
6
7use crate::task::{spawn_inner, TaskError, TaskHandle};
8
9#[derive(Copy, Clone)]
11pub struct TaskManagerBuilder {
12 max_tasks: usize,
13 capacity: usize,
14 completion_events_buffer_size: usize,
15}
16
17impl TaskManagerBuilder {
18 pub fn with_max_tasks(&mut self, max_tasks: usize) -> &mut TaskManagerBuilder {
20 self.max_tasks = max_tasks;
21 return self;
22 }
23
24 pub fn with_capacity(&mut self, capacity: usize) -> &mut TaskManagerBuilder {
26 self.capacity = capacity;
27 return self;
28 }
29
30 pub fn with_completion_event_buffer_size(
33 &mut self,
34 completion_event_buffer_size: usize,
35 ) -> &mut TaskManagerBuilder {
36 self.completion_events_buffer_size = completion_event_buffer_size;
37 return self;
38 }
39
40 pub fn build(self) -> TaskManager {
42 TaskManager::new(self.max_tasks, self.capacity, self.completion_events_buffer_size)
43 }
44}
45
46#[derive(Debug)]
48pub enum TaskManagerError {
49 TaskNotFound,
51 TaskManagerIsFull,
53}
54
55pub struct TaskManager {
58 tasks: slab::Slab<TaskHandle<()>>,
59 completion_event_queue_sender: mpsc::Sender<usize>,
60 completion_event_queue_receiver: mpsc::Receiver<usize>,
61 max_tasks: usize,
62}
63
64impl TaskManager {
65 pub fn builder() -> TaskManagerBuilder {
67 TaskManagerBuilder {
68 max_tasks: 1024,
69 capacity: 32,
70 completion_events_buffer_size: 256,
71 }
72 }
73
74 pub fn new(max_tasks: usize, capacity: usize, completion_events_buffer_size: usize) -> TaskManager {
76 let (completion_event_queue_sender, completion_event_queue_receiver) =
77 mpsc::channel(completion_events_buffer_size);
78
79 TaskManager {
80 tasks: slab::Slab::with_capacity(capacity),
81 completion_event_queue_sender,
82 completion_event_queue_receiver,
83 max_tasks,
84 }
85 }
86
87 pub fn size(&self) -> usize {
89 self.tasks.len()
90 }
91
92 pub fn try_spawn<F>(&mut self, future: F) -> Option<usize>
96 where
97 F: Future<Output = ()> + Send + 'static,
98 {
99 if self.tasks.len() == self.max_tasks {
100 return None;
101 }
102
103 let task_entry = self.tasks.vacant_entry();
104 let task_key = task_entry.key();
105
106 let completion_event_queue_sender = self.completion_event_queue_sender.clone();
107 let task_handle = spawn_inner(future, async move {
108 let _ = completion_event_queue_sender.send(task_key).await;
109 });
110 task_entry.insert(task_handle);
111
112 return Some(task_key);
113 }
114
115 pub async fn process(&mut self, resume_panic: bool) {
121 loop {
122 let task_key = self
123 .completion_event_queue_receiver
124 .recv()
125 .await
126 .expect("channel unexpectedly closed");
127
128 match self.tasks.try_remove(task_key) {
129 None => log::debug!("task {} is not longer attached to the manager", task_key),
130 Some(task_handle) => match task_handle.await {
131 Err(TaskError::Panicked(reason)) => {
132 if resume_panic {
133 std::panic::resume_unwind(reason);
134 }
135 }
136 _ => {}
137 },
138 }
139 }
140 }
141
142 pub fn detach(&mut self, task_key: usize) -> Result<TaskHandle<()>, TaskManagerError> {
144 match self.tasks.try_remove(task_key) {
145 Some(task_handle) => Ok(task_handle),
146 None => Err(TaskManagerError::TaskNotFound),
147 }
148 }
149
150 pub fn cancel(mut self) {
152 for (_, task_handle) in self.tasks.iter_mut() {
153 task_handle.cancel();
154 }
155 }
156
157 pub fn abort(mut self) {
159 for (_, task_handle) in self.tasks.iter_mut() {
160 task_handle.abort();
161 }
162 }
163
164 pub async fn join(mut self, resume_panic: bool) {
169 for (_, task_handle) in std::mem::take(&mut self.tasks) {
170 match task_handle.await {
171 Err(TaskError::Panicked(reason)) => {
172 if resume_panic {
173 std::panic::resume_unwind(reason);
174 }
175 }
176 _ => {}
177 }
178 }
179 }
180
181 pub fn cancel_task(&mut self, task_key: usize) -> Result<(), TaskManagerError> {
185 match self.tasks.get_mut(task_key) {
186 Some(task_handle) => {
187 task_handle.cancel();
188 Ok(())
189 }
190 None => Err(TaskManagerError::TaskNotFound),
191 }
192 }
193
194 pub fn abort_task(&mut self, task_key: usize) -> Result<(), TaskManagerError> {
197 match self.tasks.try_remove(task_key) {
198 Some(task_handle) => {
199 task_handle.abort();
200 Ok(())
201 }
202 None => Err(TaskManagerError::TaskNotFound),
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use std::time::Duration;
210
211 use super::TaskManager;
212 use crate::try_await;
213
214 #[tokio::test]
215 async fn test_task_manager_overflow() {
216 let mut task_manager = TaskManager::builder().with_max_tasks(1).build();
217
218 let task_key = task_manager.try_spawn(async {});
219 assert!(task_key.is_some());
220
221 let task_key = task_manager.try_spawn(async {});
222 assert!(task_key.is_none());
223 }
224
225 #[tokio::test]
226 #[should_panic(expected = "test panic")]
227 async fn test_task_unwinding_enabled() {
228 let panic_func = async { panic!("test panic") };
229
230 let mut task_manager = TaskManager::builder().build();
231 task_manager.try_spawn(panic_func).unwrap();
232 task_manager.join(true).await;
233 }
234
235 #[tokio::test]
236 async fn test_task_unwinding_disabled() {
237 let panic_func = async { panic!("test panic") };
238
239 let mut task_manager = TaskManager::builder().build();
240 task_manager.try_spawn(panic_func).unwrap();
241 task_manager.join(false).await;
242 }
243
244 #[tokio::test]
245 async fn test_task_abortion() {
246 let infinite_func = async {
247 tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)).await;
248 };
249
250 let mut task_manager = TaskManager::builder().build();
251 let task_key = task_manager.try_spawn(infinite_func).unwrap();
252 task_manager.abort_task(task_key).unwrap();
253 task_manager.join(true).await;
254 }
255
256 #[tokio::test]
257 async fn test_task_manager_abortion() {
258 let infinite_func1 = async {
259 tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)).await;
260 };
261 let infinite_func2 = async {
262 tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)).await;
263 };
264
265 let mut task_manager = TaskManager::builder().build();
266 task_manager.try_spawn(infinite_func1).unwrap();
267 task_manager.try_spawn(infinite_func2).unwrap();
268
269 task_manager.abort();
270 }
271
272 #[tokio::test]
273 async fn test_task_manager_cancellation() {
274 let cancelable_func1 = async move {
275 try_await!(tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)));
276 };
277 let cancelable_func2 = async move {
278 try_await!(tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)));
279 };
280
281 let mut task_manager = TaskManager::builder().build();
282 task_manager.try_spawn(cancelable_func1).unwrap();
283 task_manager.try_spawn(cancelable_func2).unwrap();
284
285 task_manager.cancel();
286 }
287
288 #[tokio::test]
289 async fn test_processing_loop() {
290 let mut task_manager = TaskManager::builder().build();
291 task_manager.try_spawn(async {}).unwrap();
292 task_manager.try_spawn(async {}).unwrap();
293 assert_eq!(task_manager.size(), 2);
294
295 tokio::task::yield_now().await;
296 tokio::time::timeout(Duration::from_millis(0), task_manager.process(true))
297 .await
298 .unwrap_err();
299 assert_eq!(task_manager.size(), 0);
300
301 task_manager.try_spawn(async {}).unwrap();
302 assert_eq!(task_manager.size(), 1);
303
304 tokio::task::yield_now().await;
305 tokio::time::timeout(Duration::from_millis(0), task_manager.process(true))
306 .await
307 .unwrap_err();
308 assert_eq!(task_manager.size(), 0);
309 }
310
311 #[tokio::test]
312 async fn test_task_detach() {
313 let cancelable_func = async move {
314 try_await!(tokio::time::sleep(std::time::Duration::from_secs(u64::MAX)));
315 };
316
317 let mut task_manager = TaskManager::builder().build();
318 let task_key = task_manager.try_spawn(cancelable_func).unwrap();
319
320 let mut task_handle = task_manager.detach(task_key).unwrap();
321 assert_eq!(task_manager.size(), 0);
322
323 task_handle.cancel();
324 let _ = task_handle.await;
325 }
326}