pyo3_async/
utils.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2
3use pyo3::{exceptions::PyStopIteration, prelude::*, pyclass::IterNextOutput, types::PyCFunction};
4
5// Don't use `std::thread::current` because of unnecessary Arc clone + drop.
6pub(crate) type ThreadId = usize;
7pub(crate) fn current_thread_id() -> ThreadId {
8    static THREAD_COUNTER: AtomicUsize = AtomicUsize::new(0);
9    thread_local! {
10        pub(crate) static THREAD_ID: ThreadId = THREAD_COUNTER.fetch_add(1, Ordering::Relaxed);
11    }
12    THREAD_ID.with(|id| *id)
13}
14
15pub(crate) struct WithGil<'py, T> {
16    pub(crate) inner: T,
17    pub(crate) py: Python<'py>,
18}
19
20pub(crate) fn wake_callback(py: Python, waker: std::task::Waker) -> PyResult<&PyAny> {
21    let func = PyCFunction::new_closure(py, None, None, move |_, _| waker.wake_by_ref())?;
22    Ok(func)
23}
24
25macro_rules! module {
26    ($name:ident ,$path:literal, $($field:ident),* $(,)?) => {
27        #[allow(non_upper_case_globals)]
28        static $name: ::pyo3::sync::GILOnceCell<$name> = ::pyo3::sync::GILOnceCell::new();
29
30        #[allow(non_snake_case)]
31        struct $name {
32            $($field: PyObject),*
33        }
34
35        impl $name {
36            fn get(py: Python) -> PyResult<&Self> {
37                $name.get_or_try_init(py, || {
38                    let module = py.import($path)?;
39                    Ok(Self {
40                        $($field: module.getattr(stringify!($field))?.into(),)*
41                    })
42                })
43            }
44        }
45    };
46}
47
48pub(crate) use module;
49
50pub(crate) fn poll_result(result: IterNextOutput<PyObject, PyObject>) -> PyResult<PyObject> {
51    match result {
52        IterNextOutput::Yield(ob) => Ok(ob),
53        IterNextOutput::Return(ob) => Err(PyStopIteration::new_err(ob)),
54    }
55}
56
57macro_rules! generate {
58    ($waker:ty) => {
59        /// Python coroutine wrapping a [`PyFuture`](crate::PyFuture).
60        #[pyclass]
61        pub struct Coroutine($crate::coroutine::Coroutine<$waker>);
62
63        impl Coroutine {
64            /// Wrap a boxed future in to a Python coroutine.
65            ///
66            /// If `throw` callback is provided:
67            /// - coroutine `throw` method will call it with the passed exception before polling;
68            /// - coroutine `close` method will call it with `None` before polling and dropping
69            ///   the future.
70            /// If `throw` callback is not provided, the future will dropped without additional
71            /// poll.
72            pub fn new(
73                future: ::std::pin::Pin<Box<dyn $crate::PyFuture>>,
74                throw: Option<$crate::ThrowCallback>,
75            ) -> Self {
76                Self($crate::coroutine::Coroutine::new(future, throw))
77            }
78
79            /// Wrap a generic future into a Python coroutine.
80            pub fn from_future(future: impl $crate::PyFuture + 'static) -> Self {
81                Self::new(Box::pin(future), None)
82            }
83        }
84
85        #[pymethods]
86        impl Coroutine {
87            fn send(&mut self, py: Python, _value: &PyAny) -> PyResult<PyObject> {
88                $crate::utils::poll_result(self.0.poll(py, None)?)
89            }
90
91            fn throw(&mut self, py: Python, exc: &PyAny) -> PyResult<PyObject> {
92                $crate::utils::poll_result(self.0.poll(py, Some(PyErr::from_value(exc)))?)
93            }
94
95            fn close(&mut self, py: Python) -> PyResult<()> {
96                self.0.close(py)
97            }
98
99            fn __await__(self_: &PyCell<Self>) -> PyResult<&PyAny> {
100                Ok(self_)
101            }
102
103            fn __iter__(self_: &PyCell<Self>) -> PyResult<&PyAny> {
104                Ok(self_)
105            }
106
107            fn __next__(
108                &mut self,
109                py: Python,
110            ) -> PyResult<::pyo3::pyclass::IterNextOutput<PyObject, PyObject>> {
111                self.0.poll(py, None)
112            }
113        }
114
115        impl $crate::async_generator::CoroutineFactory for Coroutine {
116            type Coroutine = Self;
117            fn coroutine(future: impl $crate::PyFuture + 'static) -> Self::Coroutine {
118                Self::from_future(future)
119            }
120        }
121
122        /// Python async generator wrapping a [`PyStream`](crate::PyStream).
123        #[pyclass]
124        pub struct AsyncGenerator($crate::async_generator::AsyncGenerator<Coroutine>);
125
126        impl AsyncGenerator {
127            /// Wrap a boxed stream in to a Python async generator.
128            ///
129            /// If `throw` callback is provided:
130            /// - async generator `athrow` method will call it with the passed exception
131            ///   before polling;
132            /// - async generator `aclose` method will call it with `None` before polling and
133            ///   dropping the stream.
134            /// If `throw` callback is not provided, the stream will dropped without additional
135            /// poll.
136            pub fn new(
137                stream: ::std::pin::Pin<Box<dyn $crate::PyStream>>,
138                throw: Option<$crate::ThrowCallback>,
139            ) -> Self {
140                Self($crate::async_generator::AsyncGenerator::new(stream, throw))
141            }
142
143            /// Wrap a generic stream.
144            pub fn from_stream(stream: impl $crate::PyStream + 'static) -> Self {
145                Self::new(Box::pin(stream), None)
146            }
147        }
148
149        #[pymethods]
150        impl AsyncGenerator {
151            fn asend(&mut self, py: Python, _value: &PyAny) -> PyResult<PyObject> {
152                self.0.next(py)
153            }
154
155            fn athrow(&mut self, py: Python, exc: &PyAny) -> PyResult<PyObject> {
156                self.0.throw(py, PyErr::from_value(exc))
157            }
158
159            fn aclose(&mut self, py: Python) -> PyResult<PyObject> {
160                self.0.close(py)
161            }
162
163            fn __aiter__(self_: &PyCell<Self>) -> PyResult<&PyAny> {
164                Ok(self_)
165            }
166
167            // `Option` because https://github.com/PyO3/pyo3/issues/3190
168            fn __anext__(&mut self, py: Python) -> PyResult<Option<PyObject>> {
169                self.0.next(py).map(Some)
170            }
171        }
172    };
173}
174pub(crate) use generate;