pyo3_async/
allow_threads.rs

1use std::{
2    future::Future,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use futures::Stream;
8use pin_project::pin_project;
9use pyo3::Python;
10
11/// Wrapper for [`Future`]/[`Stream`] that releases GIL while polling in
12/// [`PyFuture`](crate::PyFuture)/[`PyStream`](crate::PyStream).
13///
14/// Can be instantiated with [`AllowThreadsExt::allow_threads`].
15///
16/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
17#[derive(Debug)]
18#[repr(transparent)]
19#[pin_project]
20pub struct AllowThreads<T>(#[pin] pub T);
21
22impl<F> Future for AllowThreads<F>
23where
24    F: Future + Send,
25    F::Output: Send,
26{
27    type Output = F::Output;
28
29    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
30        let this = self.project();
31        let waker = cx.waker();
32        Python::with_gil(|gil| gil.allow_threads(|| this.0.poll(&mut Context::from_waker(waker))))
33    }
34}
35
36impl<S> Stream for AllowThreads<S>
37where
38    S: Stream + Send,
39    S::Item: Send,
40{
41    type Item = S::Item;
42
43    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        let this = self.project();
45        let waker = cx.waker();
46        Python::with_gil(|gil| {
47            gil.allow_threads(|| this.0.poll_next(&mut Context::from_waker(waker)))
48        })
49    }
50}
51
52/// Extension trait to allow threads while polling [`Future`] or [`Stream`].
53///
54/// It is implemented for every types.
55///
56/// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html
57pub trait AllowThreadsExt: Sized {
58    fn allow_threads(self) -> AllowThreads<Self> {
59        AllowThreads(self)
60    }
61}
62
63impl<T> AllowThreadsExt for T {}