Skip to main content

gstthreadshare/runtime/executor/
context.rs

1// Copyright (C) 2018-2020 Sebastian Dröge <sebastian@centricular.com>
2// Copyright (C) 2019-2022 François Laignel <fengalin@free.fr>
3//
4// This Source Code Form is subject to the terms of the Mozilla Public License, v2.0.
5// If a copy of the MPL was not distributed with this file, You can obtain one at
6// <https://mozilla.org/MPL/2.0/>.
7//
8// SPDX-License-Identifier: MPL-2.0
9
10use futures::prelude::*;
11
12use std::sync::LazyLock;
13
14use std::collections::HashMap;
15use std::io;
16use std::pin::Pin;
17use std::sync::{Arc, Mutex};
18use std::task::{self, Poll};
19use std::time::Duration;
20
21use super::{JoinHandle, SubTaskOutput, TaskId, scheduler};
22use crate::runtime::RUNTIME_CAT;
23
24// We are bound to using `sync` for the `runtime` `Mutex`es. Attempts to use `async` `Mutex`es
25// lead to the following issues:
26//
27// * `CONTEXTS`: can't `spawn` a `Future` when called from a `Context` thread via `ffi`.
28// * `timers`: can't automatically `remove` the timer from `BinaryHeap` because `async drop`
29//    is not available.
30// * `task_queues`: can't `add` a pending task when called from a `Context` thread via `ffi`.
31//
32// Also, we want to be able to `acquire` a `Context` outside of an `async` context.
33// These `Mutex`es must be `lock`ed for a short period.
34static CONTEXTS: LazyLock<Mutex<HashMap<Arc<str>, ContextWeak>>> =
35    LazyLock::new(|| Mutex::new(HashMap::new()));
36
37/// Blocks on `future` in one way or another if possible.
38///
39/// IO & time related `Future`s must be handled within their own [`Context`].
40/// Wait for the result using a [`JoinHandle`] or a `channel`.
41///
42/// If there's currently an active `Context` with a task, then the future is only queued up as a
43/// pending sub task for that task.
44///
45/// Otherwise the current thread is blocking and the passed in future is executed.
46///
47/// Note that you must not pass any futures here that wait for the currently active task in one way
48/// or another as this would deadlock!
49#[track_caller]
50pub fn block_on_or_add_subtask<Fut>(future: Fut) -> Option<Fut::Output>
51where
52    Fut: Future + Send + 'static,
53    Fut::Output: Send + 'static,
54{
55    if let Some((cur_context, cur_task_id)) = Context::current_task() {
56        gst::debug!(
57            RUNTIME_CAT,
58            "Adding subtask to task {:?} on context {}",
59            cur_task_id,
60            cur_context.name()
61        );
62        let _ = cur_context.add_sub_task(cur_task_id, async move {
63            future.await;
64            Ok(())
65        });
66        return None;
67    }
68
69    // Not running in a Context thread so we can block
70    Some(block_on(future))
71}
72
73/// Blocks on `future`.
74///
75/// IO & time related `Future`s must be handled within their own [`Context`].
76/// Wait for the result using a [`JoinHandle`] or a `channel`.
77///
78/// The current thread is blocking and the passed in future is executed.
79///
80/// # Panics
81///
82/// This function panics if called within a [`Context`] thread.
83#[track_caller]
84pub fn block_on<Fut>(future: Fut) -> Fut::Output
85where
86    Fut: Future + Send + 'static,
87    Fut::Output: Send + 'static,
88{
89    gst::log!(RUNTIME_CAT, "Blocking on local thread");
90    scheduler::Blocking::block_on(future)
91}
92
93/// Yields execution back to the runtime.
94#[inline]
95pub fn yield_now() -> YieldNow {
96    YieldNow::default()
97}
98
99#[derive(Debug, Default)]
100#[must_use = "futures do nothing unless you `.await` or poll them"]
101pub struct YieldNow(bool);
102
103impl Future for YieldNow {
104    type Output = ();
105
106    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
107        if !self.0 {
108            self.0 = true;
109            cx.waker().wake_by_ref();
110            Poll::Pending
111        } else {
112            Poll::Ready(())
113        }
114    }
115}
116
117#[derive(Clone, Debug)]
118pub struct ContextWeak(scheduler::ThrottlingHandleWeak);
119
120impl ContextWeak {
121    pub fn upgrade(&self) -> Option<Context> {
122        self.0.upgrade().map(Context)
123    }
124}
125
126/// A `threadshare` `runtime` `Context`.
127///
128/// The `Context` provides low-level asynchronous processing features to
129/// multiplex task execution on a single thread.
130///
131/// `Element` implementations should use [`PadSrc`] and [`PadSink`] which
132///  provide high-level features.
133///
134/// [`PadSrc`]: ../pad/struct.PadSrc.html
135/// [`PadSink`]: ../pad/struct.PadSink.html
136#[derive(Clone, Debug)]
137pub struct Context(scheduler::ThrottlingHandle);
138
139impl PartialEq for Context {
140    fn eq(&self, other: &Self) -> bool {
141        self.0.eq(&other.0)
142    }
143}
144
145impl Eq for Context {}
146
147impl Context {
148    pub fn acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error> {
149        let mut contexts = CONTEXTS.lock().unwrap();
150
151        if let Some(context_weak) = contexts.get(context_name)
152            && let Some(context) = context_weak.upgrade()
153        {
154            gst::debug!(RUNTIME_CAT, "Joining Context '{}'", context.name());
155            return Ok(context);
156        }
157
158        let context = Context(scheduler::Throttling::start(context_name, wait));
159        contexts.insert(context_name.into(), context.downgrade());
160
161        gst::debug!(
162            RUNTIME_CAT,
163            "New Context '{}' throttling {wait:?}",
164            context.name(),
165        );
166        Ok(context)
167    }
168
169    pub fn downgrade(&self) -> ContextWeak {
170        ContextWeak(self.0.downgrade())
171    }
172
173    pub fn name(&self) -> &str {
174        self.0.context_name()
175    }
176
177    // FIXME this could be renamed as max_throttling
178    // but then, all elements should also change their
179    // wait variables and properties to max_throttling.
180    pub fn wait_duration(&self) -> Duration {
181        self.0.max_throttling()
182    }
183
184    /// Total duration the scheduler spent parked.
185    ///
186    /// This is only useful for performance evaluation.
187    #[cfg(feature = "tuning")]
188    pub fn parked_duration(&self) -> Duration {
189        self.0.parked_duration()
190    }
191
192    /// Returns `true` if a `Context` is running on current thread.
193    pub fn is_context_thread() -> bool {
194        scheduler::Throttling::is_throttling_thread()
195    }
196
197    /// Returns the `Context` running on current thread, if any.
198    pub fn current() -> Option<Context> {
199        scheduler::Throttling::current().map(Context)
200    }
201
202    /// Returns the `TaskId` running on current thread, if any.
203    pub fn current_task() -> Option<(Context, TaskId)> {
204        Option::zip(
205            scheduler::Throttling::current().map(Context),
206            TaskId::current(),
207        )
208    }
209
210    /// Executes the provided function relatively to this [`Context`].
211    ///
212    /// Useful to initialize i/o sources and timers from outside
213    /// of a [`Context`].
214    ///
215    /// # Panic
216    ///
217    /// This will block current thread and would panic if run
218    /// from the [`Context`].
219    #[track_caller]
220    pub fn enter<'a, F, O>(&'a self, f: F) -> O
221    where
222        F: FnOnce() -> O + Send + 'a,
223        O: Send + 'a,
224    {
225        match Context::current().as_ref() {
226            Some(cur) => {
227                if cur == self {
228                    panic!(
229                        "Attempt to enter Context {} within itself, this would deadlock",
230                        self.name()
231                    );
232                } else {
233                    gst::warning!(
234                        RUNTIME_CAT,
235                        "Entering Context {} within {}",
236                        self.name(),
237                        cur.name()
238                    );
239                }
240            }
241            _ => {
242                gst::debug!(RUNTIME_CAT, "Entering Context {}", self.name());
243            }
244        }
245
246        self.0.enter(f)
247    }
248
249    pub fn spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
250    where
251        Fut: Future + Send + 'static,
252        Fut::Output: Send + 'static,
253    {
254        self.0.spawn(future)
255    }
256
257    pub fn spawn_and_unpark<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
258    where
259        Fut: Future + Send + 'static,
260        Fut::Output: Send + 'static,
261    {
262        self.0.spawn_and_unpark(future)
263    }
264
265    /// Forces the scheduler to unpark.
266    ///
267    /// This is not needed by elements implementors as they are
268    /// supposed to call [`Self::spawn_and_unpark`] when needed.
269    /// However, it's useful for lower level implementations such as
270    /// `runtime::Task` so as to make sure the iteration loop yields
271    /// as soon as possible when a transition is requested.
272    pub(in crate::runtime) fn unpark(&self) {
273        self.0.unpark();
274    }
275
276    pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
277    where
278        T: Future<Output = SubTaskOutput> + Send + 'static,
279    {
280        self.0.add_sub_task(task_id, sub_task)
281    }
282
283    pub async fn drain_sub_tasks() -> SubTaskOutput {
284        let (ctx, task_id) = match Context::current_task() {
285            Some(task) => task,
286            None => return Ok(()),
287        };
288
289        ctx.0.drain_sub_tasks(task_id).await
290    }
291}
292
293impl From<scheduler::ThrottlingHandle> for Context {
294    fn from(handle: scheduler::ThrottlingHandle) -> Self {
295        Context(handle)
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use futures::channel::mpsc;
302    use futures::lock::Mutex;
303    use futures::prelude::*;
304
305    use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
306    use std::sync::Arc;
307    use std::time::{Duration, Instant};
308
309    use super::Context;
310    use crate::runtime::Async;
311
312    type Item = i32;
313
314    const SLEEP_DURATION_MS: u64 = 2;
315    const SLEEP_DURATION: Duration = Duration::from_millis(SLEEP_DURATION_MS);
316    const DELAY: Duration = Duration::from_millis(SLEEP_DURATION_MS * 10);
317
318    #[test]
319    fn block_on_timer() {
320        gst::init().unwrap();
321
322        let elapsed = crate::runtime::executor::block_on(async {
323            let now = Instant::now();
324            crate::runtime::timer::delay_for(DELAY).await;
325            now.elapsed()
326        });
327
328        assert!(elapsed >= DELAY);
329    }
330
331    #[test]
332    fn context_task_id() {
333        use super::TaskId;
334
335        gst::init().unwrap();
336
337        let context = Context::acquire("context_task_id", SLEEP_DURATION).unwrap();
338        let join_handle = context.spawn(async {
339            let (ctx, task_id) = Context::current_task().unwrap();
340            assert_eq!(ctx.name(), "context_task_id");
341            assert_eq!(task_id, TaskId(0));
342        });
343        futures::executor::block_on(join_handle).unwrap();
344        // TaskId(0) is vacant again
345
346        let ctx_weak = context.downgrade();
347        let join_handle = context.spawn(async move {
348            let (ctx, task_id) = Context::current_task().unwrap();
349            assert_eq!(task_id, TaskId(0));
350
351            let res = ctx.add_sub_task(task_id, async move {
352                let (_ctx, task_id) = Context::current_task().unwrap();
353                assert_eq!(task_id, TaskId(0));
354                Ok(())
355            });
356            assert!(res.is_ok());
357
358            ctx_weak
359                .upgrade()
360                .unwrap()
361                .spawn(async {
362                    let (ctx, task_id) = Context::current_task().unwrap();
363                    assert_eq!(task_id, TaskId(1));
364
365                    let res = ctx.add_sub_task(task_id, async move {
366                        let (_ctx, task_id) = Context::current_task().unwrap();
367                        assert_eq!(task_id, TaskId(1));
368                        Ok(())
369                    });
370                    assert!(res.is_ok());
371                    assert!(Context::drain_sub_tasks().await.is_ok());
372
373                    let (_ctx, task_id) = Context::current_task().unwrap();
374                    assert_eq!(task_id, TaskId(1));
375                })
376                .await
377                .unwrap();
378
379            assert!(Context::drain_sub_tasks().await.is_ok());
380
381            let (_ctx, task_id) = Context::current_task().unwrap();
382            assert_eq!(task_id, TaskId(0));
383        });
384        futures::executor::block_on(join_handle).unwrap();
385    }
386
387    #[test]
388    fn drain_sub_tasks() {
389        // Setup
390        gst::init().unwrap();
391
392        let context = Context::acquire("drain_sub_tasks", SLEEP_DURATION).unwrap();
393
394        let join_handle = context.spawn(async {
395            let (sender, mut receiver) = mpsc::channel(1);
396            let sender: Arc<Mutex<mpsc::Sender<Item>>> = Arc::new(Mutex::new(sender));
397
398            let add_sub_task = move |item| {
399                let sender = sender.clone();
400                Context::current_task()
401                    .ok_or(())
402                    .and_then(|(ctx, task_id)| {
403                        ctx.add_sub_task(task_id, async move {
404                            sender
405                                .lock()
406                                .await
407                                .send(item)
408                                .await
409                                .map_err(|_| gst::FlowError::Error)
410                        })
411                        .map_err(drop)
412                    })
413            };
414
415            // Tests
416
417            // Drain empty queue
418            let drain_fut = Context::drain_sub_tasks();
419            drain_fut.await.unwrap();
420
421            // Add a subtask
422            add_sub_task(0).unwrap();
423
424            // Check that it was not executed yet
425            receiver.try_recv().unwrap_err();
426
427            // Drain it now and check that it was executed
428            let drain_fut = Context::drain_sub_tasks();
429            drain_fut.await.unwrap();
430            assert_eq!(receiver.try_recv(), Ok(0));
431
432            // Add another task and check that it's not executed yet
433            add_sub_task(1).unwrap();
434            receiver.try_recv().unwrap_err();
435
436            // Return the receiver
437            receiver
438        });
439
440        let mut receiver = futures::executor::block_on(join_handle).unwrap();
441
442        // The last sub task should be simply dropped at this point
443        match receiver.try_recv() {
444            Err(_) => (),
445            other => panic!("Unexpected {other:?}"),
446        }
447    }
448
449    #[test]
450    fn block_on_from_sync() {
451        gst::init().unwrap();
452
453        let context = Context::acquire("block_on_from_sync", SLEEP_DURATION).unwrap();
454
455        let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
456            let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5001);
457            let socket = Async::<UdpSocket>::bind(saddr).unwrap();
458            let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4001);
459            socket.send_to(&[0; 10], saddr).await.unwrap()
460        }))
461        .unwrap();
462        assert_eq!(bytes_sent, 10);
463
464        let elapsed = crate::runtime::executor::block_on(context.spawn(async {
465            let start = Instant::now();
466            crate::runtime::timer::delay_for(DELAY).await;
467            start.elapsed()
468        }))
469        .unwrap();
470        // Due to throttling, `Delay` may be fired earlier
471        assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
472    }
473
474    #[test]
475    #[should_panic]
476    fn block_on_from_context() {
477        gst::init().unwrap();
478
479        let context = Context::acquire("block_on_from_context", SLEEP_DURATION).unwrap();
480
481        // Panic: attempt to `runtime::executor::block_on` within a `Context` thread
482        let join_handle = context.spawn(async {
483            crate::runtime::executor::block_on(crate::runtime::timer::delay_for(DELAY));
484        });
485
486        // Panic: task has failed
487        // (enforced by `async-task`, see comment in `Future` impl for `JoinHandle`).
488        futures::executor::block_on(join_handle).unwrap_err();
489    }
490
491    #[test]
492    fn enter_context_from_scheduler() {
493        gst::init().unwrap();
494
495        let elapsed = crate::runtime::executor::block_on(async {
496            let context = Context::acquire("enter_context_from_executor", SLEEP_DURATION).unwrap();
497            let socket = context
498                .enter(|| {
499                    let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002);
500                    Async::<UdpSocket>::bind(saddr)
501                })
502                .unwrap();
503
504            let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4002);
505            let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap();
506            assert_eq!(bytes_sent, 10);
507
508            let (start, timer) =
509                context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
510            timer.await;
511            start.elapsed()
512        });
513
514        // Due to throttling, `Delay` may be fired earlier
515        assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
516    }
517
518    #[test]
519    fn enter_context_from_sync() {
520        gst::init().unwrap();
521
522        let context = Context::acquire("enter_context_from_sync", SLEEP_DURATION).unwrap();
523        let socket = context
524            .enter(|| {
525                let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5003);
526                Async::<UdpSocket>::bind(saddr)
527            })
528            .unwrap();
529
530        let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4003);
531        let bytes_sent = futures::executor::block_on(socket.send_to(&[0; 10], saddr)).unwrap();
532        assert_eq!(bytes_sent, 10);
533
534        let (start, timer) =
535            context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
536        let elapsed = crate::runtime::executor::block_on(async move {
537            timer.await;
538            start.elapsed()
539        });
540        // Due to throttling, `Delay` may be fired earlier
541        assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
542    }
543}