Skip to main content

pyo3_arrow/
record_batch.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_cast::pretty::pretty_format_batches_with_options;
7use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
8use arrow_select::concat::concat_batches;
9use arrow_select::take::take_record_batch;
10use indexmap::IndexMap;
11use pyo3::exceptions::{PyTypeError, PyValueError};
12use pyo3::intern;
13use pyo3::prelude::*;
14use pyo3::types::{PyCapsule, PyTuple, PyType};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
18use crate::ffi::from_python::utils::import_array_pycapsules;
19use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
20use crate::ffi::to_python::to_array_pycapsules;
21use crate::ffi::to_schema_pycapsule;
22use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
23use crate::utils::default_repr_options;
24use crate::{PyArray, PyField, PySchema};
25
26/// A Python-facing Arrow record batch.
27///
28/// This is a wrapper around a [RecordBatch].
29#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
30#[derive(Debug)]
31pub struct PyRecordBatch(RecordBatch);
32
33impl PyRecordBatch {
34    /// Construct a new PyRecordBatch from a [RecordBatch].
35    pub fn new(batch: RecordBatch) -> Self {
36        Self(batch)
37    }
38
39    /// Construct from raw Arrow capsules
40    pub fn from_arrow_pycapsule(
41        schema_capsule: &Bound<PyCapsule>,
42        array_capsule: &Bound<PyCapsule>,
43    ) -> PyResult<Self> {
44        let (array, field) = import_array_pycapsules(schema_capsule, array_capsule)?;
45
46        match field.data_type() {
47            DataType::Struct(fields) => {
48                let struct_array = array.as_struct();
49                let schema = SchemaBuilder::from(fields)
50                    .finish()
51                    .with_metadata(field.metadata().clone());
52                assert_eq!(
53                    struct_array.null_count(),
54                    0,
55                    "Cannot convert nullable StructArray to RecordBatch"
56                );
57
58                let columns = struct_array.columns().to_vec();
59
60                let batch = RecordBatch::try_new_with_options(
61                    Arc::new(schema),
62                    columns,
63                    &RecordBatchOptions::new().with_row_count(Some(array.len())),
64                )
65                .map_err(|err| PyValueError::new_err(err.to_string()))?;
66                Ok(Self::new(batch))
67            }
68            dt => Err(PyValueError::new_err(format!(
69                "Unexpected data type {}",
70                dt
71            ))),
72        }
73    }
74
75    /// Consume this, returning its internal [RecordBatch].
76    pub fn into_inner(self) -> RecordBatch {
77        self.0
78    }
79
80    /// Export this to a Python `arro3.core.RecordBatch`.
81    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
82        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
83        arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
84            intern!(py, "from_arrow_pycapsule"),
85            self.__arrow_c_array__(py, None)?,
86        )
87    }
88
89    /// Export this to a Python `arro3.core.RecordBatch`.
90    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
91        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
92        let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
93        arro3_mod
94            .getattr(intern!(py, "RecordBatch"))?
95            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
96    }
97
98    /// Export this to a Python `nanoarrow.Array`.
99    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
100        to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
101    }
102
103    /// Export to a pyarrow.RecordBatch
104    ///
105    /// Requires pyarrow >=14
106    pub fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
107        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
108        pyarrow_mod
109            .getattr(intern!(py, "record_batch"))?
110            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)
111    }
112
113    pub(crate) fn to_array_pycapsules<'py>(
114        py: Python<'py>,
115        record_batch: RecordBatch,
116        requested_schema: Option<Bound<'py, PyCapsule>>,
117    ) -> PyArrowResult<Bound<'py, PyTuple>> {
118        let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false)
119            .with_metadata(record_batch.schema_ref().metadata().clone());
120        let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
121        to_array_pycapsules(py, field.into(), &array, requested_schema)
122    }
123}
124
125impl From<RecordBatch> for PyRecordBatch {
126    fn from(value: RecordBatch) -> Self {
127        Self(value)
128    }
129}
130
131impl From<PyRecordBatch> for RecordBatch {
132    fn from(value: PyRecordBatch) -> Self {
133        value.0
134    }
135}
136
137impl AsRef<RecordBatch> for PyRecordBatch {
138    fn as_ref(&self) -> &RecordBatch {
139        &self.0
140    }
141}
142
143impl Display for PyRecordBatch {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        writeln!(f, "arro3.core.RecordBatch")?;
146        pretty_format_batches_with_options(
147            &[self.0.slice(0, 10.min(self.0.num_rows()))],
148            &default_repr_options(),
149        )
150        .map_err(|_| std::fmt::Error)?
151        .fmt(f)?;
152
153        Ok(())
154    }
155}
156
157#[pymethods]
158impl PyRecordBatch {
159    #[new]
160    #[pyo3(signature = (data, *, names=None, schema=None, metadata=None))]
161    fn init(
162        py: Python,
163        data: &Bound<PyAny>,
164        names: Option<Vec<String>>,
165        schema: Option<PySchema>,
166        metadata: Option<MetadataInput>,
167    ) -> PyArrowResult<Self> {
168        if data.hasattr(intern!(py, "__arrow_c_array__"))? {
169            Ok(data.extract::<PyRecordBatch>()?)
170        } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
171            Self::from_pydict(&py.get_type::<Self>(), mapping, metadata)
172        } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
173            Self::from_arrays(&py.get_type::<Self>(), arrays, names, schema, metadata)
174        } else {
175            Err(PyTypeError::new_err(
176                "Expected RecordBatch-like input or dict of arrays or list of arrays.",
177            )
178            .into())
179        }
180    }
181
182    #[pyo3(signature = (requested_schema=None))]
183    fn __arrow_c_array__<'py>(
184        &'py self,
185        py: Python<'py>,
186        requested_schema: Option<Bound<'py, PyCapsule>>,
187    ) -> PyArrowResult<Bound<'py, PyTuple>> {
188        Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
189    }
190
191    fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
192        to_schema_pycapsule(py, self.0.schema_ref().as_ref())
193    }
194
195    fn __eq__(&self, other: PyRecordBatch) -> bool {
196        self.0 == other.0
197    }
198
199    fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
200        self.column(key)
201    }
202
203    fn __repr__(&self) -> String {
204        self.to_string()
205    }
206
207    #[classmethod]
208    #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))]
209    fn from_arrays(
210        _cls: &Bound<PyType>,
211        arrays: Vec<PyArray>,
212        names: Option<Vec<String>>,
213        schema: Option<PySchema>,
214        metadata: Option<MetadataInput>,
215    ) -> PyArrowResult<Self> {
216        if schema.is_some() && metadata.is_some() {
217            return Err(PyValueError::new_err("Cannot pass both schema and metadata").into());
218        }
219
220        let (arrays, fields): (Vec<ArrayRef>, Vec<FieldRef>) =
221            arrays.into_iter().map(|arr| arr.into_inner()).unzip();
222
223        let schema: SchemaRef = if let Some(schema) = schema {
224            schema.into_inner()
225        } else {
226            let names = names.ok_or(PyValueError::new_err(
227                "names must be passed if schema is not passed.",
228            ))?;
229
230            let fields: Vec<_> = fields
231                .iter()
232                .zip(names.iter())
233                .map(|(field, name)| field.as_ref().clone().with_name(name))
234                .collect();
235
236            Arc::new(
237                Schema::new(fields)
238                    .with_metadata(metadata.unwrap_or_default().into_string_hashmap()?),
239            )
240        };
241
242        if arrays.is_empty() {
243            let rb = RecordBatch::try_new(schema, vec![])?;
244            return Ok(Self::new(rb));
245        }
246
247        let rb = RecordBatch::try_new(schema, arrays)?;
248        Ok(Self::new(rb))
249    }
250
251    #[classmethod]
252    #[pyo3(signature = (mapping, *, metadata=None))]
253    fn from_pydict(
254        _cls: &Bound<PyType>,
255        mapping: IndexMap<String, PyArray>,
256        metadata: Option<MetadataInput>,
257    ) -> PyArrowResult<Self> {
258        let mut fields = vec![];
259        let mut arrays = vec![];
260        mapping.into_iter().for_each(|(name, py_array)| {
261            let (arr, field) = py_array.into_inner();
262            fields.push(field.as_ref().clone().with_name(name));
263            arrays.push(arr);
264        });
265        let schema =
266            Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
267        let rb = RecordBatch::try_new(schema.into(), arrays)?;
268        Ok(Self::new(rb))
269    }
270
271    #[classmethod]
272    fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
273        let (array, field) = struct_array.into_inner();
274        match field.data_type() {
275            DataType::Struct(fields) => {
276                let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
277                let struct_arr = array.as_struct();
278                let columns = struct_arr.columns().to_vec();
279                let rb = RecordBatch::try_new(schema.into(), columns)?;
280                Ok(Self::new(rb))
281            }
282            _ => Err(PyTypeError::new_err("Expected struct array").into()),
283        }
284    }
285
286    #[classmethod]
287    fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
288        match input {
289            AnyRecordBatch::RecordBatch(rb) => Ok(rb),
290            AnyRecordBatch::Stream(stream) => {
291                let (batches, schema) = stream.into_table()?.into_inner();
292                let single_batch = concat_batches(&schema, batches.iter())?;
293                Ok(Self::new(single_batch))
294            }
295        }
296    }
297
298    #[classmethod]
299    #[pyo3(name = "from_arrow_pycapsule")]
300    fn from_arrow_pycapsule_py(
301        _cls: &Bound<PyType>,
302        schema_capsule: &Bound<PyCapsule>,
303        array_capsule: &Bound<PyCapsule>,
304    ) -> PyResult<Self> {
305        Self::from_arrow_pycapsule(schema_capsule, array_capsule)
306    }
307
308    fn add_column(
309        &self,
310        i: usize,
311        field: NameOrField,
312        column: PyArray,
313    ) -> PyArrowResult<Arro3RecordBatch> {
314        let mut fields = self.0.schema_ref().fields().to_vec();
315        fields.insert(i, field.into_field(column.field()));
316        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
317
318        let mut arrays = self.0.columns().to_vec();
319        arrays.insert(i, column.array().clone());
320
321        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
322        Ok(PyRecordBatch::new(new_rb).into())
323    }
324
325    fn append_column(
326        &self,
327        field: NameOrField,
328        column: PyArray,
329    ) -> PyArrowResult<Arro3RecordBatch> {
330        let mut fields = self.0.schema_ref().fields().to_vec();
331        fields.push(field.into_field(column.field()));
332        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
333
334        let mut arrays = self.0.columns().to_vec();
335        arrays.push(column.array().clone());
336
337        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
338        Ok(PyRecordBatch::new(new_rb).into())
339    }
340
341    fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
342        let column_index = i.into_position(self.0.schema_ref())?;
343        let field = self.0.schema().field(column_index).clone();
344        let array = self.0.column(column_index).clone();
345        Ok(PyArray::new(array, field.into()).into())
346    }
347
348    #[getter]
349    fn column_names(&self) -> Vec<String> {
350        self.0
351            .schema()
352            .fields()
353            .iter()
354            .map(|f| f.name().clone())
355            .collect()
356    }
357
358    #[getter]
359    fn columns(&self) -> PyResult<Vec<Arro3Array>> {
360        (0..self.num_columns())
361            .map(|i| self.column(FieldIndexInput::Position(i)))
362            .collect()
363    }
364
365    fn equals(&self, other: PyRecordBatch) -> bool {
366        self.0 == other.0
367    }
368
369    fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
370        let schema_ref = self.0.schema_ref();
371        let field = schema_ref.field(i.into_position(schema_ref)?);
372        Ok(PyField::new(field.clone().into()).into())
373    }
374
375    #[getter]
376    fn nbytes(&self) -> usize {
377        self.0.get_array_memory_size()
378    }
379
380    #[getter]
381    fn num_columns(&self) -> usize {
382        self.0.num_columns()
383    }
384
385    #[getter]
386    fn num_rows(&self) -> usize {
387        self.0.num_rows()
388    }
389
390    fn remove_column(&self, i: usize) -> Arro3RecordBatch {
391        let mut rb = self.0.clone();
392        rb.remove_column(i);
393        PyRecordBatch::new(rb).into()
394    }
395
396    #[getter]
397    fn schema(&self) -> Arro3Schema {
398        self.0.schema().into()
399    }
400
401    fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
402        let positions = columns.into_positions(self.0.schema_ref().fields())?;
403        Ok(self.0.project(&positions)?.into())
404    }
405
406    fn set_column(
407        &self,
408        i: usize,
409        field: NameOrField,
410        column: PyArray,
411    ) -> PyArrowResult<Arro3RecordBatch> {
412        let mut fields = self.0.schema_ref().fields().to_vec();
413        fields[i] = field.into_field(column.field());
414        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
415
416        let mut arrays = self.0.columns().to_vec();
417        arrays[i] = column.array().clone();
418
419        Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
420    }
421
422    #[getter]
423    fn shape(&self) -> (usize, usize) {
424        (self.num_rows(), self.num_columns())
425    }
426
427    #[pyo3(signature = (offset=0, length=None))]
428    fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
429        let length = length.unwrap_or_else(|| self.num_rows() - offset);
430        self.0.slice(offset, length).into()
431    }
432
433    fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
434        let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
435        Ok(new_batch.into())
436    }
437
438    fn to_struct_array(&self) -> Arro3Array {
439        let struct_array: StructArray = self.0.clone().into();
440        let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
441            .with_metadata(self.0.schema_ref().metadata.clone());
442        PyArray::new(Arc::new(struct_array), field.into()).into()
443    }
444
445    fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
446        let new_schema = schema.into_inner();
447        let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
448        Ok(new_batch.into())
449    }
450}