use std::{pin::Pin, sync::Arc};
use futures::{Stream, StreamExt};
use icechunk::zarr::StoreError;
use pyo3::{exceptions::PyStopAsyncIteration, prelude::*};
use tokio::sync::Mutex;
type PyObjectStream =
Arc<Mutex<Pin<Box<dyn Stream<Item = Result<PyObject, StoreError>> + Send>>>>;
#[pyclass]
pub(crate) struct PyAsyncGenerator {
stream: PyObjectStream,
}
impl PyAsyncGenerator {
pub(crate) fn new(stream: PyObjectStream) -> Self {
Self { stream }
}
}
#[pymethods]
impl PyAsyncGenerator {
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __anext__<'py>(
slf: PyRefMut<'py, Self>,
py: Python<'py>,
) -> PyResult<Option<PyObject>> {
let stream = slf.stream.clone();
let future = async move {
let mut unlocked = stream.lock().await;
let next = unlocked.next().await;
drop(unlocked);
match next {
Some(Ok(val)) => Ok(Some(val)),
Some(Err(_e)) => Ok(None),
None => Err(PyStopAsyncIteration::new_err("The iterator is exhausted")),
}
};
let result = pyo3_asyncio_0_21::tokio::future_into_py(py, future)?;
Ok(Some(result.to_object(py)))
}
}