gstthreadshare/runtime/executor/
task.rs

1// Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
2// Copyright (C) 2019-2022 François Laignel <fengalin@free.fr>
3//
4// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
5// If a copy of the MPL was not distributed with this file, You can obtain one at
6// <https://mozilla.org/MPL/2.0/>.
7//
8// SPDX-License-Identifier: MPL-2.0
9
10use async_task::Runnable;
11use concurrent_queue::ConcurrentQueue;
12
13use futures::future::BoxFuture;
14use futures::prelude::*;
15
16use pin_project_lite::pin_project;
17
18use slab::Slab;
19
20use std::cell::Cell;
21use std::collections::VecDeque;
22use std::fmt;
23use std::pin::Pin;
24use std::sync::{Arc, Mutex};
25use std::task::Poll;
26
27use super::CallOnDrop;
28use crate::runtime::RUNTIME_CAT;
29
30thread_local! {
31    static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
32}
33
34#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
35pub struct TaskId(pub(super) usize);
36
37impl TaskId {
38    pub(super) fn current() -> Option<TaskId> {
39        CURRENT_TASK_ID.try_with(Cell::get).ok().flatten()
40    }
41}
42
43pub type SubTaskOutput = Result<(), gst::FlowError>;
44
45pin_project! {
46    pub(super) struct TaskFuture<F: Future> {
47        id: TaskId,
48        #[pin]
49        future: F,
50    }
51
52}
53
54impl<F: Future> Future for TaskFuture<F> {
55    type Output = F::Output;
56
57    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
58        struct TaskIdGuard {
59            prev_task_id: Option<TaskId>,
60        }
61
62        impl Drop for TaskIdGuard {
63            fn drop(&mut self) {
64                let _ = CURRENT_TASK_ID.try_with(|cur| cur.replace(self.prev_task_id.take()));
65            }
66        }
67
68        let task_id = self.id;
69        let project = self.project();
70
71        let _guard = TaskIdGuard {
72            prev_task_id: CURRENT_TASK_ID.with(|cur| cur.replace(Some(task_id))),
73        };
74
75        project.future.poll(cx)
76    }
77}
78
79pub(super) struct Task {
80    id: TaskId,
81    sub_tasks: VecDeque<BoxFuture<'static, SubTaskOutput>>,
82}
83
84impl Task {
85    fn new(id: TaskId) -> Self {
86        Task {
87            id,
88            sub_tasks: VecDeque::new(),
89        }
90    }
91
92    fn add_sub_task<T>(&mut self, sub_task: T)
93    where
94        T: Future<Output = SubTaskOutput> + Send + 'static,
95    {
96        self.sub_tasks.push_back(sub_task.boxed());
97    }
98}
99
100impl fmt::Debug for Task {
101    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
102        fmt.debug_struct("Task")
103            .field("id", &self.id)
104            .field("sub_tasks len", &self.sub_tasks.len())
105            .finish()
106    }
107}
108
109#[derive(Debug, Clone)]
110pub(super) struct TaskQueue {
111    runnables: Arc<ConcurrentQueue<Runnable>>,
112    tasks: Arc<Mutex<Slab<Task>>>,
113}
114
115impl Default for TaskQueue {
116    fn default() -> Self {
117        TaskQueue {
118            runnables: Arc::new(ConcurrentQueue::unbounded()),
119            tasks: Arc::new(Mutex::new(Slab::new())),
120        }
121    }
122}
123
124impl TaskQueue {
125    pub fn add<F>(&self, future: F) -> (TaskId, async_task::Task<<F as Future>::Output>)
126    where
127        F: Future + Send + 'static,
128        F::Output: Send + 'static,
129    {
130        let tasks_weak = Arc::downgrade(&self.tasks);
131        let mut tasks = self.tasks.lock().unwrap();
132        let task_id = TaskId(tasks.vacant_entry().key());
133
134        let task_fut = async move {
135            gst::trace!(RUNTIME_CAT, "Running {task_id:?}");
136
137            let _guard = CallOnDrop::new(move || {
138                if let Some(task) = tasks_weak
139                    .upgrade()
140                    .and_then(|tasks| tasks.lock().unwrap().try_remove(task_id.0))
141                {
142                    if !task.sub_tasks.is_empty() {
143                        gst::warning!(
144                            RUNTIME_CAT,
145                            "Task {task_id:?} has {} pending sub tasks",
146                            task.sub_tasks.len(),
147                        );
148                    }
149                }
150
151                gst::trace!(RUNTIME_CAT, "Done {task_id:?}",);
152            });
153
154            TaskFuture {
155                id: task_id,
156                future,
157            }
158            .await
159        };
160
161        let runnables = Arc::clone(&self.runnables);
162        let (runnable, task) = async_task::spawn(task_fut, move |runnable| {
163            runnables.push(runnable).unwrap();
164        });
165        tasks.insert(Task::new(task_id));
166        drop(tasks);
167
168        runnable.schedule();
169
170        (task_id, task)
171    }
172
173    /// Adds a task to be blocked on immediately.
174    ///
175    /// # Safety
176    ///
177    /// The function and its output must outlive the execution
178    /// of the resulting task and the retrieval of the result.
179    pub unsafe fn add_sync<F, O>(&self, f: F) -> async_task::Task<O>
180    where
181        F: FnOnce() -> O + Send,
182        O: Send,
183    {
184        let tasks_clone = Arc::clone(&self.tasks);
185        let mut tasks = self.tasks.lock().unwrap();
186        let task_id = TaskId(tasks.vacant_entry().key());
187
188        let task_fut = async move {
189            gst::trace!(RUNTIME_CAT, "Executing sync function as {task_id:?}");
190
191            let _guard = CallOnDrop::new(move || {
192                let _ = tasks_clone.lock().unwrap().try_remove(task_id.0);
193
194                gst::trace!(RUNTIME_CAT, "Done executing sync function as {task_id:?}");
195            });
196
197            f()
198        };
199
200        let runnables = Arc::clone(&self.runnables);
201        // This is the unsafe call for which the lifetime must hold
202        // until the the Future is Ready and its Output retrieved.
203        let (runnable, task) = async_task::spawn_unchecked(task_fut, move |runnable| {
204            runnables.push(runnable).unwrap();
205        });
206        tasks.insert(Task::new(task_id));
207        drop(tasks);
208
209        runnable.schedule();
210
211        task
212    }
213
214    pub fn pop_runnable(&self) -> Result<Runnable, concurrent_queue::PopError> {
215        self.runnables.pop()
216    }
217
218    pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
219    where
220        T: Future<Output = SubTaskOutput> + Send + 'static,
221    {
222        let mut state = self.tasks.lock().unwrap();
223        match state.get_mut(task_id.0) {
224            Some(task) => {
225                gst::trace!(RUNTIME_CAT, "Adding subtask to {task_id:?}");
226                task.add_sub_task(sub_task);
227                Ok(())
228            }
229            None => {
230                gst::trace!(RUNTIME_CAT, "Task was removed in the meantime");
231                Err(sub_task)
232            }
233        }
234    }
235
236    pub async fn drain_sub_tasks(&self, task_id: TaskId) -> SubTaskOutput {
237        loop {
238            let mut sub_tasks = match self.tasks.lock().unwrap().get_mut(task_id.0) {
239                Some(task) if !task.sub_tasks.is_empty() => std::mem::take(&mut task.sub_tasks),
240                _ => return Ok(()),
241            };
242
243            gst::trace!(
244                RUNTIME_CAT,
245                "Draining {} sub tasks from {task_id:?}",
246                sub_tasks.len(),
247            );
248
249            for sub_task in sub_tasks.drain(..) {
250                sub_task.await?;
251            }
252        }
253    }
254}