pyo3_async_runtimes/
generic.rs

1//! Generic implementations of PyO3 Asyncio utilities that can be used for any Rust runtime
2//!
3//! Items marked with
4//! <span
5//!   class="module-item stab portability"
6//!   style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"
7//! ><code>unstable-streams</code></span>
8//! > are only available when the `unstable-streams` Cargo feature is enabled:
9//!
10//! ```toml
11//! [dependencies.pyo3-async-runtimes]
12//! version = "0.24"
13//! features = ["unstable-streams"]
14//! ```
15
16use std::{
17    future::Future,
18    pin::Pin,
19    sync::{Arc, Mutex},
20    task::{Context, Poll},
21};
22
23use crate::{
24    asyncio, call_soon_threadsafe, close, create_future, dump_err, err::RustPanic,
25    get_running_loop, into_future_with_locals, TaskLocals,
26};
27use futures::channel::oneshot;
28#[cfg(feature = "unstable-streams")]
29use futures::{channel::mpsc, SinkExt};
30#[cfg(feature = "unstable-streams")]
31use once_cell::sync::OnceCell;
32use pin_project_lite::pin_project;
33use pyo3::prelude::*;
34use pyo3::BoundObject;
35#[cfg(feature = "unstable-streams")]
36use std::marker::PhantomData;
37
38/// Generic utilities for a JoinError
39pub trait JoinError {
40    /// Check if the spawned task exited because of a panic
41    fn is_panic(&self) -> bool;
42    /// Get the panic object associated with the error.  Panics if `is_panic` is not true.
43    fn into_panic(self) -> Box<dyn std::any::Any + Send + 'static>;
44}
45
46/// Generic Rust async/await runtime
47pub trait Runtime: Send + 'static {
48    /// The error returned by a JoinHandle after being awaited
49    type JoinError: JoinError + Send;
50    /// A future that completes with the result of the spawned task
51    type JoinHandle: Future<Output = Result<(), Self::JoinError>> + Send;
52
53    /// Spawn a future onto this runtime's event loop
54    fn spawn<F>(fut: F) -> Self::JoinHandle
55    where
56        F: Future<Output = ()> + Send + 'static;
57}
58
59/// Extension trait for async/await runtimes that support spawning local tasks
60pub trait SpawnLocalExt: Runtime {
61    /// Spawn a !Send future onto this runtime's event loop
62    fn spawn_local<F>(fut: F) -> Self::JoinHandle
63    where
64        F: Future<Output = ()> + 'static;
65}
66
67/// Exposes the utilities necessary for using task-local data in the Runtime
68pub trait ContextExt: Runtime {
69    /// Set the task locals for the given future
70    fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
71    where
72        F: Future<Output = R> + Send + 'static;
73
74    /// Get the task locals for the current task
75    fn get_task_locals() -> Option<TaskLocals>;
76}
77
78/// Adds the ability to scope task-local data for !Send futures
79pub trait LocalContextExt: Runtime {
80    /// Set the task locals for the given !Send future
81    fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
82    where
83        F: Future<Output = R> + 'static;
84}
85
86/// Get the current event loop from either Python or Rust async task local context
87///
88/// This function first checks if the runtime has a task-local reference to the Python event loop.
89/// If not, it calls [`get_running_loop`](crate::get_running_loop`) to get the event loop associated
90/// with the current OS thread.
91pub fn get_current_loop<R>(py: Python) -> PyResult<Bound<PyAny>>
92where
93    R: ContextExt,
94{
95    if let Some(locals) = R::get_task_locals() {
96        Ok(locals.event_loop.into_bound(py))
97    } else {
98        get_running_loop(py)
99    }
100}
101
102/// Either copy the task locals from the current task OR get the current running loop and
103/// contextvars from Python.
104pub fn get_current_locals<R>(py: Python) -> PyResult<TaskLocals>
105where
106    R: ContextExt,
107{
108    if let Some(locals) = R::get_task_locals() {
109        Ok(locals)
110    } else {
111        Ok(TaskLocals::with_running_loop(py)?.copy_context(py)?)
112    }
113}
114
115/// Run the event loop until the given Future completes
116///
117/// After this function returns, the event loop can be resumed with [`run_until_complete`]
118///
119/// # Arguments
120/// * `event_loop` - The Python event loop that should run the future
121/// * `fut` - The future to drive to completion
122///
123/// # Examples
124///
125/// ```no_run
126/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
127/// #
128/// # use pyo3_async_runtimes::{
129/// #     TaskLocals,
130/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
131/// # };
132/// #
133/// # struct MyCustomJoinError;
134/// #
135/// # impl JoinError for MyCustomJoinError {
136/// #     fn is_panic(&self) -> bool {
137/// #         unreachable!()
138/// #     }
139/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
140/// #         unreachable!()
141/// #     }
142/// # }
143/// #
144/// # struct MyCustomJoinHandle;
145/// #
146/// # impl Future for MyCustomJoinHandle {
147/// #     type Output = Result<(), MyCustomJoinError>;
148/// #
149/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
150/// #         unreachable!()
151/// #     }
152/// # }
153/// #
154/// # struct MyCustomRuntime;
155/// #
156/// # impl Runtime for MyCustomRuntime {
157/// #     type JoinError = MyCustomJoinError;
158/// #     type JoinHandle = MyCustomJoinHandle;
159/// #
160/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
161/// #     where
162/// #         F: Future<Output = ()> + Send + 'static
163/// #     {
164/// #         unreachable!()
165/// #     }
166/// # }
167/// #
168/// # impl ContextExt for MyCustomRuntime {
169/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
170/// #     where
171/// #         F: Future<Output = R> + Send + 'static
172/// #     {
173/// #         unreachable!()
174/// #     }
175/// #     fn get_task_locals() -> Option<TaskLocals> {
176/// #         unreachable!()
177/// #     }
178/// # }
179/// #
180/// # use std::time::Duration;
181/// #
182/// # use pyo3::prelude::*;
183/// #
184/// # Python::with_gil(|py| -> PyResult<()> {
185/// # let event_loop = py.import("asyncio")?.call_method0("new_event_loop")?;
186/// # #[cfg(feature = "tokio-runtime")]
187/// pyo3_async_runtimes::generic::run_until_complete::<MyCustomRuntime, _, _>(&event_loop, async move {
188///     tokio::time::sleep(Duration::from_secs(1)).await;
189///     Ok(())
190/// })?;
191/// # Ok(())
192/// # }).unwrap();
193/// ```
194pub fn run_until_complete<R, F, T>(event_loop: &Bound<PyAny>, fut: F) -> PyResult<T>
195where
196    R: Runtime + ContextExt,
197    F: Future<Output = PyResult<T>> + Send + 'static,
198    T: Send + Sync + 'static,
199{
200    let py = event_loop.py();
201    let result_tx = Arc::new(Mutex::new(None));
202    let result_rx = Arc::clone(&result_tx);
203    let coro = future_into_py_with_locals::<R, _, ()>(
204        py,
205        TaskLocals::new(event_loop.clone()).copy_context(py)?,
206        async move {
207            let val = fut.await?;
208            if let Ok(mut result) = result_tx.lock() {
209                *result = Some(val);
210            }
211            Ok(())
212        },
213    )?;
214
215    event_loop.call_method1("run_until_complete", (coro,))?;
216
217    let result = result_rx.lock().unwrap().take().unwrap();
218    Ok(result)
219}
220
221/// Run the event loop until the given Future completes
222///
223/// # Arguments
224/// * `py` - The current PyO3 GIL guard
225/// * `fut` - The future to drive to completion
226///
227/// # Examples
228///
229/// ```no_run
230/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
231/// #
232/// # use pyo3_async_runtimes::{
233/// #     TaskLocals,
234/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
235/// # };
236/// #
237/// # struct MyCustomJoinError;
238/// #
239/// # impl JoinError for MyCustomJoinError {
240/// #     fn is_panic(&self) -> bool {
241/// #         unreachable!()
242/// #     }
243/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
244/// #         unreachable!()
245/// #     }
246/// # }
247/// #
248/// # struct MyCustomJoinHandle;
249/// #
250/// # impl Future for MyCustomJoinHandle {
251/// #     type Output = Result<(), MyCustomJoinError>;
252/// #
253/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
254/// #         unreachable!()
255/// #     }
256/// # }
257/// #
258/// # struct MyCustomRuntime;
259/// #
260/// # impl Runtime for MyCustomRuntime {
261/// #     type JoinError = MyCustomJoinError;
262/// #     type JoinHandle = MyCustomJoinHandle;
263/// #
264/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
265/// #     where
266/// #         F: Future<Output = ()> + Send + 'static
267/// #     {
268/// #         unreachable!()
269/// #     }
270/// # }
271/// #
272/// # impl ContextExt for MyCustomRuntime {
273/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
274/// #     where
275/// #         F: Future<Output = R> + Send + 'static
276/// #     {
277/// #         unreachable!()
278/// #     }
279/// #     fn get_task_locals() -> Option<TaskLocals> {
280/// #         unreachable!()
281/// #     }
282/// # }
283/// #
284/// # use std::time::Duration;
285/// # async fn custom_sleep(_duration: Duration) { }
286/// #
287/// # use pyo3::prelude::*;
288/// #
289/// fn main() {
290///     Python::with_gil(|py| {
291///         pyo3_async_runtimes::generic::run::<MyCustomRuntime, _, _>(py, async move {
292///             custom_sleep(Duration::from_secs(1)).await;
293///             Ok(())
294///         })
295///         .map_err(|e| {
296///             e.print_and_set_sys_last_vars(py);
297///         })
298///         .unwrap();
299///     })
300/// }
301/// ```
302pub fn run<R, F, T>(py: Python, fut: F) -> PyResult<T>
303where
304    R: Runtime + ContextExt,
305    F: Future<Output = PyResult<T>> + Send + 'static,
306    T: Send + Sync + 'static,
307{
308    let event_loop = asyncio(py)?.call_method0("new_event_loop")?;
309
310    let result = run_until_complete::<R, F, T>(&event_loop, fut);
311
312    close(event_loop)?;
313
314    result
315}
316
317fn cancelled(future: &Bound<PyAny>) -> PyResult<bool> {
318    future.getattr("cancelled")?.call0()?.is_truthy()
319}
320
321#[pyclass]
322struct CheckedCompletor;
323
324#[pymethods]
325impl CheckedCompletor {
326    fn __call__(
327        &self,
328        future: &Bound<PyAny>,
329        complete: &Bound<PyAny>,
330        value: &Bound<PyAny>,
331    ) -> PyResult<()> {
332        if cancelled(future)? {
333            return Ok(());
334        }
335
336        complete.call1((value,))?;
337
338        Ok(())
339    }
340}
341
342fn set_result(
343    event_loop: &Bound<PyAny>,
344    future: &Bound<PyAny>,
345    result: PyResult<PyObject>,
346) -> PyResult<()> {
347    let py = event_loop.py();
348    let none = py.None().into_bound(py);
349
350    let (complete, val) = match result {
351        Ok(val) => (future.getattr("set_result")?, val.into_pyobject(py)?),
352        Err(err) => (
353            future.getattr("set_exception")?,
354            err.into_pyobject(py)?.into_any(),
355        ),
356    };
357    call_soon_threadsafe(event_loop, &none, (CheckedCompletor, future, complete, val))?;
358
359    Ok(())
360}
361
362/// Convert a Python `awaitable` into a Rust Future
363///
364/// This function simply forwards the future and the task locals returned by [`get_current_locals`]
365/// to [`into_future_with_locals`](`crate::into_future_with_locals`). See
366/// [`into_future_with_locals`](`crate::into_future_with_locals`) for more details.
367///
368/// # Arguments
369/// * `awaitable` - The Python `awaitable` to be converted
370///
371/// # Examples
372///
373/// ```no_run
374/// # use std::{any::Any, pin::Pin, future::Future, task::{Context, Poll}, time::Duration};
375/// # use std::ffi::CString;
376/// #
377/// # use pyo3::prelude::*;
378/// #
379/// # use pyo3_async_runtimes::{
380/// #     TaskLocals,
381/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
382/// # };
383/// #
384/// # struct MyCustomJoinError;
385/// #
386/// # impl JoinError for MyCustomJoinError {
387/// #     fn is_panic(&self) -> bool {
388/// #         unreachable!()
389/// #     }
390/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
391/// #         unreachable!()
392/// #     }
393/// # }
394/// #
395/// # struct MyCustomJoinHandle;
396/// #
397/// # impl Future for MyCustomJoinHandle {
398/// #     type Output = Result<(), MyCustomJoinError>;
399/// #
400/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
401/// #         unreachable!()
402/// #     }
403/// # }
404/// #
405/// # struct MyCustomRuntime;
406/// #
407/// # impl MyCustomRuntime {
408/// #     async fn sleep(_: Duration) {
409/// #         unreachable!()
410/// #     }
411/// # }
412/// #
413/// # impl Runtime for MyCustomRuntime {
414/// #     type JoinError = MyCustomJoinError;
415/// #     type JoinHandle = MyCustomJoinHandle;
416/// #
417/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
418/// #     where
419/// #         F: Future<Output = ()> + Send + 'static
420/// #     {
421/// #         unreachable!()
422/// #     }
423/// # }
424/// #
425/// # impl ContextExt for MyCustomRuntime {
426/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
427/// #     where
428/// #         F: Future<Output = R> + Send + 'static
429/// #     {
430/// #         unreachable!()
431/// #     }
432/// #     fn get_task_locals() -> Option<TaskLocals> {
433/// #         unreachable!()
434/// #     }
435/// # }
436/// #
437/// const PYTHON_CODE: &'static str = r#"
438/// import asyncio
439///
440/// async def py_sleep(duration):
441///     await asyncio.sleep(duration)
442/// "#;
443///
444/// async fn py_sleep(seconds: f32) -> PyResult<()> {
445///     let test_mod = Python::with_gil(|py| -> PyResult<PyObject> {
446///         Ok(
447///             PyModule::from_code(
448///                 py,
449///                 &CString::new(PYTHON_CODE).unwrap(),
450///                 &CString::new("test_into_future/test_mod.py").unwrap(),
451///                 &CString::new("test_mod").unwrap(),
452///             )?
453///             .into()
454///         )
455///     })?;
456///
457///     Python::with_gil(|py| {
458///         pyo3_async_runtimes::generic::into_future::<MyCustomRuntime>(
459///             test_mod
460///                 .call_method1(py, "py_sleep", (seconds.into_pyobject(py).unwrap(),))?
461///                 .into_bound(py),
462///         )
463///     })?
464///     .await?;
465///     Ok(())
466/// }
467/// ```
468pub fn into_future<R>(
469    awaitable: Bound<PyAny>,
470) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send>
471where
472    R: Runtime + ContextExt,
473{
474    into_future_with_locals(&get_current_locals::<R>(awaitable.py())?, awaitable)
475}
476
477/// Convert a Rust Future into a Python awaitable with a generic runtime
478///
479/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
480/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
481///
482/// Python `contextvars` are preserved when calling async Python functions within the Rust future
483/// via [`into_future`] (new behaviour in `v0.15`).
484///
485/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
486/// > unfortunately fail to resolve them when called within the Rust future. This is because the
487/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
488/// >
489/// > As a workaround, you can get the `contextvars` from the current task locals using
490/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
491/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
492/// > synchronous function, and restore the previous context when it returns or raises an exception.
493///
494/// # Arguments
495/// * `py` - PyO3 GIL guard
496/// * `locals` - The task-local data for Python
497/// * `fut` - The Rust future to be converted
498///
499/// # Examples
500///
501/// ```no_run
502/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
503/// #
504/// # use pyo3_async_runtimes::{
505/// #     TaskLocals,
506/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
507/// # };
508/// #
509/// # struct MyCustomJoinError;
510/// #
511/// # impl JoinError for MyCustomJoinError {
512/// #     fn is_panic(&self) -> bool {
513/// #         unreachable!()
514/// #     }
515/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
516/// #         unreachable!()
517/// #     }
518/// # }
519/// #
520/// # struct MyCustomJoinHandle;
521/// #
522/// # impl Future for MyCustomJoinHandle {
523/// #     type Output = Result<(), MyCustomJoinError>;
524/// #
525/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
526/// #         unreachable!()
527/// #     }
528/// # }
529/// #
530/// # struct MyCustomRuntime;
531/// #
532/// # impl MyCustomRuntime {
533/// #     async fn sleep(_: Duration) {
534/// #         unreachable!()
535/// #     }
536/// # }
537/// #
538/// # impl Runtime for MyCustomRuntime {
539/// #     type JoinError = MyCustomJoinError;
540/// #     type JoinHandle = MyCustomJoinHandle;
541/// #
542/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
543/// #     where
544/// #         F: Future<Output = ()> + Send + 'static
545/// #     {
546/// #         unreachable!()
547/// #     }
548/// # }
549/// #
550/// # impl ContextExt for MyCustomRuntime {
551/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
552/// #     where
553/// #         F: Future<Output = R> + Send + 'static
554/// #     {
555/// #         unreachable!()
556/// #     }
557/// #     fn get_task_locals() -> Option<TaskLocals> {
558/// #         unreachable!()
559/// #     }
560/// # }
561/// #
562/// use std::time::Duration;
563///
564/// use pyo3::prelude::*;
565///
566/// /// Awaitable sleep function
567/// #[pyfunction]
568/// fn sleep_for<'p>(py: Python<'p>, secs: Bound<'p, PyAny>) -> PyResult<Bound<'p, PyAny>> {
569///     let secs = secs.extract()?;
570///     pyo3_async_runtimes::generic::future_into_py_with_locals::<MyCustomRuntime, _, _>(
571///         py,
572///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
573///         async move {
574///             MyCustomRuntime::sleep(Duration::from_secs(secs)).await;
575///             Ok(())
576///         }
577///     )
578/// }
579/// ```
580#[allow(unused_must_use)]
581pub fn future_into_py_with_locals<R, F, T>(
582    py: Python,
583    locals: TaskLocals,
584    fut: F,
585) -> PyResult<Bound<PyAny>>
586where
587    R: Runtime + ContextExt,
588    F: Future<Output = PyResult<T>> + Send + 'static,
589    T: for<'py> IntoPyObject<'py>,
590{
591    let (cancel_tx, cancel_rx) = oneshot::channel();
592
593    let py_fut = create_future(locals.event_loop.bind(py).clone())?;
594    py_fut.call_method1(
595        "add_done_callback",
596        (PyDoneCallback {
597            cancel_tx: Some(cancel_tx),
598        },),
599    )?;
600
601    let future_tx1 = PyObject::from(py_fut.clone());
602    let future_tx2 = future_tx1.clone_ref(py);
603
604    R::spawn(async move {
605        let locals2 = Python::with_gil(|py| locals.clone_ref(py));
606
607        if let Err(e) = R::spawn(async move {
608            let result = R::scope(
609                Python::with_gil(|py| locals2.clone_ref(py)),
610                Cancellable::new_with_cancel_rx(fut, cancel_rx),
611            )
612            .await;
613
614            Python::with_gil(move |py| {
615                if cancelled(future_tx1.bind(py))
616                    .map_err(dump_err(py))
617                    .unwrap_or(false)
618                {
619                    return;
620                }
621
622                let _ = set_result(
623                    &locals2.event_loop(py),
624                    future_tx1.bind(py),
625                    result.and_then(|val| match val.into_pyobject(py) {
626                        Ok(obj) => Ok(obj.into_any().unbind()),
627                        Err(err) => Err(err.into()),
628                    }),
629                )
630                .map_err(dump_err(py));
631            });
632        })
633        .await
634        {
635            if e.is_panic() {
636                Python::with_gil(move |py| {
637                    if cancelled(future_tx2.bind(py))
638                        .map_err(dump_err(py))
639                        .unwrap_or(false)
640                    {
641                        return;
642                    }
643
644                    let panic_message = format!(
645                        "rust future panicked: {}",
646                        get_panic_message(&e.into_panic())
647                    );
648                    let _ = set_result(
649                        locals.event_loop.bind(py),
650                        future_tx2.bind(py),
651                        Err(RustPanic::new_err(panic_message)),
652                    )
653                    .map_err(dump_err(py));
654                });
655            }
656        }
657    });
658
659    Ok(py_fut)
660}
661
662fn get_panic_message(any: &dyn std::any::Any) -> &str {
663    if let Some(str_slice) = any.downcast_ref::<&str>() {
664        str_slice
665    } else if let Some(string) = any.downcast_ref::<String>() {
666        string.as_str()
667    } else {
668        "unknown error"
669    }
670}
671
672pin_project! {
673    /// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at).
674    #[must_use = "futures do nothing unless you `.await` or poll them"]
675    #[derive(Debug)]
676    struct Cancellable<T> {
677        #[pin]
678        future: T,
679        #[pin]
680        cancel_rx: oneshot::Receiver<()>,
681
682        poll_cancel_rx: bool
683    }
684}
685
686impl<T> Cancellable<T> {
687    fn new_with_cancel_rx(future: T, cancel_rx: oneshot::Receiver<()>) -> Self {
688        Self {
689            future,
690            cancel_rx,
691
692            poll_cancel_rx: true,
693        }
694    }
695}
696
697impl<'py, F, T> Future for Cancellable<F>
698where
699    F: Future<Output = PyResult<T>>,
700    T: IntoPyObject<'py>,
701{
702    type Output = F::Output;
703
704    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
705        let this = self.project();
706
707        // First, try polling the future
708        if let Poll::Ready(v) = this.future.poll(cx) {
709            return Poll::Ready(v);
710        }
711
712        // Now check for cancellation
713        if *this.poll_cancel_rx {
714            match this.cancel_rx.poll(cx) {
715                Poll::Ready(Ok(())) => {
716                    *this.poll_cancel_rx = false;
717                    // The python future has already been cancelled, so this return value will never
718                    // be used.
719                    Poll::Ready(Err(pyo3::exceptions::PyBaseException::new_err(
720                        "unreachable",
721                    )))
722                }
723                Poll::Ready(Err(_)) => {
724                    *this.poll_cancel_rx = false;
725                    Poll::Pending
726                }
727                Poll::Pending => Poll::Pending,
728            }
729        } else {
730            Poll::Pending
731        }
732    }
733}
734
735#[pyclass]
736struct PyDoneCallback {
737    cancel_tx: Option<oneshot::Sender<()>>,
738}
739
740#[pymethods]
741impl PyDoneCallback {
742    pub fn __call__(&mut self, fut: &Bound<PyAny>) -> PyResult<()> {
743        let py = fut.py();
744
745        if cancelled(fut).map_err(dump_err(py)).unwrap_or(false) {
746            let _ = self.cancel_tx.take().unwrap().send(());
747        }
748
749        Ok(())
750    }
751}
752
753/// Convert a Rust Future into a Python awaitable with a generic runtime
754///
755/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
756/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
757///
758/// Python `contextvars` are preserved when calling async Python functions within the Rust future
759/// via [`into_future`] (new behaviour in `v0.15`).
760///
761/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
762/// > unfortunately fail to resolve them when called within the Rust future. This is because the
763/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
764/// >
765/// > As a workaround, you can get the `contextvars` from the current task locals using
766/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
767/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
768/// > synchronous function, and restore the previous context when it returns or raises an exception.
769///
770/// # Arguments
771/// * `py` - The current PyO3 GIL guard
772/// * `fut` - The Rust future to be converted
773///
774/// # Examples
775///
776/// ```no_run
777/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
778/// #
779/// # use pyo3_async_runtimes::{
780/// #     TaskLocals,
781/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
782/// # };
783/// #
784/// # struct MyCustomJoinError;
785/// #
786/// # impl JoinError for MyCustomJoinError {
787/// #     fn is_panic(&self) -> bool {
788/// #         unreachable!()
789/// #     }
790/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
791/// #         unreachable!()
792/// #     }
793/// # }
794/// #
795/// # struct MyCustomJoinHandle;
796/// #
797/// # impl Future for MyCustomJoinHandle {
798/// #     type Output = Result<(), MyCustomJoinError>;
799/// #
800/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
801/// #         unreachable!()
802/// #     }
803/// # }
804/// #
805/// # struct MyCustomRuntime;
806/// #
807/// # impl MyCustomRuntime {
808/// #     async fn sleep(_: Duration) {
809/// #         unreachable!()
810/// #     }
811/// # }
812/// #
813/// # impl Runtime for MyCustomRuntime {
814/// #     type JoinError = MyCustomJoinError;
815/// #     type JoinHandle = MyCustomJoinHandle;
816/// #
817/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
818/// #     where
819/// #         F: Future<Output = ()> + Send + 'static
820/// #     {
821/// #         unreachable!()
822/// #     }
823/// # }
824/// #
825/// # impl ContextExt for MyCustomRuntime {
826/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
827/// #     where
828/// #         F: Future<Output = R> + Send + 'static
829/// #     {
830/// #         unreachable!()
831/// #     }
832/// #     fn get_task_locals() -> Option<TaskLocals> {
833/// #         unreachable!()
834/// #     }
835/// # }
836/// #
837/// use std::time::Duration;
838///
839/// use pyo3::prelude::*;
840///
841/// /// Awaitable sleep function
842/// #[pyfunction]
843/// fn sleep_for<'p>(py: Python<'p>, secs: Bound<'p, PyAny>) -> PyResult<Bound<'p, PyAny>> {
844///     let secs = secs.extract()?;
845///     pyo3_async_runtimes::generic::future_into_py::<MyCustomRuntime, _, _>(py, async move {
846///         MyCustomRuntime::sleep(Duration::from_secs(secs)).await;
847///         Ok(())
848///     })
849/// }
850/// ```
851pub fn future_into_py<R, F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
852where
853    R: Runtime + ContextExt,
854    F: Future<Output = PyResult<T>> + Send + 'static,
855    T: for<'py> IntoPyObject<'py>,
856{
857    future_into_py_with_locals::<R, F, T>(py, get_current_locals::<R>(py)?, fut)
858}
859
860/// Convert a `!Send` Rust Future into a Python awaitable with a generic runtime and manual
861/// specification of task locals.
862///
863/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
864/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
865///
866/// Python `contextvars` are preserved when calling async Python functions within the Rust future
867/// via [`into_future`] (new behaviour in `v0.15`).
868///
869/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
870/// > unfortunately fail to resolve them when called within the Rust future. This is because the
871/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
872/// >
873/// > As a workaround, you can get the `contextvars` from the current task locals using
874/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
875/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
876/// > synchronous function, and restore the previous context when it returns or raises an exception.
877///
878/// # Arguments
879/// * `py` - PyO3 GIL guard
880/// * `locals` - The task locals for the future
881/// * `fut` - The Rust future to be converted
882///
883/// # Examples
884///
885/// ```no_run
886/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
887/// #
888/// # use pyo3_async_runtimes::{
889/// #     TaskLocals,
890/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
891/// # };
892/// #
893/// # struct MyCustomJoinError;
894/// #
895/// # impl JoinError for MyCustomJoinError {
896/// #     fn is_panic(&self) -> bool {
897/// #         unreachable!()
898/// #     }
899/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
900/// #         unreachable!()
901/// #     }
902/// # }
903/// #
904/// # struct MyCustomJoinHandle;
905/// #
906/// # impl Future for MyCustomJoinHandle {
907/// #     type Output = Result<(), MyCustomJoinError>;
908/// #
909/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
910/// #         unreachable!()
911/// #     }
912/// # }
913/// #
914/// # struct MyCustomRuntime;
915/// #
916/// # impl MyCustomRuntime {
917/// #     async fn sleep(_: Duration) {
918/// #         unreachable!()
919/// #     }
920/// # }
921/// #
922/// # impl Runtime for MyCustomRuntime {
923/// #     type JoinError = MyCustomJoinError;
924/// #     type JoinHandle = MyCustomJoinHandle;
925/// #
926/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
927/// #     where
928/// #         F: Future<Output = ()> + Send + 'static
929/// #     {
930/// #         unreachable!()
931/// #     }
932/// # }
933/// #
934/// # impl ContextExt for MyCustomRuntime {
935/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
936/// #     where
937/// #         F: Future<Output = R> + Send + 'static
938/// #     {
939/// #         unreachable!()
940/// #     }
941/// #     fn get_task_locals() -> Option<TaskLocals> {
942/// #         unreachable!()
943/// #     }
944/// # }
945/// #
946/// # impl SpawnLocalExt for MyCustomRuntime {
947/// #     fn spawn_local<F>(fut: F) -> Self::JoinHandle
948/// #     where
949/// #         F: Future<Output = ()> + 'static
950/// #     {
951/// #         unreachable!()
952/// #     }
953/// # }
954/// #
955/// # impl LocalContextExt for MyCustomRuntime {
956/// #     fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
957/// #     where
958/// #         F: Future<Output = R> + 'static
959/// #     {
960/// #         unreachable!()
961/// #     }
962/// # }
963/// #
964/// use std::{rc::Rc, time::Duration};
965///
966/// use pyo3::prelude::*;
967///
968/// /// Awaitable sleep function
969/// #[pyfunction]
970/// fn sleep_for(py: Python, secs: u64) -> PyResult<Bound<PyAny>> {
971///     // Rc is !Send so it cannot be passed into pyo3_async_runtimes::generic::future_into_py
972///     let secs = Rc::new(secs);
973///
974///     pyo3_async_runtimes::generic::local_future_into_py_with_locals::<MyCustomRuntime, _, _>(
975///         py,
976///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
977///         async move {
978///             MyCustomRuntime::sleep(Duration::from_secs(*secs)).await;
979///             Ok(())
980///         }
981///     )
982/// }
983/// ```
984#[deprecated(
985    since = "0.18.0",
986    note = "Questionable whether these conversions have real-world utility (see https://github.com/awestlake87/pyo3-asyncio/issues/59#issuecomment-1008038497 and let me know if you disagree!)"
987)]
988#[allow(unused_must_use)]
989pub fn local_future_into_py_with_locals<R, F, T>(
990    py: Python,
991    locals: TaskLocals,
992    fut: F,
993) -> PyResult<Bound<PyAny>>
994where
995    R: Runtime + SpawnLocalExt + LocalContextExt,
996    F: Future<Output = PyResult<T>> + 'static,
997    T: for<'py> IntoPyObject<'py>,
998{
999    let (cancel_tx, cancel_rx) = oneshot::channel();
1000
1001    let py_fut = create_future(locals.event_loop.clone_ref(py).into_bound(py))?;
1002    py_fut.call_method1(
1003        "add_done_callback",
1004        (PyDoneCallback {
1005            cancel_tx: Some(cancel_tx),
1006        },),
1007    )?;
1008
1009    let future_tx1 = PyObject::from(py_fut.clone());
1010    let future_tx2 = future_tx1.clone_ref(py);
1011
1012    R::spawn_local(async move {
1013        let locals2 = Python::with_gil(|py| locals.clone_ref(py));
1014
1015        if let Err(e) = R::spawn_local(async move {
1016            let result = R::scope_local(
1017                Python::with_gil(|py| locals2.clone_ref(py)),
1018                Cancellable::new_with_cancel_rx(fut, cancel_rx),
1019            )
1020            .await;
1021
1022            Python::with_gil(move |py| {
1023                if cancelled(future_tx1.bind(py))
1024                    .map_err(dump_err(py))
1025                    .unwrap_or(false)
1026                {
1027                    return;
1028                }
1029
1030                let _ = set_result(
1031                    locals2.event_loop.bind(py),
1032                    future_tx1.bind(py),
1033                    result.and_then(|val| match val.into_pyobject(py) {
1034                        Ok(obj) => Ok(obj.into_any().unbind()),
1035                        Err(err) => Err(err.into()),
1036                    }),
1037                )
1038                .map_err(dump_err(py));
1039            });
1040        })
1041        .await
1042        {
1043            if e.is_panic() {
1044                Python::with_gil(move |py| {
1045                    if cancelled(future_tx2.bind(py))
1046                        .map_err(dump_err(py))
1047                        .unwrap_or(false)
1048                    {
1049                        return;
1050                    }
1051
1052                    let panic_message = format!(
1053                        "rust future panicked: {}",
1054                        get_panic_message(&e.into_panic())
1055                    );
1056                    let _ = set_result(
1057                        locals.event_loop.bind(py),
1058                        future_tx2.bind(py),
1059                        Err(RustPanic::new_err(panic_message)),
1060                    )
1061                    .map_err(dump_err(py));
1062                });
1063            }
1064        }
1065    });
1066
1067    Ok(py_fut)
1068}
1069
1070/// Convert a `!Send` Rust Future into a Python awaitable with a generic runtime
1071///
1072/// If the `asyncio.Future` returned by this conversion is cancelled via `asyncio.Future.cancel`,
1073/// the Rust future will be cancelled as well (new behaviour in `v0.15`).
1074///
1075/// Python `contextvars` are preserved when calling async Python functions within the Rust future
1076/// via [`into_future`] (new behaviour in `v0.15`).
1077///
1078/// > Although `contextvars` are preserved for async Python functions, synchronous functions will
1079/// > unfortunately fail to resolve them when called within the Rust future. This is because the
1080/// > function is being called from a Rust thread, not inside an actual Python coroutine context.
1081/// >
1082/// > As a workaround, you can get the `contextvars` from the current task locals using
1083/// > [`get_current_locals`] and [`TaskLocals::context`](`crate::TaskLocals::context`), then wrap your
1084/// > synchronous function in a call to `contextvars.Context.run`. This will set the context, call the
1085/// > synchronous function, and restore the previous context when it returns or raises an exception.
1086///
1087/// # Arguments
1088/// * `py` - The current PyO3 GIL guard
1089/// * `fut` - The Rust future to be converted
1090///
1091/// # Examples
1092///
1093/// ```no_run
1094/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1095/// #
1096/// # use pyo3_async_runtimes::{
1097/// #     TaskLocals,
1098/// #     generic::{JoinError, SpawnLocalExt, ContextExt, LocalContextExt, Runtime}
1099/// # };
1100/// #
1101/// # struct MyCustomJoinError;
1102/// #
1103/// # impl JoinError for MyCustomJoinError {
1104/// #     fn is_panic(&self) -> bool {
1105/// #         unreachable!()
1106/// #     }
1107/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1108/// #         unreachable!()
1109/// #     }
1110/// # }
1111/// #
1112/// # struct MyCustomJoinHandle;
1113/// #
1114/// # impl Future for MyCustomJoinHandle {
1115/// #     type Output = Result<(), MyCustomJoinError>;
1116/// #
1117/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1118/// #         unreachable!()
1119/// #     }
1120/// # }
1121/// #
1122/// # struct MyCustomRuntime;
1123/// #
1124/// # impl MyCustomRuntime {
1125/// #     async fn sleep(_: Duration) {
1126/// #         unreachable!()
1127/// #     }
1128/// # }
1129/// #
1130/// # impl Runtime for MyCustomRuntime {
1131/// #     type JoinError = MyCustomJoinError;
1132/// #     type JoinHandle = MyCustomJoinHandle;
1133/// #
1134/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1135/// #     where
1136/// #         F: Future<Output = ()> + Send + 'static
1137/// #     {
1138/// #         unreachable!()
1139/// #     }
1140/// # }
1141/// #
1142/// # impl ContextExt for MyCustomRuntime {
1143/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1144/// #     where
1145/// #         F: Future<Output = R> + Send + 'static
1146/// #     {
1147/// #         unreachable!()
1148/// #     }
1149/// #     fn get_task_locals() -> Option<TaskLocals> {
1150/// #         unreachable!()
1151/// #     }
1152/// # }
1153/// #
1154/// # impl SpawnLocalExt for MyCustomRuntime {
1155/// #     fn spawn_local<F>(fut: F) -> Self::JoinHandle
1156/// #     where
1157/// #         F: Future<Output = ()> + 'static
1158/// #     {
1159/// #         unreachable!()
1160/// #     }
1161/// # }
1162/// #
1163/// # impl LocalContextExt for MyCustomRuntime {
1164/// #     fn scope_local<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R>>>
1165/// #     where
1166/// #         F: Future<Output = R> + 'static
1167/// #     {
1168/// #         unreachable!()
1169/// #     }
1170/// # }
1171/// #
1172/// use std::{rc::Rc, time::Duration};
1173///
1174/// use pyo3::prelude::*;
1175///
1176/// /// Awaitable sleep function
1177/// #[pyfunction]
1178/// fn sleep_for(py: Python, secs: u64) -> PyResult<Bound<PyAny>> {
1179///     // Rc is !Send so it cannot be passed into pyo3_async_runtimes::generic::future_into_py
1180///     let secs = Rc::new(secs);
1181///
1182///     pyo3_async_runtimes::generic::local_future_into_py::<MyCustomRuntime, _, _>(py, async move {
1183///         MyCustomRuntime::sleep(Duration::from_secs(*secs)).await;
1184///         Ok(())
1185///     })
1186/// }
1187/// ```
1188#[deprecated(
1189    since = "0.18.0",
1190    note = "Questionable whether these conversions have real-world utility (see https://github.com/awestlake87/pyo3-asyncio/issues/59#issuecomment-1008038497 and let me know if you disagree!)"
1191)]
1192#[allow(deprecated)]
1193pub fn local_future_into_py<R, F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
1194where
1195    R: Runtime + ContextExt + SpawnLocalExt + LocalContextExt,
1196    F: Future<Output = PyResult<T>> + 'static,
1197    T: for<'py> IntoPyObject<'py>,
1198{
1199    local_future_into_py_with_locals::<R, F, T>(py, get_current_locals::<R>(py)?, fut)
1200}
1201
1202/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1203///
1204/// **This API is marked as unstable** and is only available when the
1205/// `unstable-streams` crate feature is enabled. This comes with no
1206/// stability guarantees, and could be changed or removed at any time.
1207///
1208/// # Arguments
1209/// * `locals` - The current task locals
1210/// * `gen` - The Python async generator to be converted
1211///
1212/// # Examples
1213/// ```no_run
1214/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1215/// #
1216/// # use pyo3_async_runtimes::{
1217/// #     TaskLocals,
1218/// #     generic::{JoinError, ContextExt, Runtime}
1219/// # };
1220/// #
1221/// # struct MyCustomJoinError;
1222/// #
1223/// # impl JoinError for MyCustomJoinError {
1224/// #     fn is_panic(&self) -> bool {
1225/// #         unreachable!()
1226/// #     }
1227/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1228/// #         unreachable!()
1229/// #     }
1230/// # }
1231/// #
1232/// # struct MyCustomJoinHandle;
1233/// #
1234/// # impl Future for MyCustomJoinHandle {
1235/// #     type Output = Result<(), MyCustomJoinError>;
1236/// #
1237/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1238/// #         unreachable!()
1239/// #     }
1240/// # }
1241/// #
1242/// # struct MyCustomRuntime;
1243/// #
1244/// # impl Runtime for MyCustomRuntime {
1245/// #     type JoinError = MyCustomJoinError;
1246/// #     type JoinHandle = MyCustomJoinHandle;
1247/// #
1248/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1249/// #     where
1250/// #         F: Future<Output = ()> + Send + 'static
1251/// #     {
1252/// #         unreachable!()
1253/// #     }
1254/// # }
1255/// #
1256/// # impl ContextExt for MyCustomRuntime {
1257/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1258/// #     where
1259/// #         F: Future<Output = R> + Send + 'static
1260/// #     {
1261/// #         unreachable!()
1262/// #     }
1263/// #     fn get_task_locals() -> Option<TaskLocals> {
1264/// #         unreachable!()
1265/// #     }
1266/// # }
1267///
1268/// use pyo3::prelude::*;
1269/// use futures::{StreamExt, TryStreamExt};
1270/// use std::ffi::CString;
1271///
1272/// const TEST_MOD: &str = r#"
1273/// import asyncio
1274///
1275/// async def gen():
1276///     for i in range(10):
1277///         await asyncio.sleep(0.1)
1278///         yield i
1279/// "#;
1280///
1281/// # async fn test_async_gen() -> PyResult<()> {
1282/// let stream = Python::with_gil(|py| {
1283///     let test_mod = PyModule::from_code(
1284///         py,
1285///         &CString::new(TEST_MOD).unwrap(),
1286///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1287///         &CString::new("test_mod").unwrap(),
1288///     )?;
1289///
1290///     pyo3_async_runtimes::generic::into_stream_with_locals_v1::<MyCustomRuntime>(
1291///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
1292///         test_mod.call_method0("gen")?
1293///     )
1294/// })?;
1295///
1296/// let vals = stream
1297///     .map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item?.bind(py).extract()?) }))
1298///     .try_collect::<Vec<i32>>()
1299///     .await?;
1300///
1301/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1302///
1303/// Ok(())
1304/// # }
1305/// ```
1306#[cfg(feature = "unstable-streams")]
1307#[allow(unused_must_use)] // False positive unused lint on `R::spawn`
1308pub fn into_stream_with_locals_v1<R>(
1309    locals: TaskLocals,
1310    gen: Bound<'_, PyAny>,
1311) -> PyResult<impl futures::Stream<Item = PyResult<PyObject>> + 'static>
1312where
1313    R: Runtime,
1314{
1315    let (tx, rx) = async_channel::bounded(1);
1316    let anext = PyObject::from(gen.getattr("__anext__")?);
1317
1318    R::spawn(async move {
1319        loop {
1320            let fut = Python::with_gil(|py| -> PyResult<_> {
1321                into_future_with_locals(&locals, anext.bind(py).call0()?)
1322            });
1323            let item = match fut {
1324                Ok(fut) => match fut.await {
1325                    Ok(item) => Ok(item),
1326                    Err(e) => {
1327                        let stop_iter = Python::with_gil(|py| {
1328                            e.is_instance_of::<pyo3::exceptions::PyStopAsyncIteration>(py)
1329                        });
1330
1331                        if stop_iter {
1332                            // end the iteration
1333                            break;
1334                        } else {
1335                            Err(e)
1336                        }
1337                    }
1338                },
1339                Err(e) => Err(e),
1340            };
1341
1342            if tx.send(item).await.is_err() {
1343                // receiving side was dropped
1344                break;
1345            }
1346        }
1347    });
1348
1349    Ok(rx)
1350}
1351
1352/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1353///
1354/// **This API is marked as unstable** and is only available when the
1355/// `unstable-streams` crate feature is enabled. This comes with no
1356/// stability guarantees, and could be changed or removed at any time.
1357///
1358/// # Arguments
1359/// * `gen` - The Python async generator to be converted
1360///
1361/// # Examples
1362/// ```no_run
1363/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1364/// #
1365/// # use pyo3_async_runtimes::{
1366/// #     TaskLocals,
1367/// #     generic::{JoinError, ContextExt, Runtime}
1368/// # };
1369/// #
1370/// # struct MyCustomJoinError;
1371/// #
1372/// # impl JoinError for MyCustomJoinError {
1373/// #     fn is_panic(&self) -> bool {
1374/// #         unreachable!()
1375/// #     }
1376/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1377/// #         unreachable!()
1378/// #     }
1379/// # }
1380/// #
1381/// # struct MyCustomJoinHandle;
1382/// #
1383/// # impl Future for MyCustomJoinHandle {
1384/// #     type Output = Result<(), MyCustomJoinError>;
1385/// #
1386/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1387/// #         unreachable!()
1388/// #     }
1389/// # }
1390/// #
1391/// # struct MyCustomRuntime;
1392/// #
1393/// # impl Runtime for MyCustomRuntime {
1394/// #     type JoinError = MyCustomJoinError;
1395/// #     type JoinHandle = MyCustomJoinHandle;
1396/// #
1397/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1398/// #     where
1399/// #         F: Future<Output = ()> + Send + 'static
1400/// #     {
1401/// #         unreachable!()
1402/// #     }
1403/// # }
1404/// #
1405/// # impl ContextExt for MyCustomRuntime {
1406/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1407/// #     where
1408/// #         F: Future<Output = R> + Send + 'static
1409/// #     {
1410/// #         unreachable!()
1411/// #     }
1412/// #     fn get_task_locals() -> Option<TaskLocals> {
1413/// #         unreachable!()
1414/// #     }
1415/// # }
1416///
1417/// use pyo3::prelude::*;
1418/// use futures::{StreamExt, TryStreamExt};
1419/// use std::ffi::CString;
1420///
1421/// const TEST_MOD: &str = r#"
1422/// import asyncio
1423///
1424/// async def gen():
1425///     for i in range(10):
1426///         await asyncio.sleep(0.1)
1427///         yield i
1428/// "#;
1429///
1430/// # async fn test_async_gen() -> PyResult<()> {
1431/// let stream = Python::with_gil(|py| {
1432///     let test_mod = PyModule::from_code(
1433///         py,
1434///         &CString::new(TEST_MOD).unwrap(),
1435///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1436///         &CString::new("test_mod").unwrap(),
1437///     )?;
1438///
1439///     pyo3_async_runtimes::generic::into_stream_v1::<MyCustomRuntime>(test_mod.call_method0("gen")?)
1440/// })?;
1441///
1442/// let vals = stream
1443///     .map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item?.bind(py).extract()?) }))
1444///     .try_collect::<Vec<i32>>()
1445///     .await?;
1446///
1447/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1448///
1449/// Ok(())
1450/// # }
1451/// ```
1452#[cfg(feature = "unstable-streams")]
1453pub fn into_stream_v1<R>(
1454    gen: Bound<'_, PyAny>,
1455) -> PyResult<impl futures::Stream<Item = PyResult<PyObject>> + 'static>
1456where
1457    R: Runtime + ContextExt,
1458{
1459    into_stream_with_locals_v1::<R>(get_current_locals::<R>(gen.py())?, gen)
1460}
1461
1462trait Sender: Send + 'static {
1463    fn send(&mut self, py: Python, locals: TaskLocals, item: PyObject) -> PyResult<PyObject>;
1464    fn close(&mut self) -> PyResult<()>;
1465}
1466
1467#[cfg(feature = "unstable-streams")]
1468struct GenericSender<R>
1469where
1470    R: Runtime,
1471{
1472    runtime: PhantomData<R>,
1473    tx: mpsc::Sender<PyObject>,
1474}
1475
1476#[cfg(feature = "unstable-streams")]
1477impl<R> Sender for GenericSender<R>
1478where
1479    R: Runtime + ContextExt,
1480{
1481    fn send(&mut self, py: Python, locals: TaskLocals, item: PyObject) -> PyResult<PyObject> {
1482        match self.tx.try_send(item.clone_ref(py)) {
1483            Ok(_) => Ok(true.into_pyobject(py)?.into_any().unbind()),
1484            Err(e) => {
1485                if e.is_full() {
1486                    let mut tx = self.tx.clone();
1487                    Python::with_gil(move |py| {
1488                        Ok(
1489                            future_into_py_with_locals::<R, _, PyObject>(py, locals, async move {
1490                                if tx.flush().await.is_err() {
1491                                    // receiving side disconnected
1492                                    return Python::with_gil(|py| {
1493                                        Ok(false.into_pyobject(py)?.into_any().unbind())
1494                                    });
1495                                }
1496                                if tx.send(item).await.is_err() {
1497                                    // receiving side disconnected
1498                                    return Python::with_gil(|py| {
1499                                        Ok(false.into_pyobject(py)?.into_any().unbind())
1500                                    });
1501                                }
1502                                Python::with_gil(|py| {
1503                                    Ok(true.into_pyobject(py)?.into_any().unbind())
1504                                })
1505                            })?
1506                            .into(),
1507                        )
1508                    })
1509                } else {
1510                    Ok(false.into_pyobject(py)?.into_any().unbind())
1511                }
1512            }
1513        }
1514    }
1515    fn close(&mut self) -> PyResult<()> {
1516        self.tx.close_channel();
1517        Ok(())
1518    }
1519}
1520
1521#[pyclass]
1522struct SenderGlue {
1523    locals: TaskLocals,
1524    tx: Arc<Mutex<dyn Sender>>,
1525}
1526#[pymethods]
1527impl SenderGlue {
1528    pub fn send(&mut self, item: PyObject) -> PyResult<PyObject> {
1529        Python::with_gil(|py| {
1530            self.tx
1531                .lock()
1532                .unwrap()
1533                .send(py, self.locals.clone_ref(py), item)
1534        })
1535    }
1536    pub fn close(&mut self) -> PyResult<()> {
1537        self.tx.lock().unwrap().close()
1538    }
1539}
1540
1541#[cfg(feature = "unstable-streams")]
1542const STREAM_GLUE: &str = r#"
1543import asyncio
1544
1545async def forward(gen, sender):
1546    async for item in gen:
1547        should_continue = sender.send(item)
1548
1549        if asyncio.iscoroutine(should_continue):
1550            should_continue = await should_continue
1551
1552        if should_continue:
1553            continue
1554        else:
1555            break
1556
1557    sender.close()
1558"#;
1559
1560/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1561///
1562/// **This API is marked as unstable** and is only available when the
1563/// `unstable-streams` crate feature is enabled. This comes with no
1564/// stability guarantees, and could be changed or removed at any time.
1565///
1566/// # Arguments
1567/// * `locals` - The current task locals
1568/// * `gen` - The Python async generator to be converted
1569///
1570/// # Examples
1571/// ```no_run
1572/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1573/// #
1574/// # use pyo3_async_runtimes::{
1575/// #     TaskLocals,
1576/// #     generic::{JoinError, ContextExt, Runtime}
1577/// # };
1578/// #
1579/// # struct MyCustomJoinError;
1580/// #
1581/// # impl JoinError for MyCustomJoinError {
1582/// #     fn is_panic(&self) -> bool {
1583/// #         unreachable!()
1584/// #     }
1585/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1586/// #         unreachable!()
1587/// #     }
1588/// # }
1589/// #
1590/// # struct MyCustomJoinHandle;
1591/// #
1592/// # impl Future for MyCustomJoinHandle {
1593/// #     type Output = Result<(), MyCustomJoinError>;
1594/// #
1595/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1596/// #         unreachable!()
1597/// #     }
1598/// # }
1599/// #
1600/// # struct MyCustomRuntime;
1601/// #
1602/// # impl Runtime for MyCustomRuntime {
1603/// #     type JoinError = MyCustomJoinError;
1604/// #     type JoinHandle = MyCustomJoinHandle;
1605/// #
1606/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1607/// #     where
1608/// #         F: Future<Output = ()> + Send + 'static
1609/// #     {
1610/// #         unreachable!()
1611/// #     }
1612/// # }
1613/// #
1614/// # impl ContextExt for MyCustomRuntime {
1615/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1616/// #     where
1617/// #         F: Future<Output = R> + Send + 'static
1618/// #     {
1619/// #         unreachable!()
1620/// #     }
1621/// #     fn get_task_locals() -> Option<TaskLocals> {
1622/// #         unreachable!()
1623/// #     }
1624/// # }
1625///
1626/// use pyo3::prelude::*;
1627/// use futures::{StreamExt, TryStreamExt};
1628/// use std::ffi::CString;
1629///
1630/// const TEST_MOD: &str = r#"
1631/// import asyncio
1632///
1633/// async def gen():
1634///     for i in range(10):
1635///         await asyncio.sleep(0.1)
1636///         yield i
1637/// "#;
1638///
1639/// # async fn test_async_gen() -> PyResult<()> {
1640/// let stream = Python::with_gil(|py| {
1641///     let test_mod = PyModule::from_code(
1642///         py,
1643///         &CString::new(TEST_MOD).unwrap(),
1644///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1645///         &CString::new("test_mod").unwrap(),
1646///     )?;
1647///
1648///     pyo3_async_runtimes::generic::into_stream_with_locals_v2::<MyCustomRuntime>(
1649///         pyo3_async_runtimes::generic::get_current_locals::<MyCustomRuntime>(py)?,
1650///         test_mod.call_method0("gen")?
1651///     )
1652/// })?;
1653///
1654/// let vals = stream
1655///     .map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item.bind(py).extract()?) }))
1656///     .try_collect::<Vec<i32>>()
1657///     .await?;
1658///
1659/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1660///
1661/// Ok(())
1662/// # }
1663/// ```
1664#[cfg(feature = "unstable-streams")]
1665pub fn into_stream_with_locals_v2<R>(
1666    locals: TaskLocals,
1667    gen: Bound<'_, PyAny>,
1668) -> PyResult<impl futures::Stream<Item = PyObject> + 'static>
1669where
1670    R: Runtime + ContextExt,
1671{
1672    use std::ffi::CString;
1673
1674    static GLUE_MOD: OnceCell<PyObject> = OnceCell::new();
1675    let py = gen.py();
1676    let glue = GLUE_MOD
1677        .get_or_try_init(|| -> PyResult<PyObject> {
1678            Ok(PyModule::from_code(
1679                py,
1680                &CString::new(STREAM_GLUE).unwrap(),
1681                &CString::new("pyo3_async_runtimes/pyo3_async_runtimes_glue.py").unwrap(),
1682                &CString::new("pyo3_async_runtimes_glue").unwrap(),
1683            )?
1684            .into())
1685        })?
1686        .bind(py);
1687
1688    let (tx, rx) = mpsc::channel(10);
1689
1690    locals.event_loop(py).call_method1(
1691        "call_soon_threadsafe",
1692        (
1693            locals.event_loop(py).getattr("create_task")?,
1694            glue.call_method1(
1695                "forward",
1696                (
1697                    gen,
1698                    SenderGlue {
1699                        locals,
1700                        tx: Arc::new(Mutex::new(GenericSender {
1701                            runtime: PhantomData::<R>,
1702                            tx,
1703                        })),
1704                    },
1705                ),
1706            )?,
1707        ),
1708    )?;
1709    Ok(rx)
1710}
1711
1712/// <span class="module-item stab portability" style="display: inline; border-radius: 3px; padding: 2px; font-size: 80%; line-height: 1.2;"><code>unstable-streams</code></span> Convert an async generator into a stream
1713///
1714/// **This API is marked as unstable** and is only available when the
1715/// `unstable-streams` crate feature is enabled. This comes with no
1716/// stability guarantees, and could be changed or removed at any time.
1717///
1718/// # Arguments
1719/// * `gen` - The Python async generator to be converted
1720///
1721/// # Examples
1722/// ```no_run
1723/// # use std::{any::Any, task::{Context, Poll}, pin::Pin, future::Future};
1724/// #
1725/// # use pyo3_async_runtimes::{
1726/// #     TaskLocals,
1727/// #     generic::{JoinError, ContextExt, Runtime}
1728/// # };
1729/// #
1730/// # struct MyCustomJoinError;
1731/// #
1732/// # impl JoinError for MyCustomJoinError {
1733/// #     fn is_panic(&self) -> bool {
1734/// #         unreachable!()
1735/// #     }
1736/// #     fn into_panic(self) -> Box<(dyn Any + Send + 'static)> {
1737/// #         unreachable!()
1738/// #     }
1739/// # }
1740/// #
1741/// # struct MyCustomJoinHandle;
1742/// #
1743/// # impl Future for MyCustomJoinHandle {
1744/// #     type Output = Result<(), MyCustomJoinError>;
1745/// #
1746/// #     fn poll(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<Self::Output> {
1747/// #         unreachable!()
1748/// #     }
1749/// # }
1750/// #
1751/// # struct MyCustomRuntime;
1752/// #
1753/// # impl Runtime for MyCustomRuntime {
1754/// #     type JoinError = MyCustomJoinError;
1755/// #     type JoinHandle = MyCustomJoinHandle;
1756/// #
1757/// #     fn spawn<F>(fut: F) -> Self::JoinHandle
1758/// #     where
1759/// #         F: Future<Output = ()> + Send + 'static
1760/// #     {
1761/// #         unreachable!()
1762/// #     }
1763/// # }
1764/// #
1765/// # impl ContextExt for MyCustomRuntime {
1766/// #     fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
1767/// #     where
1768/// #         F: Future<Output = R> + Send + 'static
1769/// #     {
1770/// #         unreachable!()
1771/// #     }
1772/// #     fn get_task_locals() -> Option<TaskLocals> {
1773/// #         unreachable!()
1774/// #     }
1775/// # }
1776///
1777/// use pyo3::prelude::*;
1778/// use futures::{StreamExt, TryStreamExt};
1779/// use std::ffi::CString;
1780///
1781/// const TEST_MOD: &str = r#"
1782/// import asyncio
1783///
1784/// async def gen():
1785///     for i in range(10):
1786///         await asyncio.sleep(0.1)
1787///         yield i
1788/// "#;
1789///
1790/// # async fn test_async_gen() -> PyResult<()> {
1791/// let stream = Python::with_gil(|py| {
1792///     let test_mod = PyModule::from_code(
1793///         py,
1794///         &CString::new(TEST_MOD).unwrap(),
1795///         &CString::new("test_rust_coroutine/test_mod.py").unwrap(),
1796///         &CString::new("test_mod").unwrap(),
1797///     )?;
1798///
1799///     pyo3_async_runtimes::generic::into_stream_v2::<MyCustomRuntime>(test_mod.call_method0("gen")?)
1800/// })?;
1801///
1802/// let vals = stream
1803///     .map(|item| Python::with_gil(|py| -> PyResult<i32> { Ok(item.bind(py).extract()?) }))
1804///     .try_collect::<Vec<i32>>()
1805///     .await?;
1806///
1807/// assert_eq!((0..10).collect::<Vec<i32>>(), vals);
1808///
1809/// Ok(())
1810/// # }
1811/// ```
1812#[cfg(feature = "unstable-streams")]
1813pub fn into_stream_v2<R>(
1814    gen: Bound<'_, PyAny>,
1815) -> PyResult<impl futures::Stream<Item = PyObject> + 'static>
1816where
1817    R: Runtime + ContextExt,
1818{
1819    into_stream_with_locals_v2::<R>(get_current_locals::<R>(gen.py())?, gen)
1820}