1use std::{
3 future::Future,
4 pin::Pin,
5 task::{ready, Context, Poll},
6};
7
8use futures::{FutureExt, Stream, StreamExt};
9use pyo3::{
10 exceptions::{PyStopAsyncIteration, PyStopIteration},
11 intern,
12 prelude::*,
13};
14
15use crate::{coroutine, utils};
16
17utils::module!(Asyncio, "asyncio", Future);
18
19fn asyncio_future(py: Python) -> PyResult<PyObject> {
20 Asyncio::get(py)?.Future.call0(py)
21}
22
23pub(crate) struct Waker {
24 call_soon_threadsafe: PyObject,
25 future: PyObject,
26}
27
28impl coroutine::CoroutineWaker for Waker {
29 fn new(py: Python) -> PyResult<Self> {
30 let future = asyncio_future(py)?;
31 let call_soon_threadsafe = future
32 .call_method0(py, intern!(py, "get_loop"))?
33 .getattr(py, intern!(py, "call_soon_threadsafe"))?;
34 Ok(Waker {
35 call_soon_threadsafe,
36 future,
37 })
38 }
39
40 fn yield_(&self, py: Python) -> PyResult<PyObject> {
41 self.future
42 .call_method0(py, intern!(py, "__await__"))?
43 .call_method0(py, intern!(py, "__next__"))
44 }
45
46 fn wake(&self, py: Python) {
47 self.future
48 .call_method1(py, intern!(py, "set_result"), (py.None(),))
49 .expect("error while calling EventLoop.call_soon_threadsafe");
50 }
51
52 fn wake_threadsafe(&self, py: Python) {
53 let set_result = self
54 .future
55 .getattr(py, intern!(py, "set_result"))
56 .expect("error while calling Future.set_result");
57 self.call_soon_threadsafe
58 .call1(py, (set_result, py.None()))
59 .expect("error while calling EventLoop.call_soon_threadsafe");
60 }
61
62 fn update(&mut self, py: Python) -> PyResult<()> {
63 self.future = Asyncio::get(py)?.Future.call0(py)?;
64 Ok(())
65 }
66
67 fn raise(&self, py: Python) -> PyResult<()> {
68 self.future.call_method0(py, intern!(py, "result"))?;
69 Ok(())
70 }
71}
72
73utils::generate!(Waker);
74
75pub struct AwaitableWrapper {
79 future_iter: PyObject,
80 future: Option<PyObject>,
81}
82
83impl AwaitableWrapper {
84 pub fn new(awaitable: &PyAny) -> PyResult<Self> {
86 Ok(Self {
87 future_iter: awaitable
88 .call_method0(intern!(awaitable.py(), "__await__"))?
89 .extract()?,
90 future: None,
91 })
92 }
93
94 pub fn as_mut<'a>(
96 &'a mut self,
97 py: Python<'a>,
98 ) -> impl Future<Output = PyResult<PyObject>> + Unpin + 'a {
99 utils::WithGil { inner: self, py }
100 }
101}
102
103impl<'a> Future for utils::WithGil<'_, &'a mut AwaitableWrapper> {
104 type Output = PyResult<PyObject>;
105
106 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
107 if let Some(fut) = self.inner.future.as_ref() {
108 fut.call_method0(self.py, intern!(self.py, "result"))?;
109 }
110 match self
111 .inner
112 .future_iter
113 .call_method0(self.py, intern!(self.py, "__next__"))
114 {
115 Ok(future) => {
116 let callback = utils::wake_callback(self.py, cx.waker().clone())?;
117 future.call_method1(self.py, intern!(self.py, "add_done_callback"), (callback,))?;
118 self.inner.future = Some(future);
119 Poll::Pending
120 }
121 Err(err) if err.is_instance_of::<PyStopIteration>(self.py) => Poll::Ready(Ok(err
122 .value(self.py)
123 .getattr(intern!(self.py, "value"))?
124 .into())),
125 Err(err) => Poll::Ready(Err(err)),
126 }
127 }
128}
129
130impl Future for AwaitableWrapper {
131 type Output = PyResult<PyObject>;
132
133 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
134 Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_unpin(cx))
135 }
136}
137
138#[derive(Debug)]
142pub struct FutureWrapper {
143 future: PyObject,
144 cancel_on_drop: Option<CancelOnDrop>,
145}
146
147#[derive(Debug, Copy, Clone)]
149pub enum CancelOnDrop {
150 IgnoreError,
151 PanicOnError,
152}
153
154impl FutureWrapper {
155 pub fn new(future: impl Into<PyObject>, cancel_on_drop: Option<CancelOnDrop>) -> Self {
160 Self {
161 future: future.into(),
162 cancel_on_drop,
163 }
164 }
165
166 pub fn as_mut<'a>(
168 &'a mut self,
169 py: Python<'a>,
170 ) -> impl Future<Output = PyResult<PyObject>> + Unpin + 'a {
171 utils::WithGil { inner: self, py }
172 }
173}
174
175impl<'a> Future for utils::WithGil<'_, &'a mut FutureWrapper> {
176 type Output = PyResult<PyObject>;
177
178 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
179 if self
180 .inner
181 .future
182 .call_method0(self.py, intern!(self.py, "done"))?
183 .is_true(self.py)?
184 {
185 self.inner.cancel_on_drop = None;
186 return Poll::Ready(
187 self.inner
188 .future
189 .call_method0(self.py, intern!(self.py, "result")),
190 );
191 }
192 let callback = utils::wake_callback(self.py, cx.waker().clone())?;
193 self.inner.future.call_method1(
194 self.py,
195 intern!(self.py, "add_done_callback"),
196 (callback,),
197 )?;
198 Poll::Pending
199 }
200}
201
202impl Future for FutureWrapper {
203 type Output = PyResult<PyObject>;
204
205 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206 Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_unpin(cx))
207 }
208}
209
210impl Drop for FutureWrapper {
211 fn drop(&mut self) {
212 if let Some(cancel) = self.cancel_on_drop {
213 let res = Python::with_gil(|gil| self.future.call_method0(gil, intern!(gil, "cancel")));
214 if let (Err(err), CancelOnDrop::PanicOnError) = (res, cancel) {
215 panic!("Cancel error while dropping FutureWrapper: {err:?}");
216 }
217 }
218 }
219}
220
221pub struct AsyncGeneratorWrapper {
227 async_generator: PyObject,
228 next: Option<AwaitableWrapper>,
229}
230
231impl AsyncGeneratorWrapper {
232 pub fn new(async_generator: &PyAny) -> Self {
234 Self {
235 async_generator: async_generator.into(),
236 next: None,
237 }
238 }
239
240 pub fn as_mut<'a>(
244 &'a mut self,
245 py: Python<'a>,
246 ) -> impl Stream<Item = PyResult<PyObject>> + Unpin + 'a {
247 utils::WithGil { inner: self, py }
248 }
249}
250
251impl<'a> Stream for utils::WithGil<'_, &'a mut AsyncGeneratorWrapper> {
252 type Item = PyResult<PyObject>;
253
254 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
255 if self.inner.next.is_none() {
256 let next = self
257 .inner
258 .async_generator
259 .as_ref(self.py)
260 .call_method0(intern!(self.py, "__anext__"))?;
261 self.inner.next = Some(AwaitableWrapper::new(next)?);
262 }
263 let res = ready!(self.inner.next.as_mut().unwrap().poll_unpin(cx));
264 self.inner.next = None;
265 Poll::Ready(match res {
266 Ok(obj) => Some(Ok(obj)),
267 Err(err) if err.is_instance_of::<PyStopAsyncIteration>(self.py) => None,
268 Err(err) => Some(Err(err)),
269 })
270 }
271}
272
273impl Stream for AsyncGeneratorWrapper {
274 type Item = PyResult<PyObject>;
275
276 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
277 Python::with_gil(|gil| Pin::into_inner(self).as_mut(gil).poll_next_unpin(cx))
278 }
279}