Skip to main content

gstthreadshare/runtime/executor/
task.rs

1// Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
2// Copyright (C) 2019-2022 François Laignel <fengalin@free.fr>
3//
4// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
5// If a copy of the MPL was not distributed with this file, You can obtain one at
6// <https://mozilla.org/MPL/2.0/>.
7//
8// SPDX-License-Identifier: MPL-2.0
9
10use async_task::Runnable;
11use concurrent_queue::ConcurrentQueue;
12
13use futures::future::BoxFuture;
14use futures::prelude::*;
15
16use pin_project_lite::pin_project;
17
18use slab::Slab;
19
20use std::cell::Cell;
21use std::collections::VecDeque;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::{Arc, Mutex};
25use std::task::Poll;
26
27use super::CallOnDrop;
28use crate::runtime::RUNTIME_CAT;
29
30thread_local! {
31    static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
32}
33
34#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
35pub struct TaskId(pub(super) usize);
36
37impl TaskId {
38    pub(super) fn current() -> Option<TaskId> {
39        CURRENT_TASK_ID.try_with(Cell::get).ok().flatten()
40    }
41}
42
43pub type SubTaskOutput = Result<(), gst::FlowError>;
44
45pin_project! {
46    pub(super) struct TaskFuture<F: Future> {
47        id: TaskId,
48        #[pin]
49        future: F,
50    }
51
52}
53
54impl<F: Future> Future for TaskFuture<F> {
55    type Output = F::Output;
56
57    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
58        struct TaskIdGuard {
59            prev_task_id: Option<TaskId>,
60        }
61
62        impl Drop for TaskIdGuard {
63            fn drop(&mut self) {
64                let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take()));
65            }
66        }
67
68        let task_id = self.id;
69        let project = self.project();
70
71        let _guard = TaskIdGuard {
72            prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))),
73        };
74
75        project.future.poll(cx)
76    }
77}
78
79pub(super) struct Task {
80    id: TaskId,
81    sub_tasks: VecDeque<BoxFuture<'static, SubTaskOutput>>,
82}
83
84impl Task {
85    fn new(id: TaskId) -> Self {
86        Task {
87            id,
88            sub_tasks: VecDeque::new(),
89        }
90    }
91
92    fn add_sub_task<T>(&mut self, sub_task: T)
93    where
94        T: Future<Output = SubTaskOutput> + Send + 'static,
95    {
96        self.sub_tasks.push_back(sub_task.boxed());
97    }
98}
99
100impl fmt::Debug for Task {
101    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102        fmt.debug_struct("Task")
103            .field("id", &self.id)
104            .field("sub_tasks len", &self.sub_tasks.len())
105            .finish()
106    }
107}
108
109#[derive(Debug, Clone)]
110pub(super) struct TaskQueue {
111    runnables: Arc<ConcurrentQueue<Runnable>>,
112    tasks: Arc<Mutex<Slab<Task>>>,
113}
114
115impl Default for TaskQueue {
116    fn default() -> Self {
117        TaskQueue {
118            runnables: Arc::new(ConcurrentQueue::unbounded()),
119            tasks: Arc::new(Mutex::new(Slab::new())),
120        }
121    }
122}
123
124impl TaskQueue {
125    pub fn add<F>(&self, future: F) -> (TaskId, async_task::Task<<F as Future>::Output>)
126    where
127        F: Future + Send + 'static,
128        F::Output: Send + 'static,
129    {
130        let tasks_weak = Arc::downgrade(&self.tasks);
131        let mut tasks = self.tasks.lock().unwrap();
132        let task_id = TaskId(tasks.vacant_entry().key());
133
134        let task_fut = async move {
135            gst::trace!(RUNTIME_CAT, "Running {task_id:?}");
136
137            let _guard = CallOnDrop::new(move || {
138                if let Some(task) = tasks_weak
139                    .upgrade()
140                    .and_then(|tasks| tasks.lock().unwrap().try_remove(task_id.0))
141                    && !task.sub_tasks.is_empty()
142                {
143                    gst::warning!(
144                        RUNTIME_CAT,
145                        "Task {task_id:?} has {} pending sub tasks",
146                        task.sub_tasks.len(),
147                    );
148                }
149
150                gst::trace!(RUNTIME_CAT, "Done {task_id:?}",);
151            });
152
153            TaskFuture {
154                id: task_id,
155                future,
156            }
157            .await
158        };
159
160        let runnables = Arc::clone(&self.runnables);
161        let (runnable, task) = async_task::spawn(task_fut, move |runnable| {
162            runnables.push(runnable).unwrap();
163        });
164        tasks.insert(Task::new(task_id));
165        drop(tasks);
166
167        runnable.schedule();
168
169        (task_id, task)
170    }
171
172    /// Adds a task to be blocked on immediately.
173    ///
174    /// # Safety
175    ///
176    /// The function and its output must outlive the execution
177    /// of the resulting task and the retrieval of the result.
178    pub unsafe fn add_sync<F, O>(&self, f: F) -> async_task::Task<O>
179    where
180        F: FnOnce() -> O + Send,
181        O: Send,
182    {
183        unsafe {
184            let tasks_clone = Arc::clone(&self.tasks);
185            let mut tasks = self.tasks.lock().unwrap();
186            let task_id = TaskId(tasks.vacant_entry().key());
187
188            let task_fut = async move {
189                gst::trace!(RUNTIME_CAT, "Executing sync function as {task_id:?}");
190
191                let _guard = CallOnDrop::new(move || {
192                    let _ = tasks_clone.lock().unwrap().try_remove(task_id.0);
193
194                    gst::trace!(RUNTIME_CAT, "Done executing sync function as {task_id:?}");
195                });
196
197                f()
198            };
199
200            let runnables = Arc::clone(&self.runnables);
201            // This is the unsafe call for which the lifetime must hold
202            // until the the Future is Ready and its Output retrieved.
203            let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| {
204                runnables.push(runnable).unwrap();
205            });
206            tasks.insert(Task::new(task_id));
207            drop(tasks);
208
209            runnable.schedule();
210
211            task
212        }
213    }
214
215    pub fn pop_runnable(&self) -> Result<Runnable, concurrent_queue::PopError> {
216        self.runnables.pop()
217    }
218
219    pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
220    where
221        T: Future<Output = SubTaskOutput> + Send + 'static,
222    {
223        let mut state = self.tasks.lock().unwrap();
224        match state.get_mut(task_id.0) {
225            Some(task) => {
226                gst::trace!(RUNTIME_CAT, "Adding subtask to {task_id:?}");
227                task.add_sub_task(sub_task);
228                Ok(())
229            }
230            None => {
231                gst::trace!(RUNTIME_CAT, "Task was removed in the meantime");
232                Err(sub_task)
233            }
234        }
235    }
236
237    pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput {
238        loop {
239            let mut sub_tasks = match self.tasks.lock().unwrap().get_mut(task_id.0) {
240                Some(task) if !task.sub_tasks.is_empty() => std::mem::take(&mut task.sub_tasks),
241                _ => return Ok(()),
242            };
243
244            gst::trace!(
245                RUNTIME_CAT,
246                "Draining {} sub tasks from {task_id:?}",
247                sub_tasks.len(),
248            );
249
250            for sub_task in sub_tasks.drain(..) {
251                sub_task.await?;
252            }
253        }
254    }
255}