datafusion_python/
record_batch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::errors::PyDataFusionError;
21use crate::utils::wait_for_future;
22use datafusion::arrow::pyarrow::ToPyArrow;
23use datafusion::arrow::record_batch::RecordBatch;
24use datafusion::physical_plan::SendableRecordBatchStream;
25use futures::StreamExt;
26use pyo3::exceptions::{PyStopAsyncIteration, PyStopIteration};
27use pyo3::prelude::*;
28use pyo3::{pyclass, pymethods, PyObject, PyResult, Python};
29use tokio::sync::Mutex;
30
31#[pyclass(name = "RecordBatch", module = "datafusion", subclass)]
32pub struct PyRecordBatch {
33    batch: RecordBatch,
34}
35
36#[pymethods]
37impl PyRecordBatch {
38    fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
39        self.batch.to_pyarrow(py)
40    }
41}
42
43impl From<RecordBatch> for PyRecordBatch {
44    fn from(batch: RecordBatch) -> Self {
45        Self { batch }
46    }
47}
48
49#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)]
50pub struct PyRecordBatchStream {
51    stream: Arc<Mutex<SendableRecordBatchStream>>,
52}
53
54impl PyRecordBatchStream {
55    pub fn new(stream: SendableRecordBatchStream) -> Self {
56        Self {
57            stream: Arc::new(Mutex::new(stream)),
58        }
59    }
60}
61
62#[pymethods]
63impl PyRecordBatchStream {
64    fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
65        let stream = self.stream.clone();
66        wait_for_future(py, next_stream(stream, true))
67    }
68
69    fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
70        self.next(py)
71    }
72
73    fn __anext__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
74        let stream = self.stream.clone();
75        pyo3_async_runtimes::tokio::future_into_py(py, next_stream(stream, false))
76    }
77
78    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
79        slf
80    }
81
82    fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
83        slf
84    }
85}
86
87async fn next_stream(
88    stream: Arc<Mutex<SendableRecordBatchStream>>,
89    sync: bool,
90) -> PyResult<PyRecordBatch> {
91    let mut stream = stream.lock().await;
92    match stream.next().await {
93        Some(Ok(batch)) => Ok(batch.into()),
94        Some(Err(e)) => Err(PyDataFusionError::from(e))?,
95        None => {
96            // Depending on whether the iteration is sync or not, we raise either a
97            // StopIteration or a StopAsyncIteration
98            if sync {
99                Err(PyStopIteration::new_err("stream exhausted"))
100            } else {
101                Err(PyStopAsyncIteration::new_err("stream exhausted"))
102            }
103        }
104    }
105}