pyo3-async 0.3.2

PyO3 bindings to various Python asynchronous frameworks.
Documentation
use std::{
    pin::Pin,
    sync::Arc,
    task::{Context, Poll},
};

use futures::task::ArcWake;
use pyo3::{exceptions::PyRuntimeError, iter::IterNextOutput, prelude::*};

use crate::{
    utils::{current_thread_id, ThreadId},
    PyFuture, ThrowCallback,
};

pub(crate) trait CoroutineWaker: Sized {
    fn new(py: Python) -> PyResult<Self>;
    fn yield_(&self, py: Python) -> PyResult<PyObject>;
    fn wake(&self, py: Python);
    fn wake_threadsafe(&self, py: Python);
    fn update(&mut self, _py: Python) -> PyResult<()> {
        Ok(())
    }
    fn raise(&self, _py: Python) -> PyResult<()> {
        Ok(())
    }
}

pub(crate) struct Waker<W> {
    inner: W,
    thread_id: ThreadId,
}

impl<W: CoroutineWaker + Send + Sync> ArcWake for Waker<W> {
    fn wake_by_ref(arc_self: &Arc<Self>) {
        if current_thread_id() == arc_self.thread_id {
            Python::with_gil(|gil| CoroutineWaker::wake(&arc_self.inner, gil))
        } else {
            Python::with_gil(|gil| CoroutineWaker::wake_threadsafe(&arc_self.inner, gil))
        }
    }
}

pub(crate) struct Coroutine<W> {
    future: Option<Pin<Box<dyn PyFuture>>>,
    throw: Option<ThrowCallback>,
    waker: Option<Arc<Waker<W>>>,
}

impl<W> Coroutine<W> {
    pub(crate) fn new(future: Pin<Box<dyn PyFuture>>, throw: Option<ThrowCallback>) -> Self {
        Self {
            future: Some(future),
            throw,
            waker: None,
        }
    }

    pub(crate) fn close(&mut self, py: Python) -> PyResult<()> {
        if let Some(mut future_rs) = self.future.take() {
            if let Some(ref mut throw) = self.throw {
                throw(py, None);
                let waker = futures::task::noop_waker();
                let res = future_rs
                    .as_mut()
                    .poll_py(py, &mut Context::from_waker(&waker));
                if let Poll::Ready(Err(err)) = res {
                    return Err(err);
                }
            }
        }
        Ok(())
    }
}

impl<W: CoroutineWaker + Send + Sync + 'static> Coroutine<W> {
    pub(crate) fn poll(
        &mut self,
        py: Python,
        exc: Option<PyErr>,
    ) -> PyResult<IterNextOutput<PyObject, PyObject>> {
        let Some(ref mut future_rs) = self.future else {
            return Err(PyRuntimeError::new_err(
                "cannot reuse already awaited coroutine",
            ));
        };
        let exc = exc.or_else(|| self.waker.as_ref().and_then(|w| w.inner.raise(py).err()));
        match (exc, &mut self.throw) {
            (Some(exc), Some(throw)) => throw(py, Some(exc)),
            (Some(exc), _) => {
                self.future.take();
                return Err(exc);
            }
            _ => {}
        }
        if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
            waker.inner.update(py)?;
        } else {
            self.waker = Some(Arc::new(Waker {
                inner: W::new(py)?,
                thread_id: current_thread_id(),
            }));
        }
        let waker = futures::task::waker(self.waker.clone().unwrap());
        let res = future_rs
            .as_mut()
            .poll_py(py, &mut Context::from_waker(&waker));
        Ok(match res {
            Poll::Ready(res) => {
                self.future.take();
                IterNextOutput::Return(res?)
            }
            Poll::Pending => IterNextOutput::Yield(self.waker.as_ref().unwrap().inner.yield_(py)?),
        })
    }
}