gstthreadshare/runtime/executor/
task.rs1use async_task::Runnable;
11use concurrent_queue::ConcurrentQueue;
12
13use futures::future::BoxFuture;
14use futures::prelude::*;
15
16use pin_project_lite::pin_project;
17
18use slab::Slab;
19
20use std::cell::Cell;
21use std::collections::VecDeque;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::{Arc, Mutex};
25use std::task::Poll;
26
27use super::CallOnDrop;
28use crate::runtime::RUNTIME_CAT;
29
30thread_local! {
31 static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
32}
33
34#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
35pub struct TaskId(pub(super) usize);
36
37impl TaskId {
38 pub(super) fn current() -> Option<TaskId> {
39 CURRENT_TASK_ID.try_with(Cell::get).ok().flatten()
40 }
41}
42
43pub type SubTaskOutput = Result<(), gst::FlowError>;
44
45pin_project! {
46 pub(super) struct TaskFuture<F: Future> {
47 id: TaskId,
48 #[pin]
49 future: F,
50 }
51
52}
53
54impl<F: Future> Future for TaskFuture<F> {
55 type Output = F::Output;
56
57 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
58 struct TaskIdGuard {
59 prev_task_id: Option<TaskId>,
60 }
61
62 impl Drop for TaskIdGuard {
63 fn drop(&mut self) {
64 let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take()));
65 }
66 }
67
68 let task_id = self.id;
69 let project = self.project();
70
71 let _guard = TaskIdGuard {
72 prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))),
73 };
74
75 project.future.poll(cx)
76 }
77}
78
79pub(super) struct Task {
80 id: TaskId,
81 sub_tasks: VecDeque<BoxFuture<'static, SubTaskOutput>>,
82}
83
84impl Task {
85 fn new(id: TaskId) -> Self {
86 Task {
87 id,
88 sub_tasks: VecDeque::new(),
89 }
90 }
91
92 fn add_sub_task<T>(&mut self, sub_task: T)
93 where
94 T: Future<Output = SubTaskOutput> + Send + 'static,
95 {
96 self.sub_tasks.push_back(sub_task.boxed());
97 }
98}
99
100impl fmt::Debug for Task {
101 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102 fmt.debug_struct("Task")
103 .field("id", &self.id)
104 .field("sub_tasks len", &self.sub_tasks.len())
105 .finish()
106 }
107}
108
109#[derive(Debug, Clone)]
110pub(super) struct TaskQueue {
111 runnables: Arc<ConcurrentQueue<Runnable>>,
112 tasks: Arc<Mutex<Slab<Task>>>,
113}
114
115impl Default for TaskQueue {
116 fn default() -> Self {
117 TaskQueue {
118 runnables: Arc::new(ConcurrentQueue::unbounded()),
119 tasks: Arc::new(Mutex::new(Slab::new())),
120 }
121 }
122}
123
124impl TaskQueue {
125 pub fn add<F>(&self, future: F) -> (TaskId, async_task::Task<<F as Future>::Output>)
126 where
127 F: Future + Send + 'static,
128 F::Output: Send + 'static,
129 {
130 let tasks_weak = Arc::downgrade(&self.tasks);
131 let mut tasks = self.tasks.lock().unwrap();
132 let task_id = TaskId(tasks.vacant_entry().key());
133
134 let task_fut = async move {
135 gst::trace!(RUNTIME_CAT, "Running {task_id:?}");
136
137 let _guard = CallOnDrop::new(move || {
138 if let Some(task) = tasks_weak
139 .upgrade()
140 .and_then(|tasks| tasks.lock().unwrap().try_remove(task_id.0))
141 && !task.sub_tasks.is_empty()
142 {
143 gst::warning!(
144 RUNTIME_CAT,
145 "Task {task_id:?} has {} pending sub tasks",
146 task.sub_tasks.len(),
147 );
148 }
149
150 gst::trace!(RUNTIME_CAT, "Done {task_id:?}",);
151 });
152
153 TaskFuture {
154 id: task_id,
155 future,
156 }
157 .await
158 };
159
160 let runnables = Arc::clone(&self.runnables);
161 let (runnable, task) = async_task::spawn(task_fut, move |runnable| {
162 runnables.push(runnable).unwrap();
163 });
164 tasks.insert(Task::new(task_id));
165 drop(tasks);
166
167 runnable.schedule();
168
169 (task_id, task)
170 }
171
172 pub unsafe fn add_sync<F, O>(&self, f: F) -> async_task::Task<O>
179 where
180 F: FnOnce() -> O + Send,
181 O: Send,
182 {
183 unsafe {
184 let tasks_clone = Arc::clone(&self.tasks);
185 let mut tasks = self.tasks.lock().unwrap();
186 let task_id = TaskId(tasks.vacant_entry().key());
187
188 let task_fut = async move {
189 gst::trace!(RUNTIME_CAT, "Executing sync function as {task_id:?}");
190
191 let _guard = CallOnDrop::new(move || {
192 let _ = tasks_clone.lock().unwrap().try_remove(task_id.0);
193
194 gst::trace!(RUNTIME_CAT, "Done executing sync function as {task_id:?}");
195 });
196
197 f()
198 };
199
200 let runnables = Arc::clone(&self.runnables);
201 let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| {
204 runnables.push(runnable).unwrap();
205 });
206 tasks.insert(Task::new(task_id));
207 drop(tasks);
208
209 runnable.schedule();
210
211 task
212 }
213 }
214
215 pub fn pop_runnable(&self) -> Result<Runnable, concurrent_queue::PopError> {
216 self.runnables.pop()
217 }
218
219 pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
220 where
221 T: Future<Output = SubTaskOutput> + Send + 'static,
222 {
223 let mut state = self.tasks.lock().unwrap();
224 match state.get_mut(task_id.0) {
225 Some(task) => {
226 gst::trace!(RUNTIME_CAT, "Adding subtask to {task_id:?}");
227 task.add_sub_task(sub_task);
228 Ok(())
229 }
230 None => {
231 gst::trace!(RUNTIME_CAT, "Task was removed in the meantime");
232 Err(sub_task)
233 }
234 }
235 }
236
237 pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput {
238 loop {
239 let mut sub_tasks = match self.tasks.lock().unwrap().get_mut(task_id.0) {
240 Some(task) if !task.sub_tasks.is_empty() => std::mem::take(&mut task.sub_tasks),
241 _ => return Ok(()),
242 };
243
244 gst::trace!(
245 RUNTIME_CAT,
246 "Draining {} sub tasks from {task_id:?}",
247 sub_tasks.len(),
248 );
249
250 for sub_task in sub_tasks.drain(..) {
251 sub_task.await?;
252 }
253 }
254 }
255}