gstthreadshare/runtime/executor/
context.rs1use futures::prelude::*;
11
12use std::sync::LazyLock;
13
14use std::collections::HashMap;
15use std::io;
16use std::pin::Pin;
17use std::sync::{Arc, Mutex};
18use std::task::{self, Poll};
19use std::time::Duration;
20
21use super::{JoinHandle, SubTaskOutput, TaskId, scheduler};
22use crate::runtime::RUNTIME_CAT;
23
24static CONTEXTS: LazyLock<Mutex<HashMap<Arc<str>, ContextWeak>>> =
35 LazyLock::new(|| Mutex::new(HashMap::new()));
36
37#[track_caller]
50pub fn block_on_or_add_subtask<Fut>(future: Fut) -> Option<Fut::Output>
51where
52 Fut: Future + Send + 'static,
53 Fut::Output: Send + 'static,
54{
55 if let Some((cur_context, cur_task_id)) = Context::current_task() {
56 gst::debug!(
57 RUNTIME_CAT,
58 "Adding subtask to task {:?} on context {}",
59 cur_task_id,
60 cur_context.name()
61 );
62 let _ = cur_context.add_sub_task(cur_task_id, async move {
63 future.await;
64 Ok(())
65 });
66 return None;
67 }
68
69 Some(block_on(future))
71}
72
73#[track_caller]
84pub fn block_on<Fut>(future: Fut) -> Fut::Output
85where
86 Fut: Future + Send + 'static,
87 Fut::Output: Send + 'static,
88{
89 gst::log!(RUNTIME_CAT, "Blocking on local thread");
90 scheduler::Blocking::block_on(future)
91}
92
93#[inline]
95pub fn yield_now() -> YieldNow {
96 YieldNow::default()
97}
98
99#[derive(Debug, Default)]
100#[must_use = "futures do nothing unless you `.await` or poll them"]
101pub struct YieldNow(bool);
102
103impl Future for YieldNow {
104 type Output = ();
105
106 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
107 if !self.0 {
108 self.0 = true;
109 cx.waker().wake_by_ref();
110 Poll::Pending
111 } else {
112 Poll::Ready(())
113 }
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct ContextWeak(scheduler::ThrottlingHandleWeak);
119
120impl ContextWeak {
121 pub fn upgrade(&self) -> Option<Context> {
122 self.0.upgrade().map(Context)
123 }
124}
125
126#[derive(Clone, Debug)]
137pub struct Context(scheduler::ThrottlingHandle);
138
139impl PartialEq for Context {
140 fn eq(&self, other: &Self) -> bool {
141 self.0.eq(&other.0)
142 }
143}
144
145impl Eq for Context {}
146
147impl Context {
148 pub fn acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error> {
149 let mut contexts = CONTEXTS.lock().unwrap();
150
151 if let Some(context_weak) = contexts.get(context_name)
152 && let Some(context) = context_weak.upgrade()
153 {
154 gst::debug!(RUNTIME_CAT, "Joining Context '{}'", context.name());
155 return Ok(context);
156 }
157
158 let context = Context(scheduler::Throttling::start(context_name, wait));
159 contexts.insert(context_name.into(), context.downgrade());
160
161 gst::debug!(
162 RUNTIME_CAT,
163 "New Context '{}' throttling {wait:?}",
164 context.name(),
165 );
166 Ok(context)
167 }
168
169 pub fn downgrade(&self) -> ContextWeak {
170 ContextWeak(self.0.downgrade())
171 }
172
173 pub fn name(&self) -> &str {
174 self.0.context_name()
175 }
176
177 pub fn wait_duration(&self) -> Duration {
181 self.0.max_throttling()
182 }
183
184 #[cfg(feature = "tuning")]
188 pub fn parked_duration(&self) -> Duration {
189 self.0.parked_duration()
190 }
191
192 pub fn is_context_thread() -> bool {
194 scheduler::Throttling::is_throttling_thread()
195 }
196
197 pub fn current() -> Option<Context> {
199 scheduler::Throttling::current().map(Context)
200 }
201
202 pub fn current_task() -> Option<(Context, TaskId)> {
204 Option::zip(
205 scheduler::Throttling::current().map(Context),
206 TaskId::current(),
207 )
208 }
209
210 #[track_caller]
220 pub fn enter<'a, F, O>(&'a self, f: F) -> O
221 where
222 F: FnOnce() -> O + Send + 'a,
223 O: Send + 'a,
224 {
225 match Context::current().as_ref() {
226 Some(cur) => {
227 if cur == self {
228 panic!(
229 "Attempt to enter Context {} within itself, this would deadlock",
230 self.name()
231 );
232 } else {
233 gst::warning!(
234 RUNTIME_CAT,
235 "Entering Context {} within {}",
236 self.name(),
237 cur.name()
238 );
239 }
240 }
241 _ => {
242 gst::debug!(RUNTIME_CAT, "Entering Context {}", self.name());
243 }
244 }
245
246 self.0.enter(f)
247 }
248
249 pub fn spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
250 where
251 Fut: Future + Send + 'static,
252 Fut::Output: Send + 'static,
253 {
254 self.0.spawn(future)
255 }
256
257 pub fn spawn_and_unpark<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
258 where
259 Fut: Future + Send + 'static,
260 Fut::Output: Send + 'static,
261 {
262 self.0.spawn_and_unpark(future)
263 }
264
265 pub(in crate::runtime) fn unpark(&self) {
273 self.0.unpark();
274 }
275
276 pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
277 where
278 T: Future<Output = SubTaskOutput> + Send + 'static,
279 {
280 self.0.add_sub_task(task_id, sub_task)
281 }
282
283 pub async fn drain_sub_tasks() -> SubTaskOutput {
284 let (ctx, task_id) = match Context::current_task() {
285 Some(task) => task,
286 None => return Ok(()),
287 };
288
289 ctx.0.drain_sub_tasks(task_id).await
290 }
291}
292
293impl From<scheduler::ThrottlingHandle> for Context {
294 fn from(handle: scheduler::ThrottlingHandle) -> Self {
295 Context(handle)
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use futures::channel::mpsc;
302 use futures::lock::Mutex;
303 use futures::prelude::*;
304
305 use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
306 use std::sync::Arc;
307 use std::time::{Duration, Instant};
308
309 use super::Context;
310 use crate::runtime::Async;
311
312 type Item = i32;
313
314 const SLEEP_DURATION_MS: u64 = 2;
315 const SLEEP_DURATION: Duration = Duration::from_millis(SLEEP_DURATION_MS);
316 const DELAY: Duration = Duration::from_millis(SLEEP_DURATION_MS * 10);
317
318 #[test]
319 fn block_on_timer() {
320 gst::init().unwrap();
321
322 let elapsed = crate::runtime::executor::block_on(async {
323 let now = Instant::now();
324 crate::runtime::timer::delay_for(DELAY).await;
325 now.elapsed()
326 });
327
328 assert!(elapsed >= DELAY);
329 }
330
331 #[test]
332 fn context_task_id() {
333 use super::TaskId;
334
335 gst::init().unwrap();
336
337 let context = Context::acquire("context_task_id", SLEEP_DURATION).unwrap();
338 let join_handle = context.spawn(async {
339 let (ctx, task_id) = Context::current_task().unwrap();
340 assert_eq!(ctx.name(), "context_task_id");
341 assert_eq!(task_id, TaskId(0));
342 });
343 futures::executor::block_on(join_handle).unwrap();
344 let ctx_weak = context.downgrade();
347 let join_handle = context.spawn(async move {
348 let (ctx, task_id) = Context::current_task().unwrap();
349 assert_eq!(task_id, TaskId(0));
350
351 let res = ctx.add_sub_task(task_id, async move {
352 let (_ctx, task_id) = Context::current_task().unwrap();
353 assert_eq!(task_id, TaskId(0));
354 Ok(())
355 });
356 assert!(res.is_ok());
357
358 ctx_weak
359 .upgrade()
360 .unwrap()
361 .spawn(async {
362 let (ctx, task_id) = Context::current_task().unwrap();
363 assert_eq!(task_id, TaskId(1));
364
365 let res = ctx.add_sub_task(task_id, async move {
366 let (_ctx, task_id) = Context::current_task().unwrap();
367 assert_eq!(task_id, TaskId(1));
368 Ok(())
369 });
370 assert!(res.is_ok());
371 assert!(Context::drain_sub_tasks().await.is_ok());
372
373 let (_ctx, task_id) = Context::current_task().unwrap();
374 assert_eq!(task_id, TaskId(1));
375 })
376 .await
377 .unwrap();
378
379 assert!(Context::drain_sub_tasks().await.is_ok());
380
381 let (_ctx, task_id) = Context::current_task().unwrap();
382 assert_eq!(task_id, TaskId(0));
383 });
384 futures::executor::block_on(join_handle).unwrap();
385 }
386
387 #[test]
388 fn drain_sub_tasks() {
389 gst::init().unwrap();
391
392 let context = Context::acquire("drain_sub_tasks", SLEEP_DURATION).unwrap();
393
394 let join_handle = context.spawn(async {
395 let (sender, mut receiver) = mpsc::channel(1);
396 let sender: Arc<Mutex<mpsc::Sender<Item>>> = Arc::new(Mutex::new(sender));
397
398 let add_sub_task = move |item| {
399 let sender = sender.clone();
400 Context::current_task()
401 .ok_or(())
402 .and_then(|(ctx, task_id)| {
403 ctx.add_sub_task(task_id, async move {
404 sender
405 .lock()
406 .await
407 .send(item)
408 .await
409 .map_err(|_| gst::FlowError::Error)
410 })
411 .map_err(drop)
412 })
413 };
414
415 let drain_fut = Context::drain_sub_tasks();
419 drain_fut.await.unwrap();
420
421 add_sub_task(0).unwrap();
423
424 receiver.try_recv().unwrap_err();
426
427 let drain_fut = Context::drain_sub_tasks();
429 drain_fut.await.unwrap();
430 assert_eq!(receiver.try_recv(), Ok(0));
431
432 add_sub_task(1).unwrap();
434 receiver.try_recv().unwrap_err();
435
436 receiver
438 });
439
440 let mut receiver = futures::executor::block_on(join_handle).unwrap();
441
442 match receiver.try_recv() {
444 Err(_) => (),
445 other => panic!("Unexpected {other:?}"),
446 }
447 }
448
449 #[test]
450 fn block_on_from_sync() {
451 gst::init().unwrap();
452
453 let context = Context::acquire("block_on_from_sync", SLEEP_DURATION).unwrap();
454
455 let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
456 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5001);
457 let socket = Async::<UdpSocket>::bind(saddr).unwrap();
458 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4001);
459 socket.send_to(&[0; 10], saddr).await.unwrap()
460 }))
461 .unwrap();
462 assert_eq!(bytes_sent, 10);
463
464 let elapsed = crate::runtime::executor::block_on(context.spawn(async {
465 let start = Instant::now();
466 crate::runtime::timer::delay_for(DELAY).await;
467 start.elapsed()
468 }))
469 .unwrap();
470 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
472 }
473
474 #[test]
475 #[should_panic]
476 fn block_on_from_context() {
477 gst::init().unwrap();
478
479 let context = Context::acquire("block_on_from_context", SLEEP_DURATION).unwrap();
480
481 let join_handle = context.spawn(async {
483 crate::runtime::executor::block_on(crate::runtime::timer::delay_for(DELAY));
484 });
485
486 futures::executor::block_on(join_handle).unwrap_err();
489 }
490
491 #[test]
492 fn enter_context_from_scheduler() {
493 gst::init().unwrap();
494
495 let elapsed = crate::runtime::executor::block_on(async {
496 let context = Context::acquire("enter_context_from_executor", SLEEP_DURATION).unwrap();
497 let socket = context
498 .enter(|| {
499 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002);
500 Async::<UdpSocket>::bind(saddr)
501 })
502 .unwrap();
503
504 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4002);
505 let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap();
506 assert_eq!(bytes_sent, 10);
507
508 let (start, timer) =
509 context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
510 timer.await;
511 start.elapsed()
512 });
513
514 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
516 }
517
518 #[test]
519 fn enter_context_from_sync() {
520 gst::init().unwrap();
521
522 let context = Context::acquire("enter_context_from_sync", SLEEP_DURATION).unwrap();
523 let socket = context
524 .enter(|| {
525 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5003);
526 Async::<UdpSocket>::bind(saddr)
527 })
528 .unwrap();
529
530 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4003);
531 let bytes_sent = futures::executor::block_on(socket.send_to(&[0; 10], saddr)).unwrap();
532 assert_eq!(bytes_sent, 10);
533
534 let (start, timer) =
535 context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
536 let elapsed = crate::runtime::executor::block_on(async move {
537 timer.await;
538 start.elapsed()
539 });
540 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
542 }
543}