datafusion_python/
record_batch.rs1use std::sync::Arc;
19
20use crate::errors::PyDataFusionError;
21use crate::utils::wait_for_future;
22use datafusion::arrow::pyarrow::ToPyArrow;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use futures::StreamExt;
26use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
27use pyo3::prelude::*;
28use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
29use tokio::sync::Mutex;
30
31#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
32pub struct PyRecordBatch {
33 batch: RecordBatch,
34}
35
36#[pymethods]
37impl PyRecordBatch {
38 fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
39 self.batch.to_pyarrow(py)
40 }
41}
42
43impl From<RecordBatch> for PyRecordBatch {
44 fn from(batch: RecordBatch) -> Self {
45 Self { batch }
46 }
47}
48
49#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
50pub struct PyRecordBatchStream {
51 stream: Arc<Mutex<SendableRecordBatchStream>>,
52}
53
54impl PyRecordBatchStream {
55 pub fn new(stream: SendableRecordBatchStream) -> Self {
56 Self {
57 stream: Arc::new(Mutex::new(stream)),
58 }
59 }
60}
61
62#[pymethods]
63impl PyRecordBatchStream {
64 fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
65 let stream = self.stream.clone();
66 wait_for_future(py, next_stream(stream, true))
67 }
68
69 fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
70 self.next(py)
71 }
72
73 fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
74 let stream = self.stream.clone();
75 pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
76 }
77
78 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
79 slf
80 }
81
82 fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
83 slf
84 }
85}
86
87async fn next_stream(
88 stream: Arc<Mutex<SendableRecordBatchStream>>,
89 sync: bool,
90) -> PyResult<PyRecordBatch> {
91 let mut stream = stream.lock().await;
92 match stream.next().await {
93 Some(Ok(batch)) => Ok(batch.into()),
94 Some(Err(e)) => Err(PyDataFusionError::from(e))?,
95 None => {
96 if sync {
99 Err(PyStopIteration::new_err("stream exhausted"))
100 } else {
101 Err(PyStopAsyncIteration::new_err("stream exhausted"))
102 }
103 }
104 }
105}