gstthreadshare/runtime/executor/
context.rs1use futures::prelude::*;
11
12use std::sync::LazyLock;
13
14use std::collections::HashMap;
15use std::io;
16use std::pin::Pin;
17use std::sync::{Arc, Mutex};
18use std::task::{self, Poll};
19use std::time::Duration;
20
21use super::{scheduler, JoinHandle, SubTaskOutput, TaskId};
22use crate::runtime::RUNTIME_CAT;
23
24static CONTEXTS: LazyLock<Mutex<HashMap<Arc<str>, ContextWeak>>> =
35 LazyLock::new(|| Mutex::new(HashMap::new()));
36
37#[track_caller]
50pub fn block_on_or_add_subtask<Fut>(future: Fut) -> Option<Fut::Output>
51where
52 Fut: Future + Send + 'static,
53 Fut::Output: Send + 'static,
54{
55 if let Some((cur_context, cur_task_id)) = Context::current_task() {
56 gst::debug!(
57 RUNTIME_CAT,
58 "Adding subtask to task {:?} on context {}",
59 cur_task_id,
60 cur_context.name()
61 );
62 let _ = cur_context.add_sub_task(cur_task_id, async move {
63 future.await;
64 Ok(())
65 });
66 return None;
67 }
68
69 Some(block_on(future))
71}
72
73#[track_caller]
84pub fn block_on<Fut>(future: Fut) -> Fut::Output
85where
86 Fut: Future + Send + 'static,
87 Fut::Output: Send + 'static,
88{
89 gst::log!(RUNTIME_CAT, "Blocking on local thread");
90 scheduler::Blocking::block_on(future)
91}
92
93#[inline]
95pub fn yield_now() -> YieldNow {
96 YieldNow::default()
97}
98
99#[derive(Debug, Default)]
100#[must_use = "futures do nothing unless you `.await` or poll them"]
101pub struct YieldNow(bool);
102
103impl Future for YieldNow {
104 type Output = ();
105
106 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
107 if !self.0 {
108 self.0 = true;
109 cx.waker().wake_by_ref();
110 Poll::Pending
111 } else {
112 Poll::Ready(())
113 }
114 }
115}
116
117#[derive(Clone, Debug)]
118pub struct ContextWeak(scheduler::ThrottlingHandleWeak);
119
120impl ContextWeak {
121 pub fn upgrade(&self) -> Option<Context> {
122 self.0.upgrade().map(Context)
123 }
124}
125
126#[derive(Clone, Debug)]
137pub struct Context(scheduler::ThrottlingHandle);
138
139impl PartialEq for Context {
140 fn eq(&self, other: &Self) -> bool {
141 self.0.eq(&other.0)
142 }
143}
144
145impl Eq for Context {}
146
147impl Context {
148 pub fn acquire(context_name: &str, wait: Duration) -> Result<Self, io::Error> {
149 let mut contexts = CONTEXTS.lock().unwrap();
150
151 if let Some(context_weak) = contexts.get(context_name) {
152 if let Some(context) = context_weak.upgrade() {
153 gst::debug!(RUNTIME_CAT, "Joining Context '{}'", context.name());
154 return Ok(context);
155 }
156 }
157
158 let context = Context(scheduler::Throttling::start(context_name, wait));
159 contexts.insert(context_name.into(), context.downgrade());
160
161 gst::debug!(
162 RUNTIME_CAT,
163 "New Context '{}' throttling {wait:?}",
164 context.name(),
165 );
166 Ok(context)
167 }
168
169 pub fn downgrade(&self) -> ContextWeak {
170 ContextWeak(self.0.downgrade())
171 }
172
173 pub fn name(&self) -> &str {
174 self.0.context_name()
175 }
176
177 pub fn wait_duration(&self) -> Duration {
181 self.0.max_throttling()
182 }
183
184 #[cfg(feature = "tuning")]
188 pub fn parked_duration(&self) -> Duration {
189 self.0.parked_duration()
190 }
191
192 pub fn is_context_thread() -> bool {
194 scheduler::Throttling::is_throttling_thread()
195 }
196
197 pub fn current() -> Option<Context> {
199 scheduler::Throttling::current().map(Context)
200 }
201
202 pub fn current_task() -> Option<(Context, TaskId)> {
204 Option::zip(
205 scheduler::Throttling::current().map(Context),
206 TaskId::current(),
207 )
208 }
209
210 #[track_caller]
220 pub fn enter<'a, F, O>(&'a self, f: F) -> O
221 where
222 F: FnOnce() -> O + Send + 'a,
223 O: Send + 'a,
224 {
225 if let Some(cur) = Context::current().as_ref() {
226 if cur == self {
227 panic!(
228 "Attempt to enter Context {} within itself, this would deadlock",
229 self.name()
230 );
231 } else {
232 gst::warning!(
233 RUNTIME_CAT,
234 "Entering Context {} within {}",
235 self.name(),
236 cur.name()
237 );
238 }
239 } else {
240 gst::debug!(RUNTIME_CAT, "Entering Context {}", self.name());
241 }
242
243 self.0.enter(f)
244 }
245
246 pub fn spawn<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
247 where
248 Fut: Future + Send + 'static,
249 Fut::Output: Send + 'static,
250 {
251 self.0.spawn(future)
252 }
253
254 pub fn spawn_and_unpark<Fut>(&self, future: Fut) -> JoinHandle<Fut::Output>
255 where
256 Fut: Future + Send + 'static,
257 Fut::Output: Send + 'static,
258 {
259 self.0.spawn_and_unpark(future)
260 }
261
262 pub(in crate::runtime) fn unpark(&self) {
270 self.0.unpark();
271 }
272
273 pub fn add_sub_task<T>(&self, task_id: TaskId, sub_task: T) -> Result<(), T>
274 where
275 T: Future<Output = SubTaskOutput> + Send + 'static,
276 {
277 self.0.add_sub_task(task_id, sub_task)
278 }
279
280 pub async fn drain_sub_tasks() -> SubTaskOutput {
281 let (ctx, task_id) = match Context::current_task() {
282 Some(task) => task,
283 None => return Ok(()),
284 };
285
286 ctx.0.drain_sub_tasks(task_id).await
287 }
288}
289
290impl From<scheduler::ThrottlingHandle> for Context {
291 fn from(handle: scheduler::ThrottlingHandle) -> Self {
292 Context(handle)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use futures::channel::mpsc;
299 use futures::lock::Mutex;
300 use futures::prelude::*;
301
302 use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
303 use std::sync::Arc;
304 use std::time::{Duration, Instant};
305
306 use super::Context;
307 use crate::runtime::Async;
308
309 type Item = i32;
310
311 const SLEEP_DURATION_MS: u64 = 2;
312 const SLEEP_DURATION: Duration = Duration::from_millis(SLEEP_DURATION_MS);
313 const DELAY: Duration = Duration::from_millis(SLEEP_DURATION_MS * 10);
314
315 #[test]
316 fn block_on_timer() {
317 gst::init().unwrap();
318
319 let elapsed = crate::runtime::executor::block_on(async {
320 let now = Instant::now();
321 crate::runtime::timer::delay_for(DELAY).await;
322 now.elapsed()
323 });
324
325 assert!(elapsed >= DELAY);
326 }
327
328 #[test]
329 fn context_task_id() {
330 use super::TaskId;
331
332 gst::init().unwrap();
333
334 let context = Context::acquire("context_task_id", SLEEP_DURATION).unwrap();
335 let join_handle = context.spawn(async {
336 let (ctx, task_id) = Context::current_task().unwrap();
337 assert_eq!(ctx.name(), "context_task_id");
338 assert_eq!(task_id, TaskId(0));
339 });
340 futures::executor::block_on(join_handle).unwrap();
341 let ctx_weak = context.downgrade();
344 let join_handle = context.spawn(async move {
345 let (ctx, task_id) = Context::current_task().unwrap();
346 assert_eq!(task_id, TaskId(0));
347
348 let res = ctx.add_sub_task(task_id, async move {
349 let (_ctx, task_id) = Context::current_task().unwrap();
350 assert_eq!(task_id, TaskId(0));
351 Ok(())
352 });
353 assert!(res.is_ok());
354
355 ctx_weak
356 .upgrade()
357 .unwrap()
358 .spawn(async {
359 let (ctx, task_id) = Context::current_task().unwrap();
360 assert_eq!(task_id, TaskId(1));
361
362 let res = ctx.add_sub_task(task_id, async move {
363 let (_ctx, task_id) = Context::current_task().unwrap();
364 assert_eq!(task_id, TaskId(1));
365 Ok(())
366 });
367 assert!(res.is_ok());
368 assert!(Context::drain_sub_tasks().await.is_ok());
369
370 let (_ctx, task_id) = Context::current_task().unwrap();
371 assert_eq!(task_id, TaskId(1));
372 })
373 .await
374 .unwrap();
375
376 assert!(Context::drain_sub_tasks().await.is_ok());
377
378 let (_ctx, task_id) = Context::current_task().unwrap();
379 assert_eq!(task_id, TaskId(0));
380 });
381 futures::executor::block_on(join_handle).unwrap();
382 }
383
384 #[test]
385 fn drain_sub_tasks() {
386 gst::init().unwrap();
388
389 let context = Context::acquire("drain_sub_tasks", SLEEP_DURATION).unwrap();
390
391 let join_handle = context.spawn(async {
392 let (sender, mut receiver) = mpsc::channel(1);
393 let sender: Arc<Mutex<mpsc::Sender<Item>>> = Arc::new(Mutex::new(sender));
394
395 let add_sub_task = move |item| {
396 let sender = sender.clone();
397 Context::current_task()
398 .ok_or(())
399 .and_then(|(ctx, task_id)| {
400 ctx.add_sub_task(task_id, async move {
401 sender
402 .lock()
403 .await
404 .send(item)
405 .await
406 .map_err(|_| gst::FlowError::Error)
407 })
408 .map_err(drop)
409 })
410 };
411
412 let drain_fut = Context::drain_sub_tasks();
416 drain_fut.await.unwrap();
417
418 add_sub_task(0).unwrap();
420
421 receiver.try_next().unwrap_err();
423
424 let drain_fut = Context::drain_sub_tasks();
426 drain_fut.await.unwrap();
427 assert_eq!(receiver.try_next().unwrap(), Some(0));
428
429 add_sub_task(1).unwrap();
431 receiver.try_next().unwrap_err();
432
433 receiver
435 });
436
437 let mut receiver = futures::executor::block_on(join_handle).unwrap();
438
439 match receiver.try_next() {
441 Ok(None) | Err(_) => (),
442 other => panic!("Unexpected {other:?}"),
443 }
444 }
445
446 #[test]
447 fn block_on_from_sync() {
448 gst::init().unwrap();
449
450 let context = Context::acquire("block_on_from_sync", SLEEP_DURATION).unwrap();
451
452 let bytes_sent = crate::runtime::executor::block_on(context.spawn(async {
453 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5001);
454 let socket = Async::<UdpSocket>::bind(saddr).unwrap();
455 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4001);
456 socket.send_to(&[0; 10], saddr).await.unwrap()
457 }))
458 .unwrap();
459 assert_eq!(bytes_sent, 10);
460
461 let elapsed = crate::runtime::executor::block_on(context.spawn(async {
462 let start = Instant::now();
463 crate::runtime::timer::delay_for(DELAY).await;
464 start.elapsed()
465 }))
466 .unwrap();
467 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
469 }
470
471 #[test]
472 #[should_panic]
473 fn block_on_from_context() {
474 gst::init().unwrap();
475
476 let context = Context::acquire("block_on_from_context", SLEEP_DURATION).unwrap();
477
478 let join_handle = context.spawn(async {
480 crate::runtime::executor::block_on(crate::runtime::timer::delay_for(DELAY));
481 });
482
483 futures::executor::block_on(join_handle).unwrap_err();
486 }
487
488 #[test]
489 fn enter_context_from_scheduler() {
490 gst::init().unwrap();
491
492 let elapsed = crate::runtime::executor::block_on(async {
493 let context = Context::acquire("enter_context_from_executor", SLEEP_DURATION).unwrap();
494 let socket = context
495 .enter(|| {
496 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5002);
497 Async::<UdpSocket>::bind(saddr)
498 })
499 .unwrap();
500
501 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4002);
502 let bytes_sent = socket.send_to(&[0; 10], saddr).await.unwrap();
503 assert_eq!(bytes_sent, 10);
504
505 let (start, timer) =
506 context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
507 timer.await;
508 start.elapsed()
509 });
510
511 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
513 }
514
515 #[test]
516 fn enter_context_from_sync() {
517 gst::init().unwrap();
518
519 let context = Context::acquire("enter_context_from_sync", SLEEP_DURATION).unwrap();
520 let socket = context
521 .enter(|| {
522 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5003);
523 Async::<UdpSocket>::bind(saddr)
524 })
525 .unwrap();
526
527 let saddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 4003);
528 let bytes_sent = futures::executor::block_on(socket.send_to(&[0; 10], saddr)).unwrap();
529 assert_eq!(bytes_sent, 10);
530
531 let (start, timer) =
532 context.enter(|| (Instant::now(), crate::runtime::timer::delay_for(DELAY)));
533 let elapsed = crate::runtime::executor::block_on(async move {
534 timer.await;
535 start.elapsed()
536 });
537 assert!(elapsed + SLEEP_DURATION / 2 >= DELAY);
539 }
540}