pyo3_arrow/
record_batch.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
7use arrow_select::concat::concat_batches;
8use arrow_select::take::take_record_batch;
9use indexmap::IndexMap;
10use pyo3::exceptions::{PyTypeError, PyValueError};
11use pyo3::prelude::*;
12use pyo3::types::{PyCapsule, PyTuple, PyType};
13use pyo3::{intern, IntoPyObjectExt};
14
15use crate::error::PyArrowResult;
16use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
17use crate::ffi::from_python::utils::import_array_pycapsules;
18use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
19use crate::ffi::to_python::to_array_pycapsules;
20use crate::ffi::to_schema_pycapsule;
21use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
22use crate::schema::display_schema;
23use crate::{PyArray, PyField, PySchema};
24
25/// A Python-facing Arrow record batch.
26///
27/// This is a wrapper around a [RecordBatch].
28#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
29#[derive(Debug)]
30pub struct PyRecordBatch(RecordBatch);
31
32impl PyRecordBatch {
33    /// Construct a new PyRecordBatch from a [RecordBatch].
34    pub fn new(batch: RecordBatch) -> Self {
35        Self(batch)
36    }
37
38    /// Construct from raw Arrow capsules
39    pub fn from_arrow_pycapsule(
40        schema_capsule: &Bound<PyCapsule>,
41        array_capsule: &Bound<PyCapsule>,
42    ) -> PyResult<Self> {
43        let (array, field, data_len) = import_array_pycapsules(schema_capsule, array_capsule)?;
44
45        match field.data_type() {
46            DataType::Struct(fields) => {
47                let struct_array = array.as_struct();
48                let schema = SchemaBuilder::from(fields)
49                    .finish()
50                    .with_metadata(field.metadata().clone());
51                assert_eq!(
52                    struct_array.null_count(),
53                    0,
54                    "Cannot convert nullable StructArray to RecordBatch"
55                );
56
57                let columns = struct_array.columns().to_vec();
58
59                // Special cast to handle zero-column RecordBatches with positive length
60                let batch = if array.is_empty() && data_len > 0 {
61                    RecordBatch::try_new_with_options(
62                        Arc::new(schema),
63                        columns,
64                        &RecordBatchOptions::new().with_row_count(Some(data_len)),
65                    )
66                    .map_err(|err| PyValueError::new_err(err.to_string()))?
67                } else {
68                    RecordBatch::try_new(Arc::new(schema), columns)
69                        .map_err(|err| PyValueError::new_err(err.to_string()))?
70                };
71                Ok(Self::new(batch))
72            }
73            dt => Err(PyValueError::new_err(format!(
74                "Unexpected data type {}",
75                dt
76            ))),
77        }
78    }
79
80    /// Consume this, returning its internal [RecordBatch].
81    pub fn into_inner(self) -> RecordBatch {
82        self.0
83    }
84
85    /// Export this to a Python `arro3.core.RecordBatch`.
86    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
87        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
88        arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
89            intern!(py, "from_arrow_pycapsule"),
90            self.__arrow_c_array__(py, None)?,
91        )
92    }
93
94    /// Export this to a Python `arro3.core.RecordBatch`.
95    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
96        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
97        let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
98        arro3_mod
99            .getattr(intern!(py, "RecordBatch"))?
100            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
101    }
102
103    /// Export this to a Python `nanoarrow.Array`.
104    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
105        to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
106    }
107
108    /// Export to a pyarrow.RecordBatch
109    ///
110    /// Requires pyarrow >=14
111    pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
112        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
113        let pyarrow_obj = pyarrow_mod
114            .getattr(intern!(py, "record_batch"))?
115            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
116        pyarrow_obj.into_py_any(py)
117    }
118
119    pub(crate) fn to_array_pycapsules<'py>(
120        py: Python<'py>,
121        record_batch: RecordBatch,
122        requested_schema: Option<Bound<'py, PyCapsule>>,
123    ) -> PyArrowResult<Bound<'py, PyTuple>> {
124        let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
125        let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
126        to_array_pycapsules(py, field.into(), &array, requested_schema)
127    }
128}
129
130impl From<RecordBatch> for PyRecordBatch {
131    fn from(value: RecordBatch) -> Self {
132        Self(value)
133    }
134}
135
136impl From<PyRecordBatch> for RecordBatch {
137    fn from(value: PyRecordBatch) -> Self {
138        value.0
139    }
140}
141
142impl AsRef<RecordBatch> for PyRecordBatch {
143    fn as_ref(&self) -> &RecordBatch {
144        &self.0
145    }
146}
147
148impl Display for PyRecordBatch {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        writeln!(f, "arro3.core.RecordBatch")?;
151        writeln!(f, "-----------------")?;
152        display_schema(&self.0.schema(), f)
153    }
154}
155
156#[pymethods]
157impl PyRecordBatch {
158    #[new]
159    #[pyo3(signature = (data, *,  schema=None, metadata=None))]
160    fn init(
161        py: Python,
162        data: &Bound<PyAny>,
163        schema: Option<PySchema>,
164        metadata: Option<MetadataInput>,
165    ) -> PyArrowResult<Self> {
166        if let Ok(data) = data.extract::<PyRecordBatch>() {
167            Ok(data)
168        } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
169            Self::from_pydict(&py.get_type::<PyRecordBatch>(), mapping, metadata)
170        } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
171            Self::from_arrays(
172                &py.get_type::<PyRecordBatch>(),
173                arrays,
174                schema.ok_or(PyValueError::new_err(
175                    "Schema must be passed with list of arrays",
176                ))?,
177            )
178        } else {
179            Err(PyTypeError::new_err(
180                "Expected RecordBatch-like input or dict of arrays or list of arrays.",
181            )
182            .into())
183        }
184    }
185
186    #[pyo3(signature = (requested_schema=None))]
187    fn __arrow_c_array__<'py>(
188        &'py self,
189        py: Python<'py>,
190        requested_schema: Option<Bound<'py, PyCapsule>>,
191    ) -> PyArrowResult<Bound<'py, PyTuple>> {
192        Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
193    }
194
195    fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
196        to_schema_pycapsule(py, self.0.schema_ref().as_ref())
197    }
198
199    fn __eq__(&self, other: &PyRecordBatch) -> bool {
200        self.0 == other.0
201    }
202
203    fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
204        self.column(key)
205    }
206
207    fn __repr__(&self) -> String {
208        self.to_string()
209    }
210
211    #[classmethod]
212    #[pyo3(signature = (arrays, *, schema))]
213    fn from_arrays(
214        _cls: &Bound<PyType>,
215        arrays: Vec<PyArray>,
216        schema: PySchema,
217    ) -> PyArrowResult<Self> {
218        let rb = RecordBatch::try_new(
219            schema.into(),
220            arrays
221                .into_iter()
222                .map(|arr| {
223                    let (arr, _field) = arr.into_inner();
224                    arr
225                })
226                .collect(),
227        )?;
228        Ok(Self::new(rb))
229    }
230
231    #[classmethod]
232    #[pyo3(signature = (mapping, *, metadata=None))]
233    fn from_pydict(
234        _cls: &Bound<PyType>,
235        mapping: IndexMap<String, PyArray>,
236        metadata: Option<MetadataInput>,
237    ) -> PyArrowResult<Self> {
238        let mut fields = vec![];
239        let mut arrays = vec![];
240        mapping.into_iter().for_each(|(name, py_array)| {
241            let (arr, field) = py_array.into_inner();
242            fields.push(field.as_ref().clone().with_name(name));
243            arrays.push(arr);
244        });
245        let schema =
246            Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
247        let rb = RecordBatch::try_new(schema.into(), arrays)?;
248        Ok(Self::new(rb))
249    }
250
251    #[classmethod]
252    fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
253        let (array, field) = struct_array.into_inner();
254        match field.data_type() {
255            DataType::Struct(fields) => {
256                let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
257                let struct_arr = array.as_struct();
258                let columns = struct_arr.columns().to_vec();
259                let rb = RecordBatch::try_new(schema.into(), columns)?;
260                Ok(Self::new(rb))
261            }
262            _ => Err(PyTypeError::new_err("Expected struct array").into()),
263        }
264    }
265
266    #[classmethod]
267    fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
268        match input {
269            AnyRecordBatch::RecordBatch(rb) => Ok(rb),
270            AnyRecordBatch::Stream(stream) => {
271                let (batches, schema) = stream.into_table()?.into_inner();
272                let single_batch = concat_batches(&schema, batches.iter())?;
273                Ok(Self::new(single_batch))
274            }
275        }
276    }
277
278    #[classmethod]
279    #[pyo3(name = "from_arrow_pycapsule")]
280    fn from_arrow_pycapsule_py(
281        _cls: &Bound<PyType>,
282        schema_capsule: &Bound<PyCapsule>,
283        array_capsule: &Bound<PyCapsule>,
284    ) -> PyResult<Self> {
285        Self::from_arrow_pycapsule(schema_capsule, array_capsule)
286    }
287
288    fn add_column(
289        &self,
290        i: usize,
291        field: NameOrField,
292        column: PyArray,
293    ) -> PyArrowResult<Arro3RecordBatch> {
294        let mut fields = self.0.schema_ref().fields().to_vec();
295        fields.insert(i, field.into_field(column.field()));
296        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
297
298        let mut arrays = self.0.columns().to_vec();
299        arrays.insert(i, column.array().clone());
300
301        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
302        Ok(PyRecordBatch::new(new_rb).into())
303    }
304
305    fn append_column(
306        &self,
307        field: NameOrField,
308        column: PyArray,
309    ) -> PyArrowResult<Arro3RecordBatch> {
310        let mut fields = self.0.schema_ref().fields().to_vec();
311        fields.push(field.into_field(column.field()));
312        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
313
314        let mut arrays = self.0.columns().to_vec();
315        arrays.push(column.array().clone());
316
317        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
318        Ok(PyRecordBatch::new(new_rb).into())
319    }
320
321    fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
322        let column_index = i.into_position(self.0.schema_ref())?;
323        let field = self.0.schema().field(column_index).clone();
324        let array = self.0.column(column_index).clone();
325        Ok(PyArray::new(array, field.into()).into())
326    }
327
328    #[getter]
329    fn column_names(&self) -> Vec<String> {
330        self.0
331            .schema()
332            .fields()
333            .iter()
334            .map(|f| f.name().clone())
335            .collect()
336    }
337
338    #[getter]
339    fn columns(&self) -> PyResult<Vec<Arro3Array>> {
340        (0..self.num_columns())
341            .map(|i| self.column(FieldIndexInput::Position(i)))
342            .collect()
343    }
344
345    fn equals(&self, other: PyRecordBatch) -> bool {
346        self.0 == other.0
347    }
348
349    fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
350        let schema_ref = self.0.schema_ref();
351        let field = schema_ref.field(i.into_position(schema_ref)?);
352        Ok(PyField::new(field.clone().into()).into())
353    }
354
355    #[getter]
356    fn nbytes(&self) -> usize {
357        self.0.get_array_memory_size()
358    }
359
360    #[getter]
361    fn num_columns(&self) -> usize {
362        self.0.num_columns()
363    }
364
365    #[getter]
366    fn num_rows(&self) -> usize {
367        self.0.num_rows()
368    }
369
370    fn remove_column(&self, i: usize) -> Arro3RecordBatch {
371        let mut rb = self.0.clone();
372        rb.remove_column(i);
373        PyRecordBatch::new(rb).into()
374    }
375
376    #[getter]
377    fn schema(&self) -> Arro3Schema {
378        self.0.schema().into()
379    }
380
381    fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
382        let positions = columns.into_positions(self.0.schema_ref().fields())?;
383        Ok(self.0.project(&positions)?.into())
384    }
385
386    fn set_column(
387        &self,
388        i: usize,
389        field: NameOrField,
390        column: PyArray,
391    ) -> PyArrowResult<Arro3RecordBatch> {
392        let mut fields = self.0.schema_ref().fields().to_vec();
393        fields[i] = field.into_field(column.field());
394        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
395
396        let mut arrays = self.0.columns().to_vec();
397        arrays[i] = column.array().clone();
398
399        Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
400    }
401
402    #[getter]
403    fn shape(&self) -> (usize, usize) {
404        (self.num_rows(), self.num_columns())
405    }
406
407    #[pyo3(signature = (offset=0, length=None))]
408    fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
409        let length = length.unwrap_or_else(|| self.num_rows() - offset);
410        self.0.slice(offset, length).into()
411    }
412
413    fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
414        let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
415        Ok(new_batch.into())
416    }
417
418    fn to_struct_array(&self) -> Arro3Array {
419        let struct_array: StructArray = self.0.clone().into();
420        let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
421            .with_metadata(self.0.schema_ref().metadata.clone());
422        PyArray::new(Arc::new(struct_array), field.into()).into()
423    }
424
425    fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
426        let new_schema = schema.into_inner();
427        let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
428        Ok(new_batch.into())
429    }
430}