gstthreadshare/runtime/executor/
task.rs1use async_task::Runnable;
11use concurrent_queue::ConcurrentQueue;
12
13use futures::future::BoxFuture;
14use futures::prelude::*;
15
16use pin_project_lite::pin_project;
17
18use slab::Slab;
19
20use std::cell::Cell;
21use std::collections::VecDeque;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::{Arc, Mutex};
25use std::task::Poll;
26
27use super::CallOnDrop;
28use crate::runtime::RUNTIME_CAT;
29
30thread_local! {
31 static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
32}
33
34#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
35pub struct TaskId(pub(super) usize);
36
37impl TaskId {
38 pub(super) fn current() -> Option<TaskId> {
39 CURRENT_TASK_ID.try_with(Cell::get).ok().flatten()
40 }
41}
42
43pub type SubTaskOutput = Result<(), gst::FlowError>;
44
45pin_project! {
46 pub(super) struct TaskFuture<F: Future> {
47 id: TaskId,
48 #[pin]
49 future: F,
50 }
51
52}
53
54impl<F: Future> Future for TaskFuture<F> {
55 type Output = F::Output;
56
57 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
58 struct TaskIdGuard {
59 prev_task_id: Option<TaskId>,
60 }
61
62 impl Drop for TaskIdGuard {
63 fn drop(&mut self) {
64 let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take()));
65 }
66 }
67
68 let task_id = self.id;
69 let project = self.project();
70
71 let _guard = TaskIdGuard {
72 prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))),
73 };
74
75 project.future.poll(cx)
76 }
77}
78
79pub(super) struct Task {
80 id: TaskId,
81 sub_tasks: VecDeque<BoxFuture<'static, SubTaskOutput>>,
82}
83
84impl Task {
85 fn new(id: TaskId) -> Self {
86 Task {
87 id,
88 sub_tasks: VecDeque::new(),
89 }
90 }
91
92 fn add_sub_task<T>(&mut self, sub_task: T)
93 where
94 T: Future<Output = SubTaskOutput> + Send + 'static,
95 {
96 self.sub_tasks.push_back(sub_task.boxed());
97 }
98}
99
100impl fmt::Debug for Task {
101 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102 fmt.debug_struct("Task")
103 .field("id", &self.id)
104 .field("sub_tasks len", &self.sub_tasks.len())
105 .finish()
106 }
107}
108
109#[derive(Debug, Clone)]
110pub(super) struct TaskQueue {
111 runnables: Arc<ConcurrentQueue<Runnable>>,
112 tasks: Arc<Mutex<Slab<Task>>>,
113}
114
115impl Default for TaskQueue {
116 fn default() -> Self {
117 TaskQueue {
118 runnables: Arc::new(ConcurrentQueue::unbounded()),
119 tasks: Arc::new(Mutex::new(Slab::new())),
120 }
121 }
122}
123
124impl TaskQueue {
125 pub fn add<F>(&self, future: F) -> (TaskId, async_task::Task<<F as Future>::Output>)
126 where
127 F: Future + Send + 'static,
128 F::Output: Send + 'static,
129 {
130 let tasks_weak = Arc::downgrade(&self.tasks);
131 let mut tasks = self.tasks.lock().unwrap();
132 let task_id = TaskId(tasks.vacant_entry().key());
133
134 let task_fut = async move {
135 gst::trace!(RUNTIME_CAT, "Running {task_id:?}");
136
137 let _guard = CallOnDrop::new(move || {
138 if let Some(task) = tasks_weak
139 .upgrade()
140 .and_then(|tasks| tasks.lock().unwrap().try_remove(task_id.0))
141 {
142 if !task.sub_tasks.is_empty() {
143 gst::warning!(
144 RUNTIME_CAT,
145 "Task {task_id:?} has {} pending sub tasks",
146 task.sub_tasks.len(),
147 );
148 }
149 }
150
151 gst::trace!(RUNTIME_CAT, "Done {task_id:?}",);
152 });
153
154 TaskFuture {
155 id: task_id,
156 future,
157 }
158 .await
159 };
160
161 let runnables = Arc::clone(&self.runnables);
162 let (runnable, task) = async_task::spawn(task_fut, move |runnable| {
163 runnables.push(runnable).unwrap();
164 });
165 tasks.insert(Task::new(task_id));
166 drop(tasks);
167
168 runnable.schedule();
169
170 (task_id, task)
171 }
172
173 pub unsafe fn add_sync<F, O>(&self, f: F) -> async_task::Task<O>
180 where
181 F: FnOnce() -> O + Send,
182 O: Send,
183 {
184 let tasks_clone = Arc::clone(&self.tasks);
185 let mut tasks = self.tasks.lock().unwrap();
186 let task_id = TaskId(tasks.vacant_entry().key());
187
188 let task_fut = async move {
189 gst::trace!(RUNTIME_CAT, "Executing sync function as {task_id:?}");
190
191 let _guard = CallOnDrop::new(move || {
192 let _ = tasks_clone.lock().unwrap().try_remove(task_id.0);
193
194 gst::trace!(RUNTIME_CAT, "Done executing sync function as {task_id:?}");
195 });
196
197 f()
198 };
199
200 let runnables = Arc::clone(&self.runnables);
201 let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| {
204 runnables.push(runnable).unwrap();
205 });
206 tasks.insert(Task::new(task_id));
207 drop(tasks);
208
209 runnable.schedule();
210
211 task
212 }
213
214 pub fn pop_runnable(&self) -> Result<Runnable, concurrent_queue::PopError> {
215 self.runnables.pop()
216 }
217
218 pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
219 where
220 T: Future<Output = SubTaskOutput> + Send + 'static,
221 {
222 let mut state = self.tasks.lock().unwrap();
223 match state.get_mut(task_id.0) {
224 Some(task) => {
225 gst::trace!(RUNTIME_CAT, "Adding subtask to {task_id:?}");
226 task.add_sub_task(sub_task);
227 Ok(())
228 }
229 None => {
230 gst::trace!(RUNTIME_CAT, "Task was removed in the meantime");
231 Err(sub_task)
232 }
233 }
234 }
235
236 pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput {
237 loop {
238 let mut sub_tasks = match self.tasks.lock().unwrap().get_mut(task_id.0) {
239 Some(task) if !task.sub_tasks.is_empty() => std::mem::take(&mut task.sub_tasks),
240 _ => return Ok(()),
241 };
242
243 gst::trace!(
244 RUNTIME_CAT,
245 "Draining {} sub tasks from {task_id:?}",
246 sub_tasks.len(),
247 );
248
249 for sub_task in sub_tasks.drain(..) {
250 sub_task.await?;
251 }
252 }
253 }
254}