pyo3_async_runtimes/
generic.rs

1//! Generic implementations of PyO3 Asyncio utilities that can be used for any Rust runtime
2//!
3//! Items marked with
4//! <span
5//!   class="module-item stab portability"
6//!   style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"
7//! ><code>unstable-streams</code></span>
8//! > are only available when the `unstable-streams` Cargo feature is enabled:
9//!
10//! ```toml
11//! [dependencies.pyo3-async-runtimes]
12//! version = "0.24"
13//! features = ["unstable-streams"]
14//! ```
15
16use std::{
17    future::Future,
18    pin::Pin,
19    sync::{Arc, Mutex},
20    task::{Context, Poll},
21};
22
23use crate::{
24    asyncio, call_soon_threadsafe, close, create_future, dump_err, err::RustPanic,
25    get_running_loop, into_future_with_locals, TaskLocals,
26};
27use futures::channel::oneshot;
28#[cfg(feature = "unstable-streams")]
29use futures::{channel::mpsc, SinkExt};
30use pin_project_lite::pin_project;
31use pyo3::prelude::*;
32use pyo3::IntoPyObjectExt;
33#[cfg(feature = "unstable-streams")]
34use std::marker::PhantomData;
35
36/// Generic utilities for a JoinError
37pub trait JoinError {
38    /// Check if the spawned task exited because of a panic
39    fn is_panic(&self) -> bool;
40    /// Get the panic object associated with the error.  Panics if `is_panic` is not true.
41    fn into_panic(self) -> Box<dyn std::any::Any + Send + 'static>;
42}
43
44/// Generic Rust async/await runtime
45pub trait Runtime: Send + 'static {
46    /// The error returned by a JoinHandle after being awaited
47    type JoinError: JoinError + Send;
48    /// A future that completes with the result of the spawned task
49    type JoinHandle: Future<Output = Result<(), Self::JoinError>> + Send;
50
51    /// Spawn a future onto this runtime's event loop
52    fn spawn<F>(fut: F) -> Self::JoinHandle
53    where
54        F: Future<Output = ()> + Send + 'static;
55
56    /// Spawn a function onto this runtime's blocking event loop
57    fn spawn_blocking<F>(f: F) -> Self::JoinHandle
58    where
59        F: FnOnce() + Send + 'static;
60}
61
62/// Extension trait for async/await runtimes that support spawning local tasks
63pub trait SpawnLocalExt: Runtime {
64    /// Spawn a !Send future onto this runtime's event loop
65    fn spawn_local<F>(fut: F) -> Self::JoinHandle
66    where
67        F: Future<Output = ()> + 'static;
68}
69
70/// Exposes the utilities necessary for using task-local data in the Runtime
71pub trait ContextExt: Runtime {
72    /// Set the task locals for the given future
73    fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
74    where
75        F: Future<Output = R> + Send + 'static;
76
77    /// Get the task locals for the current task
78    fn get_task_locals() -> Option<TaskLocals>;
79}
80
81/// Adds the ability to scope task-local data for !Send futures
82pub trait LocalContextExt: Runtime {
83    /// Set the task locals for the given !Send future
84    fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
85    where
86        F: Future<Output = R> + 'static;
87}
88
89/// Get the current event loop from either Python or Rust async task local context
90///
91/// This function first checks if the runtime has a task-local reference to the Python event loop.
92/// If not, it calls [`get_running_loop`](crate::get_running_loop`) to get the event loop associated
93/// with the current OS thread.
94pub fn get_current_loop<R>(py: Python) -> PyResult<Bound<PyAny>>
95where
96    R: ContextExt,
97{
98    if let Some(locals) = R::get_task_locals() {
99        Ok(locals.0.event_loop.clone_ref(py).into_bound(py))
100    } else {
101        get_running_loop(py)
102    }
103}
104
105/// Either copy the task locals from the current task OR get the current running loop and
106/// contextvars from Python.
107pub fn get_current_locals<R>(py: Python) -> PyResult<TaskLocals>
108where
109    R: ContextExt,
110{
111    if let Some(locals) = R::get_task_locals() {
112        Ok(locals)
113    } else {
114        Ok(TaskLocals::with_running_loop(py)?.copy_context(py)?)
115    }
116}
117
118/// Run the event loop until the given Future completes
119///
120/// After this function returns, the event loop can be resumed with [`run_until_complete`]
121///
122/// # Arguments
123/// * `event_loop` - The Python event loop that should run the future
124/// * `fut` - The future to drive to completion
125///
126/// # Examples
127///
128/// ```no_run
129/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
130/// #
131/// # use pyo3_async_runtimes::{
132/// #     TaskLocals,
133/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
134/// # };
135/// #
136/// # struct MyCustomJoinError;
137/// #
138/// # impl JoinError for MyCustomJoinError {
139/// #     fn is_panic(&self) -> bool {
140/// #         unreachable!()
141/// #     }
142/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
143/// #         unreachable!()
144/// #     }
145/// # }
146/// #
147/// # struct MyCustomJoinHandle;
148/// #
149/// # impl Future for MyCustomJoinHandle {
150/// #     type Output = Result<(), MyCustomJoinError>;
151/// #
152/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
153/// #         unreachable!()
154/// #     }
155/// # }
156/// #
157/// # struct MyCustomRuntime;
158/// #
159/// # impl Runtime for MyCustomRuntime {
160/// #     type JoinError = MyCustomJoinError;
161/// #     type JoinHandle = MyCustomJoinHandle;
162/// #
163/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
164/// #     where
165/// #         F: Future<Output = ()> + Send + 'static
166/// #     {
167/// #         unreachable!()
168/// #     }
169/// #
170/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
171/// #         unreachable!()
172/// #     }
173/// # }
174/// #
175/// # impl ContextExt for MyCustomRuntime {
176/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
177/// #     where
178/// #         F: Future<Output = R> + Send + 'static
179/// #     {
180/// #         unreachable!()
181/// #     }
182/// #     fn get_task_locals() -> Option<TaskLocals> {
183/// #         unreachable!()
184/// #     }
185/// # }
186/// #
187/// # use std::time::Duration;
188/// #
189/// # use pyo3::prelude::*;
190/// #
191/// # Python::attach(|py| -> PyResult<()> {
192/// # let event_loop = py.import("asyncio")?.call_method0("new_event_loop")?;
193/// # #[cfg(feature = "tokio-runtime")]
194/// pyo3_async_runtimes::generic::run_until_complete::<MyCustomRuntime, _, _>(&event_loop, async move {
195///     tokio::time::sleep(Duration::from_secs(1)).await;
196///     Ok(())
197/// })?;
198/// # Ok(())
199/// # }).unwrap();
200/// ```
201pub fn run_until_complete<R, F, T>(event_loop: &Bound<PyAny>, fut: F) -> PyResult<T>
202where
203    R: Runtime + ContextExt,
204    F: Future<Output = PyResult<T>> + Send + 'static,
205    T: Send + Sync + 'static,
206{
207    let py = event_loop.py();
208    let result_tx = Arc::new(Mutex::new(None));
209    let result_rx = Arc::clone(&result_tx);
210    let coro = future_into_py_with_locals::<R, _, ()>(
211        py,
212        TaskLocals::new(event_loop.clone()).copy_context(py)?,
213        async move {
214            let val = fut.await?;
215            if let Ok(mut result) = result_tx.lock() {
216                *result = Some(val);
217            }
218            Ok(())
219        },
220    )?;
221
222    event_loop.call_method1(pyo3::intern!(py, "run_until_complete"), (coro,))?;
223
224    let result = result_rx.lock().unwrap().take().unwrap();
225    Ok(result)
226}
227
228/// Run the event loop until the given Future completes
229///
230/// # Arguments
231/// * `py` - The current PyO3 GIL guard
232/// * `fut` - The future to drive to completion
233///
234/// # Examples
235///
236/// ```no_run
237/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
238/// #
239/// # use pyo3_async_runtimes::{
240/// #     TaskLocals,
241/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
242/// # };
243/// #
244/// # struct MyCustomJoinError;
245/// #
246/// # impl JoinError for MyCustomJoinError {
247/// #     fn is_panic(&self) -> bool {
248/// #         unreachable!()
249/// #     }
250/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
251/// #         unreachable!()
252/// #     }
253/// # }
254/// #
255/// # struct MyCustomJoinHandle;
256/// #
257/// # impl Future for MyCustomJoinHandle {
258/// #     type Output = Result<(), MyCustomJoinError>;
259/// #
260/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
261/// #         unreachable!()
262/// #     }
263/// # }
264/// #
265/// # struct MyCustomRuntime;
266/// #
267/// # impl Runtime for MyCustomRuntime {
268/// #     type JoinError = MyCustomJoinError;
269/// #     type JoinHandle = MyCustomJoinHandle;
270/// #
271/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
272/// #     where
273/// #         F: Future<Output = ()> + Send + 'static
274/// #     {
275/// #         unreachable!()
276/// #     }
277/// #
278/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
279/// #         unreachable!()
280/// #     }
281/// # }
282/// #
283/// # impl ContextExt for MyCustomRuntime {
284/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
285/// #     where
286/// #         F: Future<Output = R> + Send + 'static
287/// #     {
288/// #         unreachable!()
289/// #     }
290/// #     fn get_task_locals() -> Option<TaskLocals> {
291/// #         unreachable!()
292/// #     }
293/// # }
294/// #
295/// # use std::time::Duration;
296/// # async fn custom_sleep(_duration: Duration) { }
297/// #
298/// # use pyo3::prelude::*;
299/// #
300/// fn main() {
301///     Python::attach(|py| {
302///         pyo3_async_runtimes::generic::run::<MyCustomRuntime, _, _>(py, async move {
303///             custom_sleep(Duration::from_secs(1)).await;
304///             Ok(())
305///         })
306///         .map_err(|e| {
307///             e.print_and_set_sys_last_vars(py);
308///         })
309///         .unwrap();
310///     })
311/// }
312/// ```
313pub fn run<R, F, T>(py: Python, fut: F) -> PyResult<T>
314where
315    R: Runtime + ContextExt,
316    F: Future<Output = PyResult<T>> + Send + 'static,
317    T: Send + Sync + 'static,
318{
319    let event_loop = asyncio(py)?.call_method0(pyo3::intern!(py, "new_event_loop"))?;
320
321    let result = run_until_complete::<R, F, T>(&event_loop, fut);
322
323    close(event_loop)?;
324
325    result
326}
327
328fn cancelled(future: &Bound<PyAny>) -> PyResult<bool> {
329    future
330        .getattr(pyo3::intern!(future.py(), "cancelled"))?
331        .call0()?
332        .is_truthy()
333}
334
335#[pyclass]
336struct CheckedCompletor;
337
338#[pymethods]
339impl CheckedCompletor {
340    fn __call__(
341        &self,
342        future: &Bound<PyAny>,
343        complete: &Bound<PyAny>,
344        value: &Bound<PyAny>,
345    ) -> PyResult<()> {
346        if cancelled(future)? {
347            return Ok(());
348        }
349
350        complete.call1((value,))?;
351
352        Ok(())
353    }
354}
355
356fn set_result(
357    event_loop: &Bound<PyAny>,
358    future: &Bound<PyAny>,
359    result: PyResult<Py<PyAny>>,
360) -> PyResult<()> {
361    let py = event_loop.py();
362    let none = py.None().into_bound(py);
363
364    let (complete, val) = match result {
365        Ok(val) => (
366            future.getattr(pyo3::intern!(py, "set_result"))?,
367            val.into_pyobject(py)?,
368        ),
369        Err(err) => (
370            future.getattr(pyo3::intern!(py, "set_exception"))?,
371            err.into_bound_py_any(py)?,
372        ),
373    };
374    call_soon_threadsafe(event_loop, &none, (CheckedCompletor, future, complete, val))?;
375
376    Ok(())
377}
378
379/// Convert a Python `awaitable` into a Rust Future
380///
381/// This function simply forwards the future and the task locals returned by [`get_current_locals`]
382/// to [`into_future_with_locals`](`crate::into_future_with_locals`). See
383/// [`into_future_with_locals`](`crate::into_future_with_locals`) for more details.
384///
385/// # Arguments
386/// * `awaitable` - The Python `awaitable` to be converted
387///
388/// # Examples
389///
390/// ```no_run
391/// # use std::{any::Any, pin::Pin, future::Future, task::{Context, Poll}, time::Duration};
392/// # use std::ffi::CString;
393/// #
394/// # use pyo3::prelude::*;
395/// #
396/// # use pyo3_async_runtimes::{
397/// #     TaskLocals,
398/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
399/// # };
400/// #
401/// # struct MyCustomJoinError;
402/// #
403/// # impl JoinError for MyCustomJoinError {
404/// #     fn is_panic(&self) -> bool {
405/// #         unreachable!()
406/// #     }
407/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
408/// #         unreachable!()
409/// #     }
410/// # }
411/// #
412/// # struct MyCustomJoinHandle;
413/// #
414/// # impl Future for MyCustomJoinHandle {
415/// #     type Output = Result<(), MyCustomJoinError>;
416/// #
417/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
418/// #         unreachable!()
419/// #     }
420/// # }
421/// #
422/// # struct MyCustomRuntime;
423/// #
424/// # impl MyCustomRuntime {
425/// #     async fn sleep(_: Duration) {
426/// #         unreachable!()
427/// #     }
428/// # }
429/// #
430/// # impl Runtime for MyCustomRuntime {
431/// #     type JoinError = MyCustomJoinError;
432/// #     type JoinHandle = MyCustomJoinHandle;
433/// #
434/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
435/// #     where
436/// #         F: Future<Output = ()> + Send + 'static
437/// #     {
438/// #         unreachable!()
439/// #     }
440/// #
441/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
442/// #         unreachable!()
443/// #     }
444/// # }
445/// #
446/// # impl ContextExt for MyCustomRuntime {
447/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
448/// #     where
449/// #         F: Future<Output = R> + Send + 'static
450/// #     {
451/// #         unreachable!()
452/// #     }
453/// #     fn get_task_locals() -> Option<TaskLocals> {
454/// #         unreachable!()
455/// #     }
456/// # }
457/// #
458/// const PYTHON_CODE: &'static str = r#"
459/// import asyncio
460///
461/// async def py_sleep(duration):
462///     await asyncio.sleep(duration)
463/// "#;
464///
465/// async fn py_sleep(seconds: f32) -> PyResult<()> {
466///     let test_mod = Python::attach(|py| -> PyResult<Py<PyAny>> {
467///         Ok(
468///             PyModule::from_code(
469///                 py,
470///                 &CString::new(PYTHON_CODE).unwrap(),
471///                 &CString::new("test_into_future/test_mod.py").unwrap(),
472///                 &CString::new("test_mod").unwrap(),
473///             )?
474///             .into()
475///         )
476///     })?;
477///
478///     Python::attach(|py| {
479///         pyo3_async_runtimes::generic::into_future::<MyCustomRuntime>(
480///             test_mod
481///                 .call_method1(py, "py_sleep", (seconds,))?
482///                 .into_bound(py),
483///         )
484///     })?
485///     .await?;
486///     Ok(())
487/// }
488/// ```
489pub fn into_future<R>(
490    awaitable: Bound<PyAny>,
491) -> PyResult<impl Future<Output = PyResult<Py<PyAny>>> + Send>
492where
493    R: Runtime + ContextExt,
494{
495    into_future_with_locals(&get_current_locals::<R>(awaitable.py())?, awaitable)
496}
497
498/// Convert a Rust Future into a Python awaitable with a generic runtime
499///
500/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
501/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
502///
503/// Python `contextvars` are preserved when calling async Python functions within the Rust future
504/// via [`into_future`] (new behaviour in `v0.15`).
505///
506/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
507/// > unfortunately fail to resolve them when called within the Rust future. This is because the
508/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
509/// >
510/// > As a workaround, you can get the `contextvars` from the current task locals using
511/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
512/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
513/// > synchronous function, and restore the previous context when it returns or raises an exception.
514///
515/// # Arguments
516/// * `py` - PyO3 GIL guard
517/// * `locals` - The task-local data for Python
518/// * `fut` - The Rust future to be converted
519///
520/// # Examples
521///
522/// ```no_run
523/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
524/// #
525/// # use pyo3_async_runtimes::{
526/// #     TaskLocals,
527/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
528/// # };
529/// #
530/// # struct MyCustomJoinError;
531/// #
532/// # impl JoinError for MyCustomJoinError {
533/// #     fn is_panic(&self) -> bool {
534/// #         unreachable!()
535/// #     }
536/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
537/// #         unreachable!()
538/// #     }
539/// # }
540/// #
541/// # struct MyCustomJoinHandle;
542/// #
543/// # impl Future for MyCustomJoinHandle {
544/// #     type Output = Result<(), MyCustomJoinError>;
545/// #
546/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
547/// #         unreachable!()
548/// #     }
549/// # }
550/// #
551/// # struct MyCustomRuntime;
552/// #
553/// # impl MyCustomRuntime {
554/// #     async fn sleep(_: Duration) {
555/// #         unreachable!()
556/// #     }
557/// # }
558/// #
559/// # impl Runtime for MyCustomRuntime {
560/// #     type JoinError = MyCustomJoinError;
561/// #     type JoinHandle = MyCustomJoinHandle;
562/// #
563/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
564/// #     where
565/// #         F: Future<Output = ()> + Send + 'static
566/// #     {
567/// #         unreachable!()
568/// #     }
569/// #
570/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
571/// #         unreachable!()
572/// #     }
573/// # }
574/// #
575/// # impl ContextExt for MyCustomRuntime {
576/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
577/// #     where
578/// #         F: Future<Output = R> + Send + 'static
579/// #     {
580/// #         unreachable!()
581/// #     }
582/// #     fn get_task_locals() -> Option<TaskLocals> {
583/// #         unreachable!()
584/// #     }
585/// # }
586/// #
587/// use std::time::Duration;
588///
589/// use pyo3::prelude::*;
590///
591/// /// Awaitable sleep function
592/// #[pyfunction]
593/// fn sleep_for<'p>(py: Python<'p>, secs: Bound<'p, PyAny>) -> PyResult<Bound<'p, PyAny>> {
594///     let secs = secs.extract()?;
595///     pyo3_async_runtimes::generic::future_into_py_with_locals::<MyCustomRuntime, _, _>(
596///         py,
597///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
598///         async move {
599///             MyCustomRuntime::sleep(Duration::from_secs(secs)).await;
600///             Ok(())
601///         }
602///     )
603/// }
604/// ```
605#[allow(unused_must_use)]
606pub fn future_into_py_with_locals<R, F, T>(
607    py: Python,
608    locals: TaskLocals,
609    fut: F,
610) -> PyResult<Bound<PyAny>>
611where
612    R: Runtime + ContextExt,
613    F: Future<Output = PyResult<T>> + Send + 'static,
614    T: for<'py> IntoPyObject<'py> + Send + 'static,
615{
616    let (cancel_tx, cancel_rx) = oneshot::channel();
617
618    let py_fut = create_future(locals.0.event_loop.bind(py).clone())?;
619    py_fut.call_method1(
620        pyo3::intern!(py, "add_done_callback"),
621        (PyDoneCallback {
622            cancel_tx: Some(cancel_tx),
623        },),
624    )?;
625
626    let future_tx1: Py<PyAny> = py_fut.clone().into();
627    let future_tx2 = future_tx1.clone_ref(py);
628
629    R::spawn(async move {
630        let locals2 = locals.clone();
631
632        if let Err(e) = R::spawn(async move {
633            let result = R::scope(
634                locals2.clone(),
635                Cancellable::new_with_cancel_rx(fut, cancel_rx),
636            )
637            .await;
638
639            // We should not hold GIL inside async-std/tokio event loop,
640            // because a blocked task may prevent other tasks from progressing.
641            R::spawn_blocking(|| {
642                Python::attach(move |py| {
643                    if cancelled(future_tx1.bind(py))
644                        .map_err(dump_err(py))
645                        .unwrap_or(false)
646                    {
647                        return;
648                    }
649
650                    let _ = set_result(
651                        &locals2.event_loop(py),
652                        future_tx1.bind(py),
653                        result.and_then(|val| val.into_py_any(py)),
654                    )
655                    .map_err(dump_err(py));
656                });
657            });
658        })
659        .await
660        {
661            if e.is_panic() {
662                R::spawn_blocking(|| {
663                    Python::attach(move |py| {
664                        if cancelled(future_tx2.bind(py))
665                            .map_err(dump_err(py))
666                            .unwrap_or(false)
667                        {
668                            return;
669                        }
670
671                        let panic_message = format!(
672                            "rust future panicked: {}",
673                            get_panic_message(&e.into_panic())
674                        );
675                        let _ = set_result(
676                            locals.0.event_loop.bind(py),
677                            future_tx2.bind(py),
678                            Err(RustPanic::new_err(panic_message)),
679                        )
680                        .map_err(dump_err(py));
681                    });
682                });
683            }
684        }
685    });
686
687    Ok(py_fut)
688}
689
690fn get_panic_message(any: &dyn std::any::Any) -> &str {
691    if let Some(str_slice) = any.downcast_ref::<&str>() {
692        str_slice
693    } else if let Some(string) = any.downcast_ref::<String>() {
694        string.as_str()
695    } else {
696        "unknown error"
697    }
698}
699
700pin_project! {
701    /// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at).
702    #[must_use = "futures do nothing unless you `.await` or poll them"]
703    #[derive(Debug)]
704    struct Cancellable<T> {
705        #[pin]
706        future: T,
707        #[pin]
708        cancel_rx: oneshot::Receiver<()>,
709
710        poll_cancel_rx: bool
711    }
712}
713
714impl<T> Cancellable<T> {
715    fn new_with_cancel_rx(future: T, cancel_rx: oneshot::Receiver<()>) -> Self {
716        Self {
717            future,
718            cancel_rx,
719
720            poll_cancel_rx: true,
721        }
722    }
723}
724
725impl<'py, F, T> Future for Cancellable<F>
726where
727    F: Future<Output = PyResult<T>>,
728    T: IntoPyObject<'py>,
729{
730    type Output = F::Output;
731
732    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
733        let this = self.project();
734
735        // First, try polling the future
736        if let Poll::Ready(v) = this.future.poll(cx) {
737            return Poll::Ready(v);
738        }
739
740        // Now check for cancellation
741        if *this.poll_cancel_rx {
742            match this.cancel_rx.poll(cx) {
743                Poll::Ready(Ok(())) => {
744                    *this.poll_cancel_rx = false;
745                    // The python future has already been cancelled, so this return value will never
746                    // be used.
747                    Poll::Ready(Err(pyo3::exceptions::PyBaseException::new_err(
748                        "unreachable",
749                    )))
750                }
751                Poll::Ready(Err(_)) => {
752                    *this.poll_cancel_rx = false;
753                    Poll::Pending
754                }
755                Poll::Pending => Poll::Pending,
756            }
757        } else {
758            Poll::Pending
759        }
760    }
761}
762
763#[pyclass]
764struct PyDoneCallback {
765    cancel_tx: Option<oneshot::Sender<()>>,
766}
767
768#[pymethods]
769impl PyDoneCallback {
770    pub fn __call__(&mut self, fut: &Bound<PyAny>) -> PyResult<()> {
771        let py = fut.py();
772
773        if cancelled(fut).map_err(dump_err(py)).unwrap_or(false) {
774            let _ = self.cancel_tx.take().unwrap().send(());
775        }
776
777        Ok(())
778    }
779}
780
781/// Convert a Rust Future into a Python awaitable with a generic runtime
782///
783/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
784/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
785///
786/// Python `contextvars` are preserved when calling async Python functions within the Rust future
787/// via [`into_future`] (new behaviour in `v0.15`).
788///
789/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
790/// > unfortunately fail to resolve them when called within the Rust future. This is because the
791/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
792/// >
793/// > As a workaround, you can get the `contextvars` from the current task locals using
794/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
795/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
796/// > synchronous function, and restore the previous context when it returns or raises an exception.
797///
798/// # Arguments
799/// * `py` - The current PyO3 GIL guard
800/// * `fut` - The Rust future to be converted
801///
802/// # Examples
803///
804/// ```no_run
805/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
806/// #
807/// # use pyo3_async_runtimes::{
808/// #     TaskLocals,
809/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
810/// # };
811/// #
812/// # struct MyCustomJoinError;
813/// #
814/// # impl JoinError for MyCustomJoinError {
815/// #     fn is_panic(&self) -> bool {
816/// #         unreachable!()
817/// #     }
818/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
819/// #         unreachable!()
820/// #     }
821/// # }
822/// #
823/// # struct MyCustomJoinHandle;
824/// #
825/// # impl Future for MyCustomJoinHandle {
826/// #     type Output = Result<(), MyCustomJoinError>;
827/// #
828/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
829/// #         unreachable!()
830/// #     }
831/// # }
832/// #
833/// # struct MyCustomRuntime;
834/// #
835/// # impl MyCustomRuntime {
836/// #     async fn sleep(_: Duration) {
837/// #         unreachable!()
838/// #     }
839/// # }
840/// #
841/// # impl Runtime for MyCustomRuntime {
842/// #     type JoinError = MyCustomJoinError;
843/// #     type JoinHandle = MyCustomJoinHandle;
844/// #
845/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
846/// #     where
847/// #         F: Future<Output = ()> + Send + 'static
848/// #     {
849/// #         unreachable!()
850/// #     }
851/// #
852/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
853/// #         unreachable!()
854/// #     }
855/// # }
856/// #
857/// # impl ContextExt for MyCustomRuntime {
858/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
859/// #     where
860/// #         F: Future<Output = R> + Send + 'static
861/// #     {
862/// #         unreachable!()
863/// #     }
864/// #     fn get_task_locals() -> Option<TaskLocals> {
865/// #         unreachable!()
866/// #     }
867/// # }
868/// #
869/// use std::time::Duration;
870///
871/// use pyo3::prelude::*;
872///
873/// /// Awaitable sleep function
874/// #[pyfunction]
875/// fn sleep_for<'p>(py: Python<'p>, secs: Bound<'p, PyAny>) -> PyResult<Bound<'p, PyAny>> {
876///     let secs = secs.extract()?;
877///     pyo3_async_runtimes::generic::future_into_py::<MyCustomRuntime, _, _>(py, async move {
878///         MyCustomRuntime::sleep(Duration::from_secs(secs)).await;
879///         Ok(())
880///     })
881/// }
882/// ```
883pub fn future_into_py<R, F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
884where
885    R: Runtime + ContextExt,
886    F: Future<Output = PyResult<T>> + Send + 'static,
887    T: for<'py> IntoPyObject<'py> + Send + 'static,
888{
889    future_into_py_with_locals::<R, F, T>(py, get_current_locals::<R>(py)?, fut)
890}
891
892/// Convert a `!Send` Rust Future into a Python awaitable with a generic runtime and manual
893/// specification of task locals.
894///
895/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
896/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
897///
898/// Python `contextvars` are preserved when calling async Python functions within the Rust future
899/// via [`into_future`] (new behaviour in `v0.15`).
900///
901/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
902/// > unfortunately fail to resolve them when called within the Rust future. This is because the
903/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
904/// >
905/// > As a workaround, you can get the `contextvars` from the current task locals using
906/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
907/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
908/// > synchronous function, and restore the previous context when it returns or raises an exception.
909///
910/// # Arguments
911/// * `py` - PyO3 GIL guard
912/// * `locals` - The task locals for the future
913/// * `fut` - The Rust future to be converted
914///
915/// # Examples
916///
917/// ```no_run
918/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
919/// #
920/// # use pyo3_async_runtimes::{
921/// #     TaskLocals,
922/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
923/// # };
924/// #
925/// # struct MyCustomJoinError;
926/// #
927/// # impl JoinError for MyCustomJoinError {
928/// #     fn is_panic(&self) -> bool {
929/// #         unreachable!()
930/// #     }
931/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
932/// #         unreachable!()
933/// #     }
934/// # }
935/// #
936/// # struct MyCustomJoinHandle;
937/// #
938/// # impl Future for MyCustomJoinHandle {
939/// #     type Output = Result<(), MyCustomJoinError>;
940/// #
941/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
942/// #         unreachable!()
943/// #     }
944/// # }
945/// #
946/// # struct MyCustomRuntime;
947/// #
948/// # impl MyCustomRuntime {
949/// #     async fn sleep(_: Duration) {
950/// #         unreachable!()
951/// #     }
952/// # }
953/// #
954/// # impl Runtime for MyCustomRuntime {
955/// #     type JoinError = MyCustomJoinError;
956/// #     type JoinHandle = MyCustomJoinHandle;
957/// #
958/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
959/// #     where
960/// #         F: Future<Output = ()> + Send + 'static
961/// #     {
962/// #         unreachable!()
963/// #     }
964/// #
965/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
966/// #         unreachable!()
967/// #     }
968/// # }
969/// #
970/// # impl ContextExt for MyCustomRuntime {
971/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
972/// #     where
973/// #         F: Future<Output = R> + Send + 'static
974/// #     {
975/// #         unreachable!()
976/// #     }
977/// #     fn get_task_locals() -> Option<TaskLocals> {
978/// #         unreachable!()
979/// #     }
980/// # }
981/// #
982/// # impl SpawnLocalExt for MyCustomRuntime {
983/// #     fn spawn_local<F>(fut: F) -> Self::JoinHandle
984/// #     where
985/// #         F: Future<Output = ()> + 'static
986/// #     {
987/// #         unreachable!()
988/// #     }
989/// # }
990/// #
991/// # impl LocalContextExt for MyCustomRuntime {
992/// #     fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
993/// #     where
994/// #         F: Future<Output = R> + 'static
995/// #     {
996/// #         unreachable!()
997/// #     }
998/// # }
999/// #
1000/// use std::{rc::Rc, time::Duration};
1001///
1002/// use pyo3::prelude::*;
1003///
1004/// /// Awaitable sleep function
1005/// #[pyfunction]
1006/// fn sleep_for(py: Python, secs: u64) -> PyResult<Bound<PyAny>> {
1007///     // Rc is !Send so it cannot be passed into pyo3_async_runtimes::generic::future_into_py
1008///     let secs = Rc::new(secs);
1009///
1010///     pyo3_async_runtimes::generic::local_future_into_py_with_locals::<MyCustomRuntime, _, _>(
1011///         py,
1012///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
1013///         async move {
1014///             MyCustomRuntime::sleep(Duration::from_secs(*secs)).await;
1015///             Ok(())
1016///         }
1017///     )
1018/// }
1019/// ```
1020#[deprecated(
1021    since = "0.18.0",
1022    note = "Questionable whether these conversions have real-world utility (see https://github.com/awestlake87/pyo3-asyncio/issues/59#issuecomment-1008038497 and let me know if you disagree!)"
1023)]
1024#[allow(unused_must_use)]
1025pub fn local_future_into_py_with_locals<R, F, T>(
1026    py: Python,
1027    locals: TaskLocals,
1028    fut: F,
1029) -> PyResult<Bound<PyAny>>
1030where
1031    R: Runtime + SpawnLocalExt + LocalContextExt,
1032    F: Future<Output = PyResult<T>> + 'static,
1033    T: for<'py> IntoPyObject<'py>,
1034{
1035    let (cancel_tx, cancel_rx) = oneshot::channel();
1036
1037    let py_fut = create_future(locals.0.event_loop.clone_ref(py).into_bound(py))?;
1038    py_fut.call_method1(
1039        pyo3::intern!(py, "add_done_callback"),
1040        (PyDoneCallback {
1041            cancel_tx: Some(cancel_tx),
1042        },),
1043    )?;
1044
1045    let future_tx1: Py<PyAny> = py_fut.clone().into();
1046    let future_tx2 = future_tx1.clone_ref(py);
1047
1048    R::spawn_local(async move {
1049        let locals2 = locals.clone();
1050
1051        if let Err(e) = R::spawn_local(async move {
1052            let result = R::scope_local(
1053                locals2.clone(),
1054                Cancellable::new_with_cancel_rx(fut, cancel_rx),
1055            )
1056            .await;
1057
1058            Python::attach(move |py| {
1059                if cancelled(future_tx1.bind(py))
1060                    .map_err(dump_err(py))
1061                    .unwrap_or(false)
1062                {
1063                    return;
1064                }
1065
1066                let _ = set_result(
1067                    locals2.0.event_loop.bind(py),
1068                    future_tx1.bind(py),
1069                    result.and_then(|val| val.into_py_any(py)),
1070                )
1071                .map_err(dump_err(py));
1072            });
1073        })
1074        .await
1075        {
1076            if e.is_panic() {
1077                Python::attach(move |py| {
1078                    if cancelled(future_tx2.bind(py))
1079                        .map_err(dump_err(py))
1080                        .unwrap_or(false)
1081                    {
1082                        return;
1083                    }
1084
1085                    let panic_message = format!(
1086                        "rust future panicked: {}",
1087                        get_panic_message(&e.into_panic())
1088                    );
1089                    let _ = set_result(
1090                        locals.0.event_loop.bind(py),
1091                        future_tx2.bind(py),
1092                        Err(RustPanic::new_err(panic_message)),
1093                    )
1094                    .map_err(dump_err(py));
1095                });
1096            }
1097        }
1098    });
1099
1100    Ok(py_fut)
1101}
1102
1103/// Convert a `!Send` Rust Future into a Python awaitable with a generic runtime
1104///
1105/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
1106/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
1107///
1108/// Python `contextvars` are preserved when calling async Python functions within the Rust future
1109/// via [`into_future`] (new behaviour in `v0.15`).
1110///
1111/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
1112/// > unfortunately fail to resolve them when called within the Rust future. This is because the
1113/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
1114/// >
1115/// > As a workaround, you can get the `contextvars` from the current task locals using
1116/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
1117/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
1118/// > synchronous function, and restore the previous context when it returns or raises an exception.
1119///
1120/// # Arguments
1121/// * `py` - The current PyO3 GIL guard
1122/// * `fut` - The Rust future to be converted
1123///
1124/// # Examples
1125///
1126/// ```no_run
1127/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1128/// #
1129/// # use pyo3_async_runtimes::{
1130/// #     TaskLocals,
1131/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
1132/// # };
1133/// #
1134/// # struct MyCustomJoinError;
1135/// #
1136/// # impl JoinError for MyCustomJoinError {
1137/// #     fn is_panic(&self) -> bool {
1138/// #         unreachable!()
1139/// #     }
1140/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1141/// #         unreachable!()
1142/// #     }
1143/// # }
1144/// #
1145/// # struct MyCustomJoinHandle;
1146/// #
1147/// # impl Future for MyCustomJoinHandle {
1148/// #     type Output = Result<(), MyCustomJoinError>;
1149/// #
1150/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1151/// #         unreachable!()
1152/// #     }
1153/// # }
1154/// #
1155/// # struct MyCustomRuntime;
1156/// #
1157/// # impl MyCustomRuntime {
1158/// #     async fn sleep(_: Duration) {
1159/// #         unreachable!()
1160/// #     }
1161/// # }
1162/// #
1163/// # impl Runtime for MyCustomRuntime {
1164/// #     type JoinError = MyCustomJoinError;
1165/// #     type JoinHandle = MyCustomJoinHandle;
1166/// #
1167/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1168/// #     where
1169/// #         F: Future<Output = ()> + Send + 'static
1170/// #     {
1171/// #         unreachable!()
1172/// #     }
1173/// #
1174/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
1175/// #         unreachable!()
1176/// #     }
1177/// # }
1178/// #
1179/// # impl ContextExt for MyCustomRuntime {
1180/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1181/// #     where
1182/// #         F: Future<Output = R> + Send + 'static
1183/// #     {
1184/// #         unreachable!()
1185/// #     }
1186/// #     fn get_task_locals() -> Option<TaskLocals> {
1187/// #         unreachable!()
1188/// #     }
1189/// # }
1190/// #
1191/// # impl SpawnLocalExt for MyCustomRuntime {
1192/// #     fn spawn_local<F>(fut: F) -> Self::JoinHandle
1193/// #     where
1194/// #         F: Future<Output = ()> + 'static
1195/// #     {
1196/// #         unreachable!()
1197/// #     }
1198/// # }
1199/// #
1200/// # impl LocalContextExt for MyCustomRuntime {
1201/// #     fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
1202/// #     where
1203/// #         F: Future<Output = R> + 'static
1204/// #     {
1205/// #         unreachable!()
1206/// #     }
1207/// # }
1208/// #
1209/// use std::{rc::Rc, time::Duration};
1210///
1211/// use pyo3::prelude::*;
1212///
1213/// /// Awaitable sleep function
1214/// #[pyfunction]
1215/// fn sleep_for(py: Python, secs: u64) -> PyResult<Bound<PyAny>> {
1216///     // Rc is !Send so it cannot be passed into pyo3_async_runtimes::generic::future_into_py
1217///     let secs = Rc::new(secs);
1218///
1219///     pyo3_async_runtimes::generic::local_future_into_py::<MyCustomRuntime, _, _>(py, async move {
1220///         MyCustomRuntime::sleep(Duration::from_secs(*secs)).await;
1221///         Ok(())
1222///     })
1223/// }
1224/// ```
1225#[deprecated(
1226    since = "0.18.0",
1227    note = "Questionable whether these conversions have real-world utility (see https://github.com/awestlake87/pyo3-asyncio/issues/59#issuecomment-1008038497 and let me know if you disagree!)"
1228)]
1229#[allow(deprecated)]
1230pub fn local_future_into_py<R, F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
1231where
1232    R: Runtime + ContextExt + SpawnLocalExt + LocalContextExt,
1233    F: Future<Output = PyResult<T>> + 'static,
1234    T: for<'py> IntoPyObject<'py>,
1235{
1236    local_future_into_py_with_locals::<R, F, T>(py, get_current_locals::<R>(py)?, fut)
1237}
1238
1239/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1240///
1241/// **This API is marked as unstable** and is only available when the
1242/// `unstable-streams` crate feature is enabled. This comes with no
1243/// stability guarantees, and could be changed or removed at any time.
1244///
1245/// # Arguments
1246/// * `locals` - The current task locals
1247/// * `gen` - The Python async generator to be converted
1248///
1249/// # Examples
1250/// ```no_run
1251/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1252/// #
1253/// # use pyo3_async_runtimes::{
1254/// #     TaskLocals,
1255/// #     generic::{JoinError, ContextExt, Runtime}
1256/// # };
1257/// #
1258/// # struct MyCustomJoinError;
1259/// #
1260/// # impl JoinError for MyCustomJoinError {
1261/// #     fn is_panic(&self) -> bool {
1262/// #         unreachable!()
1263/// #     }
1264/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1265/// #         unreachable!()
1266/// #     }
1267/// # }
1268/// #
1269/// # struct MyCustomJoinHandle;
1270/// #
1271/// # impl Future for MyCustomJoinHandle {
1272/// #     type Output = Result<(), MyCustomJoinError>;
1273/// #
1274/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1275/// #         unreachable!()
1276/// #     }
1277/// # }
1278/// #
1279/// # struct MyCustomRuntime;
1280/// #
1281/// # impl Runtime for MyCustomRuntime {
1282/// #     type JoinError = MyCustomJoinError;
1283/// #     type JoinHandle = MyCustomJoinHandle;
1284/// #
1285/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1286/// #     where
1287/// #         F: Future<Output = ()> + Send + 'static
1288/// #     {
1289/// #         unreachable!()
1290/// #     }
1291/// #
1292/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
1293/// #         unreachable!()
1294/// #     }
1295/// # }
1296/// #
1297/// # impl ContextExt for MyCustomRuntime {
1298/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1299/// #     where
1300/// #         F: Future<Output = R> + Send + 'static
1301/// #     {
1302/// #         unreachable!()
1303/// #     }
1304/// #     fn get_task_locals() -> Option<TaskLocals> {
1305/// #         unreachable!()
1306/// #     }
1307/// # }
1308///
1309/// use pyo3::prelude::*;
1310/// use futures::{StreamExt, TryStreamExt};
1311/// use std::ffi::CString;
1312///
1313/// const TEST_MOD: &str = r#"
1314/// import asyncio
1315///
1316/// async def gen():
1317///     for i in range(10):
1318///         await asyncio.sleep(0.1)
1319///         yield i
1320/// "#;
1321///
1322/// # async fn test_async_gen() -> PyResult<()> {
1323/// let stream = Python::attach(|py| {
1324///     let test_mod = PyModule::from_code(
1325///         py,
1326///         &CString::new(TEST_MOD).unwrap(),
1327///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1328///         &CString::new("test_mod").unwrap(),
1329///     )?;
1330///
1331///     pyo3_async_runtimes::generic::into_stream_with_locals_v1::<MyCustomRuntime>(
1332///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
1333///         test_mod.call_method0("gen")?
1334///     )
1335/// })?;
1336///
1337/// let vals = stream
1338///     .map(|item| Python::attach(|py| -> PyResult<i32> { Ok(item?.bind(py).extract()?) }))
1339///     .try_collect::<Vec<i32>>()
1340///     .await?;
1341///
1342/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1343///
1344/// Ok(())
1345/// # }
1346/// ```
1347#[cfg(feature = "unstable-streams")]
1348#[allow(unused_must_use)] // False positive unused lint on `R::spawn`
1349pub fn into_stream_with_locals_v1<R>(
1350    locals: TaskLocals,
1351    gen: Bound<'_, PyAny>,
1352) -> PyResult<impl futures::Stream<Item = PyResult<Py<PyAny>>> + 'static>
1353where
1354    R: Runtime,
1355{
1356    let (tx, rx) = async_channel::bounded(1);
1357    let py = gen.py();
1358    let anext: Py<PyAny> = gen.getattr(pyo3::intern!(py, "__anext__"))?.into();
1359
1360    R::spawn(async move {
1361        loop {
1362            let fut = Python::attach(|py| -> PyResult<_> {
1363                into_future_with_locals(&locals, anext.bind(py).call0()?)
1364            });
1365            let item = match fut {
1366                Ok(fut) => match fut.await {
1367                    Ok(item) => Ok(item),
1368                    Err(e) => {
1369                        let stop_iter = Python::attach(|py| {
1370                            e.is_instance_of::<pyo3::exceptions::PyStopAsyncIteration>(py)
1371                        });
1372
1373                        if stop_iter {
1374                            // end the iteration
1375                            break;
1376                        } else {
1377                            Err(e)
1378                        }
1379                    }
1380                },
1381                Err(e) => Err(e),
1382            };
1383
1384            if tx.send(item).await.is_err() {
1385                // receiving side was dropped
1386                break;
1387            }
1388        }
1389    });
1390
1391    Ok(rx)
1392}
1393
1394/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1395///
1396/// **This API is marked as unstable** and is only available when the
1397/// `unstable-streams` crate feature is enabled. This comes with no
1398/// stability guarantees, and could be changed or removed at any time.
1399///
1400/// # Arguments
1401/// * `gen` - The Python async generator to be converted
1402///
1403/// # Examples
1404/// ```no_run
1405/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1406/// #
1407/// # use pyo3_async_runtimes::{
1408/// #     TaskLocals,
1409/// #     generic::{JoinError, ContextExt, Runtime}
1410/// # };
1411/// #
1412/// # struct MyCustomJoinError;
1413/// #
1414/// # impl JoinError for MyCustomJoinError {
1415/// #     fn is_panic(&self) -> bool {
1416/// #         unreachable!()
1417/// #     }
1418/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1419/// #         unreachable!()
1420/// #     }
1421/// # }
1422/// #
1423/// # struct MyCustomJoinHandle;
1424/// #
1425/// # impl Future for MyCustomJoinHandle {
1426/// #     type Output = Result<(), MyCustomJoinError>;
1427/// #
1428/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1429/// #         unreachable!()
1430/// #     }
1431/// # }
1432/// #
1433/// # struct MyCustomRuntime;
1434/// #
1435/// # impl Runtime for MyCustomRuntime {
1436/// #     type JoinError = MyCustomJoinError;
1437/// #     type JoinHandle = MyCustomJoinHandle;
1438/// #
1439/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1440/// #     where
1441/// #         F: Future<Output = ()> + Send + 'static
1442/// #     {
1443/// #         unreachable!()
1444/// #     }
1445/// #
1446/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
1447/// #         unreachable!()
1448/// #     }
1449/// # }
1450/// #
1451/// # impl ContextExt for MyCustomRuntime {
1452/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1453/// #     where
1454/// #         F: Future<Output = R> + Send + 'static
1455/// #     {
1456/// #         unreachable!()
1457/// #     }
1458/// #     fn get_task_locals() -> Option<TaskLocals> {
1459/// #         unreachable!()
1460/// #     }
1461/// # }
1462///
1463/// use pyo3::prelude::*;
1464/// use futures::{StreamExt, TryStreamExt};
1465/// use std::ffi::CString;
1466///
1467/// const TEST_MOD: &str = r#"
1468/// import asyncio
1469///
1470/// async def gen():
1471///     for i in range(10):
1472///         await asyncio.sleep(0.1)
1473///         yield i
1474/// "#;
1475///
1476/// # async fn test_async_gen() -> PyResult<()> {
1477/// let stream = Python::attach(|py| {
1478///     let test_mod = PyModule::from_code(
1479///         py,
1480///         &CString::new(TEST_MOD).unwrap(),
1481///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1482///         &CString::new("test_mod").unwrap(),
1483///     )?;
1484///
1485///     pyo3_async_runtimes::generic::into_stream_v1::<MyCustomRuntime>(test_mod.call_method0("gen")?)
1486/// })?;
1487///
1488/// let vals = stream
1489///     .map(|item| Python::attach(|py| -> PyResult<i32> { Ok(item?.bind(py).extract()?) }))
1490///     .try_collect::<Vec<i32>>()
1491///     .await?;
1492///
1493/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1494///
1495/// Ok(())
1496/// # }
1497/// ```
1498#[cfg(feature = "unstable-streams")]
1499pub fn into_stream_v1<R>(
1500    gen: Bound<'_, PyAny>,
1501) -> PyResult<impl futures::Stream<Item = PyResult<Py<PyAny>>> + 'static>
1502where
1503    R: Runtime + ContextExt,
1504{
1505    into_stream_with_locals_v1::<R>(get_current_locals::<R>(gen.py())?, gen)
1506}
1507
1508trait Sender: Send + 'static {
1509    fn send(&mut self, py: Python, locals: TaskLocals, item: Py<PyAny>) -> PyResult<Py<PyAny>>;
1510    fn close(&mut self) -> PyResult<()>;
1511}
1512
1513#[cfg(feature = "unstable-streams")]
1514struct GenericSender<R>
1515where
1516    R: Runtime,
1517{
1518    runtime: PhantomData<R>,
1519    tx: mpsc::Sender<Py<PyAny>>,
1520}
1521
1522#[cfg(feature = "unstable-streams")]
1523impl<R> Sender for GenericSender<R>
1524where
1525    R: Runtime + ContextExt,
1526{
1527    fn send(&mut self, py: Python, locals: TaskLocals, item: Py<PyAny>) -> PyResult<Py<PyAny>> {
1528        match self.tx.try_send(item.clone_ref(py)) {
1529            Ok(_) => true.into_py_any(py),
1530            Err(e) => {
1531                if e.is_full() {
1532                    let mut tx = self.tx.clone();
1533
1534                    future_into_py_with_locals::<R, _, bool>(py, locals, async move {
1535                        if tx.flush().await.is_err() {
1536                            // receiving side disconnected
1537                            return Ok(false);
1538                        }
1539                        if tx.send(item).await.is_err() {
1540                            // receiving side disconnected
1541                            return Ok(false);
1542                        }
1543                        Ok(true)
1544                    })
1545                    .map(Bound::unbind)
1546                } else {
1547                    false.into_py_any(py)
1548                }
1549            }
1550        }
1551    }
1552    fn close(&mut self) -> PyResult<()> {
1553        self.tx.close_channel();
1554        Ok(())
1555    }
1556}
1557
1558#[pyclass]
1559struct SenderGlue {
1560    locals: TaskLocals,
1561    tx: Arc<Mutex<dyn Sender>>,
1562}
1563#[pymethods]
1564impl SenderGlue {
1565    pub fn send(&mut self, item: Py<PyAny>) -> PyResult<Py<PyAny>> {
1566        Python::attach(|py| self.tx.lock().unwrap().send(py, self.locals.clone(), item))
1567    }
1568    pub fn close(&mut self) -> PyResult<()> {
1569        self.tx.lock().unwrap().close()
1570    }
1571}
1572
1573#[cfg(feature = "unstable-streams")]
1574const STREAM_GLUE: &str = r#"
1575import asyncio
1576
1577async def forward(gen, sender):
1578    async for item in gen:
1579        should_continue = sender.send(item)
1580
1581        if asyncio.iscoroutine(should_continue):
1582            should_continue = await should_continue
1583
1584        if should_continue:
1585            continue
1586        else:
1587            break
1588
1589    sender.close()
1590"#;
1591
1592/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1593///
1594/// **This API is marked as unstable** and is only available when the
1595/// `unstable-streams` crate feature is enabled. This comes with no
1596/// stability guarantees, and could be changed or removed at any time.
1597///
1598/// # Arguments
1599/// * `locals` - The current task locals
1600/// * `gen` - The Python async generator to be converted
1601///
1602/// # Examples
1603/// ```no_run
1604/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1605/// #
1606/// # use pyo3_async_runtimes::{
1607/// #     TaskLocals,
1608/// #     generic::{JoinError, ContextExt, Runtime}
1609/// # };
1610/// #
1611/// # struct MyCustomJoinError;
1612/// #
1613/// # impl JoinError for MyCustomJoinError {
1614/// #     fn is_panic(&self) -> bool {
1615/// #         unreachable!()
1616/// #     }
1617/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1618/// #         unreachable!()
1619/// #     }
1620/// # }
1621/// #
1622/// # struct MyCustomJoinHandle;
1623/// #
1624/// # impl Future for MyCustomJoinHandle {
1625/// #     type Output = Result<(), MyCustomJoinError>;
1626/// #
1627/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1628/// #         unreachable!()
1629/// #     }
1630/// # }
1631/// #
1632/// # struct MyCustomRuntime;
1633/// #
1634/// # impl Runtime for MyCustomRuntime {
1635/// #     type JoinError = MyCustomJoinError;
1636/// #     type JoinHandle = MyCustomJoinHandle;
1637/// #
1638/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1639/// #     where
1640/// #         F: Future<Output = ()> + Send + 'static
1641/// #     {
1642/// #         unreachable!()
1643/// #     }
1644/// #
1645/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
1646/// #         unreachable!()
1647/// #     }
1648/// # }
1649/// #
1650/// # impl ContextExt for MyCustomRuntime {
1651/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1652/// #     where
1653/// #         F: Future<Output = R> + Send + 'static
1654/// #     {
1655/// #         unreachable!()
1656/// #     }
1657/// #     fn get_task_locals() -> Option<TaskLocals> {
1658/// #         unreachable!()
1659/// #     }
1660/// # }
1661///
1662/// use pyo3::prelude::*;
1663/// use futures::{StreamExt, TryStreamExt};
1664/// use std::ffi::CString;
1665///
1666/// const TEST_MOD: &str = r#"
1667/// import asyncio
1668///
1669/// async def gen():
1670///     for i in range(10):
1671///         await asyncio.sleep(0.1)
1672///         yield i
1673/// "#;
1674///
1675/// # async fn test_async_gen() -> PyResult<()> {
1676/// let stream = Python::attach(|py| {
1677///     let test_mod = PyModule::from_code(
1678///         py,
1679///         &CString::new(TEST_MOD).unwrap(),
1680///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1681///         &CString::new("test_mod").unwrap(),
1682///     )?;
1683///
1684///     pyo3_async_runtimes::generic::into_stream_with_locals_v2::<MyCustomRuntime>(
1685///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
1686///         test_mod.call_method0("gen")?
1687///     )
1688/// })?;
1689///
1690/// let vals = stream
1691///     .map(|item| Python::attach(|py| -> PyResult<i32> { Ok(item.bind(py).extract()?) }))
1692///     .try_collect::<Vec<i32>>()
1693///     .await?;
1694///
1695/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1696///
1697/// Ok(())
1698/// # }
1699/// ```
1700#[cfg(feature = "unstable-streams")]
1701pub fn into_stream_with_locals_v2<R>(
1702    locals: TaskLocals,
1703    gen: Bound<'_, PyAny>,
1704) -> PyResult<impl futures::Stream<Item = Py<PyAny>> + 'static>
1705where
1706    R: Runtime + ContextExt,
1707{
1708    use std::ffi::CString;
1709
1710    use pyo3::sync::PyOnceLock;
1711
1712    static GLUE_MOD: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
1713    let py = gen.py();
1714    let glue = GLUE_MOD
1715        .get_or_try_init(py, || -> PyResult<Py<PyAny>> {
1716            Ok(PyModule::from_code(
1717                py,
1718                &CString::new(STREAM_GLUE).unwrap(),
1719                &CString::new("pyo3_async_runtimes/pyo3_async_runtimes_glue.py").unwrap(),
1720                &CString::new("pyo3_async_runtimes_glue").unwrap(),
1721            )?
1722            .into())
1723        })?
1724        .bind(py);
1725
1726    let (tx, rx) = mpsc::channel(10);
1727
1728    locals.event_loop(py).call_method1(
1729        pyo3::intern!(py, "call_soon_threadsafe"),
1730        (
1731            locals
1732                .event_loop(py)
1733                .getattr(pyo3::intern!(py, "create_task"))?,
1734            glue.call_method1(
1735                pyo3::intern!(py, "forward"),
1736                (
1737                    gen,
1738                    SenderGlue {
1739                        locals,
1740                        tx: Arc::new(Mutex::new(GenericSender {
1741                            runtime: PhantomData::<R>,
1742                            tx,
1743                        })),
1744                    },
1745                ),
1746            )?,
1747        ),
1748    )?;
1749    Ok(rx)
1750}
1751
1752/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1753///
1754/// **This API is marked as unstable** and is only available when the
1755/// `unstable-streams` crate feature is enabled. This comes with no
1756/// stability guarantees, and could be changed or removed at any time.
1757///
1758/// # Arguments
1759/// * `gen` - The Python async generator to be converted
1760///
1761/// # Examples
1762/// ```no_run
1763/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1764/// #
1765/// # use pyo3_async_runtimes::{
1766/// #     TaskLocals,
1767/// #     generic::{JoinError, ContextExt, Runtime}
1768/// # };
1769/// #
1770/// # struct MyCustomJoinError;
1771/// #
1772/// # impl JoinError for MyCustomJoinError {
1773/// #     fn is_panic(&self) -> bool {
1774/// #         unreachable!()
1775/// #     }
1776/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1777/// #         unreachable!()
1778/// #     }
1779/// # }
1780/// #
1781/// # struct MyCustomJoinHandle;
1782/// #
1783/// # impl Future for MyCustomJoinHandle {
1784/// #     type Output = Result<(), MyCustomJoinError>;
1785/// #
1786/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1787/// #         unreachable!()
1788/// #     }
1789/// # }
1790/// #
1791/// # struct MyCustomRuntime;
1792/// #
1793/// # impl Runtime for MyCustomRuntime {
1794/// #     type JoinError = MyCustomJoinError;
1795/// #     type JoinHandle = MyCustomJoinHandle;
1796/// #
1797/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1798/// #     where
1799/// #         F: Future<Output = ()> + Send + 'static
1800/// #     {
1801/// #         unreachable!()
1802/// #     }
1803/// #
1804/// #     fn spawn_blocking<F>(f: F) -> Self::JoinHandle where F: FnOnce() + Send + 'static {
1805/// #         unreachable!()
1806/// #     }
1807/// # }
1808/// #
1809/// # impl ContextExt for MyCustomRuntime {
1810/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1811/// #     where
1812/// #         F: Future<Output = R> + Send + 'static
1813/// #     {
1814/// #         unreachable!()
1815/// #     }
1816/// #     fn get_task_locals() -> Option<TaskLocals> {
1817/// #         unreachable!()
1818/// #     }
1819/// # }
1820///
1821/// use pyo3::prelude::*;
1822/// use futures::{StreamExt, TryStreamExt};
1823/// use std::ffi::CString;
1824///
1825/// const TEST_MOD: &str = r#"
1826/// import asyncio
1827///
1828/// async def gen():
1829///     for i in range(10):
1830///         await asyncio.sleep(0.1)
1831///         yield i
1832/// "#;
1833///
1834/// # async fn test_async_gen() -> PyResult<()> {
1835/// let stream = Python::attach(|py| {
1836///     let test_mod = PyModule::from_code(
1837///         py,
1838///         &CString::new(TEST_MOD).unwrap(),
1839///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1840///         &CString::new("test_mod").unwrap(),
1841///     )?;
1842///
1843///     pyo3_async_runtimes::generic::into_stream_v2::<MyCustomRuntime>(test_mod.call_method0("gen")?)
1844/// })?;
1845///
1846/// let vals = stream
1847///     .map(|item| Python::attach(|py| -> PyResult<i32> { Ok(item.bind(py).extract()?) }))
1848///     .try_collect::<Vec<i32>>()
1849///     .await?;
1850///
1851/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1852///
1853/// Ok(())
1854/// # }
1855/// ```
1856#[cfg(feature = "unstable-streams")]
1857pub fn into_stream_v2<R>(
1858    gen: Bound<'_, PyAny>,
1859) -> PyResult<impl futures::Stream<Item = Py<PyAny>> + 'static>
1860where
1861    R: Runtime + ContextExt,
1862{
1863    into_stream_with_locals_v2::<R>(get_current_locals::<R>(gen.py())?, gen)
1864}