pyo3_async/
asyncio.rs

1//! `asyncio` compatible coroutine and async generator implementation.
2use std::{
3    future::Future,
4    pin::Pin,
5    task::{ready, Context, Poll},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use pyo3::{
10    exceptions::{PyStopAsyncIteration, PyStopIteration},
11    intern,
12    prelude::*,
13};
14
15use crate::{coroutine, utils};
16
17utils::module!(Asyncio, "asyncio", Future);
18
19fn asyncio_future(py: Python) -> PyResult<PyObject> {
20    Asyncio::get(py)?.Future.call0(py)
21}
22
23pub(crate) struct Waker {
24    call_soon_threadsafe: PyObject,
25    future: PyObject,
26}
27
28impl coroutine::CoroutineWaker for Waker {
29    fn new(py: Python) -> PyResult<Self> {
30        let future = asyncio_future(py)?;
31        let call_soon_threadsafe = future
32            .call_method0(py, intern!(py, "get_loop"))?
33            .getattr(py, intern!(py, "call_soon_threadsafe"))?;
34        Ok(Waker {
35            call_soon_threadsafe,
36            future,
37        })
38    }
39
40    fn yield_(&self, py: Python) -> PyResult<PyObject> {
41        self.future
42            .call_method0(py, intern!(py, "__await__"))?
43            .call_method0(py, intern!(py, "__next__"))
44    }
45
46    fn wake(&self, py: Python) {
47        self.future
48            .call_method1(py, intern!(py, "set_result"), (py.None(),))
49            .expect("error while calling EventLoop.call_soon_threadsafe");
50    }
51
52    fn wake_threadsafe(&self, py: Python) {
53        let set_result = self
54            .future
55            .getattr(py, intern!(py, "set_result"))
56            .expect("error while calling Future.set_result");
57        self.call_soon_threadsafe
58            .call1(py, (set_result, py.None()))
59            .expect("error while calling EventLoop.call_soon_threadsafe");
60    }
61
62    fn update(&mut self, py: Python) -> PyResult<()> {
63        self.future = Asyncio::get(py)?.Future.call0(py)?;
64        Ok(())
65    }
66
67    fn raise(&self, py: Python) -> PyResult<()> {
68        self.future.call_method0(py, intern!(py, "result"))?;
69        Ok(())
70    }
71}
72
73utils::generate!(Waker);
74
75/// [`Future`] wrapper for a Python awaitable (in `asyncio` context).
76///
77/// The future should be polled in the thread where the event loop is running.
78pub struct AwaitableWrapper {
79    future_iter: PyObject,
80    future: Option<PyObject>,
81}
82
83impl AwaitableWrapper {
84    /// Wrap a Python awaitable.
85    pub fn new(awaitable: &PyAny) -> PyResult<Self> {
86        Ok(Self {
87            future_iter: awaitable
88                .call_method0(intern!(awaitable.py(), "__await__"))?
89                .extract()?,
90            future: None,
91        })
92    }
93
94    /// GIL-bound [`Future`] reference.
95    pub fn as_mut<'a>(
96        &'a mut self,
97        py: Python<'a>,
98    ) -> impl Future<Output = PyResult<PyObject>> + Unpin + 'a {
99        utils::WithGil { inner: self, py }
100    }
101}
102
103impl<'a> Future for utils::WithGil<'_, &'a mut AwaitableWrapper> {
104    type Output = PyResult<PyObject>;
105
106    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107        if let Some(fut) = self.inner.future.as_ref() {
108            fut.call_method0(self.py, intern!(self.py, "result"))?;
109        }
110        match self
111            .inner
112            .future_iter
113            .call_method0(self.py, intern!(self.py, "__next__"))
114        {
115            Ok(future) => {
116                let callback = utils::wake_callback(self.py, cx.waker().clone())?;
117                future.call_method1(self.py, intern!(self.py, "add_done_callback"), (callback,))?;
118                self.inner.future = Some(future);
119                Poll::Pending
120            }
121            Err(err) if err.is_instance_of::<PyStopIteration>(self.py) => Poll::Ready(Ok(err
122                .value(self.py)
123                .getattr(intern!(self.py, "value"))?
124                .into())),
125            Err(err) => Poll::Ready(Err(err)),
126        }
127    }
128}
129
130impl Future for AwaitableWrapper {
131    type Output = PyResult<PyObject>;
132
133    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
134        Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_unpin(cx))
135    }
136}
137
138/// [`Future`] wrapper for Python future.
139///
140/// Because its duck-typed, it can work either with [`asyncio.Future`](https://docs.python.org/3/library/asyncio-future.html#asyncio.Future) or [`concurrent.futures.Future`](https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future).
141#[derive(Debug)]
142pub struct FutureWrapper {
143    future: PyObject,
144    cancel_on_drop: Option<CancelOnDrop>,
145}
146
147/// Cancel-on-drop error handling policy (see [`FutureWrapper::new`]).
148#[derive(Debug, Copy, Clone)]
149pub enum CancelOnDrop {
150    IgnoreError,
151    PanicOnError,
152}
153
154impl FutureWrapper {
155    /// Wrap a Python future.
156    ///
157    /// If `cancel_on_drop` is not `None`, the Python future will be cancelled, and error may be
158    /// handled following the provided policy.
159    pub fn new(future: impl Into<PyObject>, cancel_on_drop: Option<CancelOnDrop>) -> Self {
160        Self {
161            future: future.into(),
162            cancel_on_drop,
163        }
164    }
165
166    /// GIL-bound [`Future`] reference.
167    pub fn as_mut<'a>(
168        &'a mut self,
169        py: Python<'a>,
170    ) -> impl Future<Output = PyResult<PyObject>> + Unpin + 'a {
171        utils::WithGil { inner: self, py }
172    }
173}
174
175impl<'a> Future for utils::WithGil<'_, &'a mut FutureWrapper> {
176    type Output = PyResult<PyObject>;
177
178    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179        if self
180            .inner
181            .future
182            .call_method0(self.py, intern!(self.py, "done"))?
183            .is_true(self.py)?
184        {
185            self.inner.cancel_on_drop = None;
186            return Poll::Ready(
187                self.inner
188                    .future
189                    .call_method0(self.py, intern!(self.py, "result")),
190            );
191        }
192        let callback = utils::wake_callback(self.py, cx.waker().clone())?;
193        self.inner.future.call_method1(
194            self.py,
195            intern!(self.py, "add_done_callback"),
196            (callback,),
197        )?;
198        Poll::Pending
199    }
200}
201
202impl Future for FutureWrapper {
203    type Output = PyResult<PyObject>;
204
205    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206        Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_unpin(cx))
207    }
208}
209
210impl Drop for FutureWrapper {
211    fn drop(&mut self) {
212        if let Some(cancel) = self.cancel_on_drop {
213            let res = Python::with_gil(|gil| self.future.call_method0(gil, intern!(gil, "cancel")));
214            if let (Err(err), CancelOnDrop::PanicOnError) = (res, cancel) {
215                panic!("Cancel error while dropping FutureWrapper: {err:?}");
216            }
217        }
218    }
219}
220
221/// [`Stream`] wrapper for a Python async generator (in `asyncio` context).
222///
223/// The stream should be polled in the thread where the event loop is running.
224///
225/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
226pub struct AsyncGeneratorWrapper {
227    async_generator: PyObject,
228    next: Option<AwaitableWrapper>,
229}
230
231impl AsyncGeneratorWrapper {
232    /// Wrap a Python async generator.
233    pub fn new(async_generator: &PyAny) -> Self {
234        Self {
235            async_generator: async_generator.into(),
236            next: None,
237        }
238    }
239
240    /// GIL-bound [`Stream`] reference.
241    ///
242    /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
243    pub fn as_mut<'a>(
244        &'a mut self,
245        py: Python<'a>,
246    ) -> impl Stream<Item = PyResult<PyObject>> + Unpin + 'a {
247        utils::WithGil { inner: self, py }
248    }
249}
250
251impl<'a> Stream for utils::WithGil<'_, &'a mut AsyncGeneratorWrapper> {
252    type Item = PyResult<PyObject>;
253
254    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255        if self.inner.next.is_none() {
256            let next = self
257                .inner
258                .async_generator
259                .as_ref(self.py)
260                .call_method0(intern!(self.py, "__anext__"))?;
261            self.inner.next = Some(AwaitableWrapper::new(next)?);
262        }
263        let res = ready!(self.inner.next.as_mut().unwrap().poll_unpin(cx));
264        self.inner.next = None;
265        Poll::Ready(match res {
266            Ok(obj) => Some(Ok(obj)),
267            Err(err) if err.is_instance_of::<PyStopAsyncIteration>(self.py) => None,
268            Err(err) => Some(Err(err)),
269        })
270    }
271}
272
273impl Stream for AsyncGeneratorWrapper {
274    type Item = PyResult<PyObject>;
275
276    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277        Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_next_unpin(cx))
278    }
279}