pyo3_arrow/
record_batch_reader.rs1use std::fmt::Display;
2use std::sync::{Arc, Mutex};
3
4use arrow_array::{ArrayRef, RecordBatchIterator, RecordBatchReader, StructArray};
5use arrow_schema::{Field, SchemaRef};
6use pyo3::exceptions::{PyIOError, PyStopIteration, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyCapsule, PyTuple, PyType};
9use pyo3::{intern, IntoPyObjectExt};
10
11use crate::error::PyArrowResult;
12use crate::export::{Arro3RecordBatch, Arro3Schema, Arro3Table};
13use crate::ffi::from_python::utils::import_stream_pycapsule;
14use crate::ffi::to_python::chunked::ArrayIterator;
15use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
16use crate::ffi::to_python::to_stream_pycapsule;
17use crate::ffi::to_schema_pycapsule;
18use crate::input::AnyRecordBatch;
19use crate::schema::display_schema;
20use crate::{PyRecordBatch, PySchema, PyTable};
21
22#[pyclass(
26 module = "arro3.core._core",
27 name = "RecordBatchReader",
28 subclass,
29 frozen
30)]
31pub struct PyRecordBatchReader(pub(crate) Mutex<Option<Box<dyn RecordBatchReader + Send>>>);
32
33impl PyRecordBatchReader {
34 pub fn new(reader: Box<dyn RecordBatchReader + Send>) -> Self {
36 Self(Mutex::new(Some(reader)))
37 }
38
39 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
41 let stream = import_stream_pycapsule(capsule)?;
42 let stream_reader = arrow_array::ffi_stream::ArrowArrayStreamReader::try_new(stream)
43 .map_err(|err| PyValueError::new_err(err.to_string()))?;
44
45 Ok(Self::new(Box::new(stream_reader)))
46 }
47
48 pub fn into_reader(self) -> PyResult<Box<dyn RecordBatchReader + Send>> {
52 let stream = self
53 .0
54 .lock()
55 .unwrap()
56 .take()
57 .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
58 Ok(stream)
59 }
60
61 pub fn into_table(self) -> PyArrowResult<PyTable> {
63 let stream = self
64 .0
65 .lock()
66 .unwrap()
67 .take()
68 .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
69 let schema = stream.schema();
70 let mut batches = vec![];
71 for batch in stream {
72 batches.push(batch?);
73 }
74 Ok(PyTable::try_new(batches, schema)?)
75 }
76
77 pub fn schema_ref(&self) -> PyResult<SchemaRef> {
81 let inner = self.0.lock().unwrap();
82 let stream = inner
83 .as_ref()
84 .ok_or(PyIOError::new_err("Stream already closed."))?;
85 Ok(stream.schema())
86 }
87
88 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
90 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
91 arro3_mod
92 .getattr(intern!(py, "RecordBatchReader"))?
93 .call_method1(
94 intern!(py, "from_arrow_pycapsule"),
95 PyTuple::new(py, vec![self.__arrow_c_stream__(py, None)?])?,
96 )
97 }
98
99 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
101 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
102 let reader = self
103 .0
104 .lock()
105 .unwrap()
106 .take()
107 .ok_or(PyIOError::new_err("Cannot read from closed stream"))?;
108 let capsule = Self::to_stream_pycapsule(py, reader, None)?;
109 arro3_mod
110 .getattr(intern!(py, "RecordBatchReader"))?
111 .call_method1(
112 intern!(py, "from_arrow_pycapsule"),
113 PyTuple::new(py, vec![capsule])?,
114 )
115 }
116
117 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
119 to_nanoarrow_array_stream(py, &self.__arrow_c_stream__(py, None)?)
120 }
121
122 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
126 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
127 let record_batch_reader_class = pyarrow_mod.getattr(intern!(py, "RecordBatchReader"))?;
128 let pyarrow_obj = record_batch_reader_class.call_method1(
129 intern!(py, "from_stream"),
130 PyTuple::new(py, vec![self.into_pyobject(py)?])?,
131 )?;
132 pyarrow_obj.into_py_any(py)
133 }
134
135 pub(crate) fn to_stream_pycapsule<'py>(
136 py: Python<'py>,
137 reader: Box<dyn RecordBatchReader + Send>,
138 requested_schema: Option<Bound<'py, PyCapsule>>,
139 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
140 let schema = reader.schema().clone();
141 let array_reader = reader.into_iter().map(|maybe_batch| {
142 let arr: ArrayRef = Arc::new(StructArray::from(maybe_batch?));
143 Ok(arr)
144 });
145 let array_reader = Box::new(ArrayIterator::new(
146 array_reader,
147 Field::new_struct("", schema.fields().clone(), false)
148 .with_metadata(schema.metadata.clone())
149 .into(),
150 ));
151 to_stream_pycapsule(py, array_reader, requested_schema)
152 }
153}
154
155impl From<Box<dyn RecordBatchReader + Send>> for PyRecordBatchReader {
156 fn from(value: Box<dyn RecordBatchReader + Send>) -> Self {
157 Self::new(value)
158 }
159}
160
161impl Display for PyRecordBatchReader {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 writeln!(f, "arro3.core.RecordBatchReader")?;
164 writeln!(f, "-----------------------")?;
165 if let Ok(schema) = self.schema_ref() {
166 display_schema(&schema, f)
167 } else {
168 writeln!(f, "Closed stream")
169 }
170 }
171}
172
173#[pymethods]
174impl PyRecordBatchReader {
175 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
176 to_schema_pycapsule(py, self.schema_ref()?.as_ref())
177 }
178
179 #[pyo3(signature = (requested_schema=None))]
180 fn __arrow_c_stream__<'py>(
181 &'py self,
182 py: Python<'py>,
183 requested_schema: Option<Bound<'py, PyCapsule>>,
184 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
185 let reader = self
186 .0
187 .lock()
188 .unwrap()
189 .take()
190 .ok_or(PyIOError::new_err("Cannot read from closed stream"))?;
191 Self::to_stream_pycapsule(py, reader, requested_schema)
192 }
193
194 fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
197 slf
198 }
199
200 fn __next__(&self) -> PyArrowResult<Arro3RecordBatch> {
201 self.read_next_batch()
202 }
203
204 fn __repr__(&self) -> String {
205 self.to_string()
206 }
207
208 #[classmethod]
209 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
210 let reader = input.into_reader()?;
211 Ok(Self::new(reader))
212 }
213
214 #[classmethod]
215 #[pyo3(name = "from_arrow_pycapsule")]
216 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
217 Self::from_arrow_pycapsule(capsule)
218 }
219
220 #[classmethod]
221 fn from_batches(_cls: &Bound<PyType>, schema: PySchema, batches: Vec<PyRecordBatch>) -> Self {
222 let batches = batches
223 .into_iter()
224 .map(|batch| batch.into_inner())
225 .collect::<Vec<_>>();
226 Self::new(Box::new(RecordBatchIterator::new(
227 batches.into_iter().map(Ok),
228 schema.into_inner(),
229 )))
230 }
231
232 #[classmethod]
233 fn from_stream(_cls: &Bound<PyType>, data: &Bound<PyAny>) -> PyResult<Self> {
234 data.extract()
235 }
236
237 #[getter]
238 fn closed(&self) -> bool {
239 self.0.lock().unwrap().is_none()
240 }
241
242 fn read_all(&self) -> PyArrowResult<Arro3Table> {
243 let stream = self
244 .0
245 .lock()
246 .unwrap()
247 .take()
248 .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
249 let schema = stream.schema();
250 let mut batches = vec![];
251 for batch in stream {
252 batches.push(batch?);
253 }
254 Ok(PyTable::try_new(batches, schema)?.into())
255 }
256
257 fn read_next_batch(&self) -> PyArrowResult<Arro3RecordBatch> {
258 let mut inner = self.0.lock().unwrap();
259 let stream = inner
260 .as_mut()
261 .ok_or(PyIOError::new_err("Cannot read from closed stream."))?;
262
263 if let Some(next_batch) = stream.next() {
264 Ok(next_batch?.into())
265 } else {
266 Err(PyStopIteration::new_err("").into())
267 }
268 }
269
270 #[getter]
271 fn schema(&self) -> PyResult<Arro3Schema> {
272 Ok(PySchema::new(self.schema_ref()?.clone()).into())
273 }
274}