pyo3-async 0.3.2

PyO3 bindings to various Python asynchronous frameworks.
Documentation
use std::{
    marker::PhantomData,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{ready, Context, Poll},
};

use pyo3::{exceptions::PyStopAsyncIteration, prelude::*};

use crate::{PyFuture, PyStream, ThrowCallback};

type SharedStream = Arc<Mutex<Option<Pin<Box<dyn PyStream>>>>>;

struct PyStreamNext {
    stream: SharedStream,
    close: bool,
}

impl PyFuture for PyStreamNext {
    fn poll_py(self: Pin<&mut Self>, py: Python, cx: &mut Context) -> Poll<PyResult<PyObject>> {
        let err = || Err(PyStopAsyncIteration::new_err(py.None()));
        let this = Pin::into_inner(self);
        let mut guard = this.stream.lock().unwrap();
        let Some(ref mut stream) = *guard else {
            return Poll::Ready(err());
        };
        let opt_res = ready!(stream.as_mut().poll_next_py(py, cx));
        if let Some(res) = opt_res {
            if this.close {
                *guard = None;
            }
            return Poll::Ready(res);
        }
        *guard = None;
        Poll::Ready(err())
    }
}

pub(crate) trait CoroutineFactory {
    type Coroutine: IntoPy<PyObject>;
    fn coroutine(future: impl PyFuture + 'static) -> Self::Coroutine;
}

pub(crate) struct AsyncGenerator<C> {
    stream: SharedStream,
    throw: Option<ThrowCallback>,
    _phantom: PhantomData<C>,
}

impl<C> AsyncGenerator<C> {
    pub(crate) fn new(stream: Pin<Box<dyn PyStream>>, throw: Option<ThrowCallback>) -> Self {
        Self {
            stream: Arc::new(Mutex::new(Some(stream))),
            throw,
            _phantom: PhantomData,
        }
    }
}

impl<C: CoroutineFactory> AsyncGenerator<C> {
    pub(crate) fn _next(&mut self, py: Python, close: bool) -> PyResult<PyObject> {
        let stream = self.stream.clone();
        Ok(C::coroutine(PyStreamNext { stream, close }).into_py(py))
    }

    pub(crate) fn next(&mut self, py: Python) -> PyResult<PyObject> {
        self._next(py, false)
    }

    pub(crate) fn throw(&mut self, py: Python, exc: PyErr) -> PyResult<PyObject> {
        let Some(throw) = &mut self.throw else {
            return Ok(C::coroutine(async move { Err::<(), _>(exc) }).into_py(py));
        };
        throw(py, Some(exc));
        self._next(py, false)
    }

    pub(crate) fn close(&mut self, py: Python) -> PyResult<PyObject> {
        if let Some(throw) = &mut self.throw {
            throw(py, None);
        }
        self._next(py, true)
    }
}