pyo3_arrow/
record_batch_reader.rs

1use std::fmt::Display;
2use std::sync::{Arc, Mutex};
3
4use arrow_array::{ArrayRef, RecordBatchIterator, RecordBatchReader, StructArray};
5use arrow_schema::{Field, SchemaRef};
6use pyo3::exceptions::{PyIOError, PyStopIteration, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyCapsule, PyTuple, PyType};
9use pyo3::{intern, IntoPyObjectExt};
10
11use crate::error::PyArrowResult;
12use crate::export::{Arro3RecordBatch, Arro3Schema, Arro3Table};
13use crate::ffi::from_python::utils::import_stream_pycapsule;
14use crate::ffi::to_python::chunked::ArrayIterator;
15use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
16use crate::ffi::to_python::to_stream_pycapsule;
17use crate::ffi::to_schema_pycapsule;
18use crate::input::AnyRecordBatch;
19use crate::schema::display_schema;
20use crate::{PyRecordBatch, PySchema, PyTable};
21
22/// A Python-facing Arrow record batch reader.
23///
24/// This is a wrapper around a [RecordBatchReader].
25#[pyclass(
26    module = "arro3.core._core",
27    name = "RecordBatchReader",
28    subclass,
29    frozen
30)]
31pub struct PyRecordBatchReader(pub(crate) Mutex<Option<Box<dyn RecordBatchReader + Send>>>);
32
33impl PyRecordBatchReader {
34    /// Construct a new PyRecordBatchReader from an existing [RecordBatchReader].
35    pub fn new(reader: Box<dyn RecordBatchReader + Send>) -> Self {
36        Self(Mutex::new(Some(reader)))
37    }
38
39    /// Construct from a raw Arrow C Stream capsule
40    pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
41        let stream = import_stream_pycapsule(capsule)?;
42        let stream_reader = arrow_array::ffi_stream::ArrowArrayStreamReader::try_new(stream)
43            .map_err(|err| PyValueError::new_err(err.to_string()))?;
44
45        Ok(Self::new(Box::new(stream_reader)))
46    }
47
48    /// Consume this reader and convert into a [RecordBatchReader].
49    ///
50    /// The reader can only be consumed once. Calling `into_reader`
51    pub fn into_reader(self) -> PyResult<Box<dyn RecordBatchReader + Send>> {
52        let stream = self
53            .0
54            .lock()
55            .unwrap()
56            .take()
57            .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
58        Ok(stream)
59    }
60
61    /// Consume this reader and create a [PyTable] object
62    pub fn into_table(self) -> PyArrowResult<PyTable> {
63        let stream = self
64            .0
65            .lock()
66            .unwrap()
67            .take()
68            .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
69        let schema = stream.schema();
70        let mut batches = vec![];
71        for batch in stream {
72            batches.push(batch?);
73        }
74        Ok(PyTable::try_new(batches, schema)?)
75    }
76
77    /// Access the [SchemaRef] of this RecordBatchReader.
78    ///
79    /// If the stream has already been consumed, this method will error.
80    pub fn schema_ref(&self) -> PyResult<SchemaRef> {
81        let inner = self.0.lock().unwrap();
82        let stream = inner
83            .as_ref()
84            .ok_or(PyIOError::new_err("Stream already closed."))?;
85        Ok(stream.schema())
86    }
87
88    /// Export this to a Python `arro3.core.RecordBatchReader`.
89    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
90        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
91        arro3_mod
92            .getattr(intern!(py, "RecordBatchReader"))?
93            .call_method1(
94                intern!(py, "from_arrow_pycapsule"),
95                PyTuple::new(py, vec![self.__arrow_c_stream__(py, None)?])?,
96            )
97    }
98
99    /// Export this to a Python `arro3.core.RecordBatchReader`.
100    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
101        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
102        let reader = self
103            .0
104            .lock()
105            .unwrap()
106            .take()
107            .ok_or(PyIOError::new_err("Cannot read from closed stream"))?;
108        let capsule = Self::to_stream_pycapsule(py, reader, None)?;
109        arro3_mod
110            .getattr(intern!(py, "RecordBatchReader"))?
111            .call_method1(
112                intern!(py, "from_arrow_pycapsule"),
113                PyTuple::new(py, vec![capsule])?,
114            )
115    }
116
117    /// Export this to a Python `nanoarrow.ArrayStream`.
118    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
119        to_nanoarrow_array_stream(py, &self.__arrow_c_stream__(py, None)?)
120    }
121
122    /// Export to a pyarrow.RecordBatchReader
123    ///
124    /// Requires pyarrow >=15
125    pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
126        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
127        let record_batch_reader_class = pyarrow_mod.getattr(intern!(py, "RecordBatchReader"))?;
128        let pyarrow_obj = record_batch_reader_class.call_method1(
129            intern!(py, "from_stream"),
130            PyTuple::new(py, vec![self.into_pyobject(py)?])?,
131        )?;
132        pyarrow_obj.into_py_any(py)
133    }
134
135    pub(crate) fn to_stream_pycapsule<'py>(
136        py: Python<'py>,
137        reader: Box<dyn RecordBatchReader + Send>,
138        requested_schema: Option<Bound<'py, PyCapsule>>,
139    ) -> PyArrowResult<Bound<'py, PyCapsule>> {
140        let schema = reader.schema().clone();
141        let array_reader = reader.into_iter().map(|maybe_batch| {
142            let arr: ArrayRef = Arc::new(StructArray::from(maybe_batch?));
143            Ok(arr)
144        });
145        let array_reader = Box::new(ArrayIterator::new(
146            array_reader,
147            Field::new_struct("", schema.fields().clone(), false)
148                .with_metadata(schema.metadata.clone())
149                .into(),
150        ));
151        to_stream_pycapsule(py, array_reader, requested_schema)
152    }
153}
154
155impl From<Box<dyn RecordBatchReader + Send>> for PyRecordBatchReader {
156    fn from(value: Box<dyn RecordBatchReader + Send>) -> Self {
157        Self::new(value)
158    }
159}
160
161impl Display for PyRecordBatchReader {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        writeln!(f, "arro3.core.RecordBatchReader")?;
164        writeln!(f, "-----------------------")?;
165        if let Ok(schema) = self.schema_ref() {
166            display_schema(&schema, f)
167        } else {
168            writeln!(f, "Closed stream")
169        }
170    }
171}
172
173#[pymethods]
174impl PyRecordBatchReader {
175    fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
176        to_schema_pycapsule(py, self.schema_ref()?.as_ref())
177    }
178
179    #[pyo3(signature = (requested_schema=None))]
180    fn __arrow_c_stream__<'py>(
181        &'py self,
182        py: Python<'py>,
183        requested_schema: Option<Bound<'py, PyCapsule>>,
184    ) -> PyArrowResult<Bound<'py, PyCapsule>> {
185        let reader = self
186            .0
187            .lock()
188            .unwrap()
189            .take()
190            .ok_or(PyIOError::new_err("Cannot read from closed stream"))?;
191        Self::to_stream_pycapsule(py, reader, requested_schema)
192    }
193
194    // Return self
195    // https://stackoverflow.com/a/52056290
196    fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
197        slf
198    }
199
200    fn __next__(&self) -> PyArrowResult<Arro3RecordBatch> {
201        self.read_next_batch()
202    }
203
204    fn __repr__(&self) -> String {
205        self.to_string()
206    }
207
208    #[classmethod]
209    fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
210        let reader = input.into_reader()?;
211        Ok(Self::new(reader))
212    }
213
214    #[classmethod]
215    #[pyo3(name = "from_arrow_pycapsule")]
216    fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
217        Self::from_arrow_pycapsule(capsule)
218    }
219
220    #[classmethod]
221    fn from_batches(_cls: &Bound<PyType>, schema: PySchema, batches: Vec<PyRecordBatch>) -> Self {
222        let batches = batches
223            .into_iter()
224            .map(|batch| batch.into_inner())
225            .collect::<Vec<_>>();
226        Self::new(Box::new(RecordBatchIterator::new(
227            batches.into_iter().map(Ok),
228            schema.into_inner(),
229        )))
230    }
231
232    #[classmethod]
233    fn from_stream(_cls: &Bound<PyType>, data: &Bound<PyAny>) -> PyResult<Self> {
234        data.extract()
235    }
236
237    #[getter]
238    fn closed(&self) -> bool {
239        self.0.lock().unwrap().is_none()
240    }
241
242    fn read_all(&self) -> PyArrowResult<Arro3Table> {
243        let stream = self
244            .0
245            .lock()
246            .unwrap()
247            .take()
248            .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
249        let schema = stream.schema();
250        let mut batches = vec![];
251        for batch in stream {
252            batches.push(batch?);
253        }
254        Ok(PyTable::try_new(batches, schema)?.into())
255    }
256
257    fn read_next_batch(&self) -> PyArrowResult<Arro3RecordBatch> {
258        let mut inner = self.0.lock().unwrap();
259        let stream = inner
260            .as_mut()
261            .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
262
263        if let Some(next_batch) = stream.next() {
264            Ok(next_batch?.into())
265        } else {
266            Err(PyStopIteration::new_err("").into())
267        }
268    }
269
270    #[getter]
271    fn schema(&self) -> PyResult<Arro3Schema> {
272        Ok(PySchema::new(self.schema_ref()?.clone()).into())
273    }
274}