pyo3_arrow/
record_batch.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow::array::AsArray;
5use arrow::compute::{concat_batches, take_record_batch};
6use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
7use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
8use indexmap::IndexMap;
9use pyo3::exceptions::{PyTypeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyCapsule, PyTuple, PyType};
12use pyo3::{intern, IntoPyObjectExt};
13
14use crate::error::PyArrowResult;
15use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
16use crate::ffi::from_python::utils::import_array_pycapsules;
17use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
18use crate::ffi::to_python::to_array_pycapsules;
19use crate::ffi::to_schema_pycapsule;
20use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
21use crate::schema::display_schema;
22use crate::{PyArray, PyField, PySchema};
23
24/// A Python-facing Arrow record batch.
25///
26/// This is a wrapper around a [RecordBatch].
27#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
28#[derive(Debug)]
29pub struct PyRecordBatch(RecordBatch);
30
31impl PyRecordBatch {
32    /// Construct a new PyRecordBatch from a [RecordBatch].
33    pub fn new(batch: RecordBatch) -> Self {
34        Self(batch)
35    }
36
37    /// Construct from raw Arrow capsules
38    pub fn from_arrow_pycapsule(
39        schema_capsule: &Bound<PyCapsule>,
40        array_capsule: &Bound<PyCapsule>,
41    ) -> PyResult<Self> {
42        let (array, field, data_len) = import_array_pycapsules(schema_capsule, array_capsule)?;
43
44        match field.data_type() {
45            DataType::Struct(fields) => {
46                let struct_array = array.as_struct();
47                let schema = SchemaBuilder::from(fields)
48                    .finish()
49                    .with_metadata(field.metadata().clone());
50                assert_eq!(
51                    struct_array.null_count(),
52                    0,
53                    "Cannot convert nullable StructArray to RecordBatch"
54                );
55
56                let columns = struct_array.columns().to_vec();
57
58                // Special cast to handle zero-column RecordBatches with positive length
59                let batch = if array.len() == 0 && data_len > 0 {
60                    RecordBatch::try_new_with_options(
61                        Arc::new(schema),
62                        columns,
63                        &RecordBatchOptions::new().with_row_count(Some(data_len)),
64                    )
65                    .map_err(|err| PyValueError::new_err(err.to_string()))?
66                } else {
67                    RecordBatch::try_new(Arc::new(schema), columns)
68                        .map_err(|err| PyValueError::new_err(err.to_string()))?
69                };
70                Ok(Self::new(batch))
71            }
72            dt => Err(PyValueError::new_err(format!(
73                "Unexpected data type {}",
74                dt
75            ))),
76        }
77    }
78
79    /// Consume this, returning its internal [RecordBatch].
80    pub fn into_inner(self) -> RecordBatch {
81        self.0
82    }
83
84    /// Export this to a Python `arro3.core.RecordBatch`.
85    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
86        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
87        arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
88            intern!(py, "from_arrow_pycapsule"),
89            self.__arrow_c_array__(py, None)?,
90        )
91    }
92
93    /// Export this to a Python `arro3.core.RecordBatch`.
94    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
95        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
96        let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
97        arro3_mod
98            .getattr(intern!(py, "RecordBatch"))?
99            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
100    }
101
102    /// Export this to a Python `nanoarrow.Array`.
103    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
104        to_nanoarrow_array(py, &self.__arrow_c_array__(py, None)?)
105    }
106
107    /// Export to a pyarrow.RecordBatch
108    ///
109    /// Requires pyarrow >=14
110    pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
111        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
112        let pyarrow_obj = pyarrow_mod
113            .getattr(intern!(py, "record_batch"))?
114            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
115        pyarrow_obj.into_py_any(py)
116    }
117
118    pub(crate) fn to_array_pycapsules<'py>(
119        py: Python<'py>,
120        record_batch: RecordBatch,
121        requested_schema: Option<Bound<'py, PyCapsule>>,
122    ) -> PyArrowResult<Bound<'py, PyTuple>> {
123        let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
124        let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
125        to_array_pycapsules(py, field.into(), &array, requested_schema)
126    }
127}
128
129impl From<RecordBatch> for PyRecordBatch {
130    fn from(value: RecordBatch) -> Self {
131        Self(value)
132    }
133}
134
135impl From<PyRecordBatch> for RecordBatch {
136    fn from(value: PyRecordBatch) -> Self {
137        value.0
138    }
139}
140
141impl AsRef<RecordBatch> for PyRecordBatch {
142    fn as_ref(&self) -> &RecordBatch {
143        &self.0
144    }
145}
146
147impl Display for PyRecordBatch {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        writeln!(f, "arro3.core.RecordBatch")?;
150        writeln!(f, "-----------------")?;
151        display_schema(&self.0.schema(), f)
152    }
153}
154
155#[pymethods]
156impl PyRecordBatch {
157    #[new]
158    #[pyo3(signature = (data, *,  schema=None, metadata=None))]
159    fn init(
160        py: Python,
161        data: &Bound<PyAny>,
162        schema: Option<PySchema>,
163        metadata: Option<MetadataInput>,
164    ) -> PyArrowResult<Self> {
165        if let Ok(data) = data.extract::<PyRecordBatch>() {
166            Ok(data)
167        } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
168            Self::from_pydict(&py.get_type::<PyRecordBatch>(), mapping, metadata)
169        } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
170            Self::from_arrays(
171                &py.get_type::<PyRecordBatch>(),
172                arrays,
173                schema.ok_or(PyValueError::new_err(
174                    "Schema must be passed with list of arrays",
175                ))?,
176            )
177        } else {
178            Err(PyTypeError::new_err(
179                "Expected RecordBatch-like input or dict of arrays or list of arrays.",
180            )
181            .into())
182        }
183    }
184
185    #[pyo3(signature = (requested_schema=None))]
186    fn __arrow_c_array__<'py>(
187        &'py self,
188        py: Python<'py>,
189        requested_schema: Option<Bound<'py, PyCapsule>>,
190    ) -> PyArrowResult<Bound<'py, PyTuple>> {
191        Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
192    }
193
194    fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
195        to_schema_pycapsule(py, self.0.schema_ref().as_ref())
196    }
197
198    fn __eq__(&self, other: &PyRecordBatch) -> bool {
199        self.0 == other.0
200    }
201
202    fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
203        self.column(key)
204    }
205
206    fn __repr__(&self) -> String {
207        self.to_string()
208    }
209
210    #[classmethod]
211    #[pyo3(signature = (arrays, *, schema))]
212    fn from_arrays(
213        _cls: &Bound<PyType>,
214        arrays: Vec<PyArray>,
215        schema: PySchema,
216    ) -> PyArrowResult<Self> {
217        let rb = RecordBatch::try_new(
218            schema.into(),
219            arrays
220                .into_iter()
221                .map(|arr| {
222                    let (arr, _field) = arr.into_inner();
223                    arr
224                })
225                .collect(),
226        )?;
227        Ok(Self::new(rb))
228    }
229
230    #[classmethod]
231    #[pyo3(signature = (mapping, *, metadata=None))]
232    fn from_pydict(
233        _cls: &Bound<PyType>,
234        mapping: IndexMap<String, PyArray>,
235        metadata: Option<MetadataInput>,
236    ) -> PyArrowResult<Self> {
237        let mut fields = vec![];
238        let mut arrays = vec![];
239        mapping.into_iter().for_each(|(name, py_array)| {
240            let (arr, field) = py_array.into_inner();
241            fields.push(field.as_ref().clone().with_name(name));
242            arrays.push(arr);
243        });
244        let schema =
245            Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
246        let rb = RecordBatch::try_new(schema.into(), arrays)?;
247        Ok(Self::new(rb))
248    }
249
250    #[classmethod]
251    fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
252        let (array, field) = struct_array.into_inner();
253        match field.data_type() {
254            DataType::Struct(fields) => {
255                let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
256                let struct_arr = array.as_struct();
257                let columns = struct_arr.columns().to_vec();
258                let rb = RecordBatch::try_new(schema.into(), columns)?;
259                Ok(Self::new(rb))
260            }
261            _ => Err(PyTypeError::new_err("Expected struct array").into()),
262        }
263    }
264
265    #[classmethod]
266    fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
267        match input {
268            AnyRecordBatch::RecordBatch(rb) => Ok(rb),
269            AnyRecordBatch::Stream(stream) => {
270                let (batches, schema) = stream.into_table()?.into_inner();
271                let single_batch = concat_batches(&schema, batches.iter())?;
272                Ok(Self::new(single_batch))
273            }
274        }
275    }
276
277    #[classmethod]
278    #[pyo3(name = "from_arrow_pycapsule")]
279    fn from_arrow_pycapsule_py(
280        _cls: &Bound<PyType>,
281        schema_capsule: &Bound<PyCapsule>,
282        array_capsule: &Bound<PyCapsule>,
283    ) -> PyResult<Self> {
284        Self::from_arrow_pycapsule(schema_capsule, array_capsule)
285    }
286
287    fn add_column(
288        &self,
289        i: usize,
290        field: NameOrField,
291        column: PyArray,
292    ) -> PyArrowResult<Arro3RecordBatch> {
293        let mut fields = self.0.schema_ref().fields().to_vec();
294        fields.insert(i, field.into_field(column.field()));
295        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
296
297        let mut arrays = self.0.columns().to_vec();
298        arrays.insert(i, column.array().clone());
299
300        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
301        Ok(PyRecordBatch::new(new_rb).into())
302    }
303
304    fn append_column(
305        &self,
306        field: NameOrField,
307        column: PyArray,
308    ) -> PyArrowResult<Arro3RecordBatch> {
309        let mut fields = self.0.schema_ref().fields().to_vec();
310        fields.push(field.into_field(column.field()));
311        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
312
313        let mut arrays = self.0.columns().to_vec();
314        arrays.push(column.array().clone());
315
316        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
317        Ok(PyRecordBatch::new(new_rb).into())
318    }
319
320    fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
321        let column_index = i.into_position(self.0.schema_ref())?;
322        let field = self.0.schema().field(column_index).clone();
323        let array = self.0.column(column_index).clone();
324        Ok(PyArray::new(array, field.into()).into())
325    }
326
327    #[getter]
328    fn column_names(&self) -> Vec<String> {
329        self.0
330            .schema()
331            .fields()
332            .iter()
333            .map(|f| f.name().clone())
334            .collect()
335    }
336
337    #[getter]
338    fn columns(&self) -> PyResult<Vec<Arro3Array>> {
339        (0..self.num_columns())
340            .map(|i| self.column(FieldIndexInput::Position(i)))
341            .collect()
342    }
343
344    fn equals(&self, other: PyRecordBatch) -> bool {
345        self.0 == other.0
346    }
347
348    fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
349        let schema_ref = self.0.schema_ref();
350        let field = schema_ref.field(i.into_position(schema_ref)?);
351        Ok(PyField::new(field.clone().into()).into())
352    }
353
354    #[getter]
355    fn nbytes(&self) -> usize {
356        self.0.get_array_memory_size()
357    }
358
359    #[getter]
360    fn num_columns(&self) -> usize {
361        self.0.num_columns()
362    }
363
364    #[getter]
365    fn num_rows(&self) -> usize {
366        self.0.num_rows()
367    }
368
369    fn remove_column(&self, i: usize) -> Arro3RecordBatch {
370        let mut rb = self.0.clone();
371        rb.remove_column(i);
372        PyRecordBatch::new(rb).into()
373    }
374
375    #[getter]
376    fn schema(&self) -> Arro3Schema {
377        self.0.schema().into()
378    }
379
380    fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
381        let positions = columns.into_positions(self.0.schema_ref().fields())?;
382        Ok(self.0.project(&positions)?.into())
383    }
384
385    fn set_column(
386        &self,
387        i: usize,
388        field: NameOrField,
389        column: PyArray,
390    ) -> PyArrowResult<Arro3RecordBatch> {
391        let mut fields = self.0.schema_ref().fields().to_vec();
392        fields[i] = field.into_field(column.field());
393        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
394
395        let mut arrays = self.0.columns().to_vec();
396        arrays[i] = column.array().clone();
397
398        Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
399    }
400
401    #[getter]
402    fn shape(&self) -> (usize, usize) {
403        (self.num_rows(), self.num_columns())
404    }
405
406    #[pyo3(signature = (offset=0, length=None))]
407    fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
408        let length = length.unwrap_or_else(|| self.num_rows() - offset);
409        self.0.slice(offset, length).into()
410    }
411
412    fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
413        let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
414        Ok(new_batch.into())
415    }
416
417    fn to_struct_array(&self) -> Arro3Array {
418        let struct_array: StructArray = self.0.clone().into();
419        let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
420            .with_metadata(self.0.schema_ref().metadata.clone());
421        PyArray::new(Arc::new(struct_array), field.into()).into()
422    }
423
424    fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
425        let new_schema = schema.into_inner();
426        let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
427        Ok(new_batch.into())
428    }
429}