Skip to main content

pyo3_arrow/
record_batch.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_cast::pretty::pretty_format_batches_with_options;
7use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
8use arrow_select::concat::concat_batches;
9use arrow_select::take::take_record_batch;
10use indexmap::IndexMap;
11use pyo3::exceptions::{PyTypeError, PyValueError};
12use pyo3::intern;
13use pyo3::prelude::*;
14use pyo3::types::{PyCapsule, PyTuple, PyType};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
18use crate::ffi::from_python::utils::import_array_pycapsules;
19use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
20use crate::ffi::to_python::to_array_pycapsules;
21use crate::ffi::to_schema_pycapsule;
22use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
23use crate::utils::default_repr_options;
24use crate::{PyArray, PyField, PySchema};
25
26/// A Python-facing Arrow record batch.
27///
28/// This is a wrapper around a [RecordBatch].
29#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
30#[derive(Debug)]
31pub struct PyRecordBatch(RecordBatch);
32
33impl PyRecordBatch {
34    /// Construct a new PyRecordBatch from a [RecordBatch].
35    pub fn new(batch: RecordBatch) -> Self {
36        Self(batch)
37    }
38
39    /// Construct from raw Arrow capsules
40    pub fn from_arrow_pycapsule(
41        schema_capsule: &Bound<PyCapsule>,
42        array_capsule: &Bound<PyCapsule>,
43    ) -> PyResult<Self> {
44        let (array, field) = import_array_pycapsules(schema_capsule, array_capsule)?;
45
46        match field.data_type() {
47            DataType::Struct(fields) => {
48                let struct_array = array.as_struct();
49                let schema = SchemaBuilder::from(fields)
50                    .finish()
51                    .with_metadata(field.metadata().clone());
52                assert_eq!(
53                    struct_array.null_count(),
54                    0,
55                    "Cannot convert nullable StructArray to RecordBatch"
56                );
57
58                let columns = struct_array.columns().to_vec();
59
60                let batch = RecordBatch::try_new_with_options(
61                    Arc::new(schema),
62                    columns,
63                    &RecordBatchOptions::new().with_row_count(Some(array.len())),
64                )
65                .map_err(|err| PyValueError::new_err(err.to_string()))?;
66                Ok(Self::new(batch))
67            }
68            dt => Err(PyValueError::new_err(format!(
69                "Unexpected data type {}",
70                dt
71            ))),
72        }
73    }
74
75    /// Consume this, returning its internal [RecordBatch].
76    pub fn into_inner(self) -> RecordBatch {
77        self.0
78    }
79
80    /// Export this to a Python `arro3.core.RecordBatch`.
81    pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
82        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
83        arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
84            intern!(py, "from_arrow_pycapsule"),
85            self.__arrow_c_array__(py, None)?,
86        )
87    }
88
89    /// Export this to a Python `arro3.core.RecordBatch`.
90    pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
91        let arro3_mod = py.import(intern!(py, "arro3.core"))?;
92        let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
93        arro3_mod
94            .getattr(intern!(py, "RecordBatch"))?
95            .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
96    }
97
98    /// Export this to a Python `nanoarrow.Array`.
99    pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
100        to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
101    }
102
103    /// Export to a pyarrow.RecordBatch
104    ///
105    /// Requires pyarrow >=14
106    pub fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
107        let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
108        pyarrow_mod
109            .getattr(intern!(py, "record_batch"))?
110            .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)
111    }
112
113    pub(crate) fn to_array_pycapsules<'py>(
114        py: Python<'py>,
115        record_batch: RecordBatch,
116        requested_schema: Option<Bound<'py, PyCapsule>>,
117    ) -> PyArrowResult<Bound<'py, PyTuple>> {
118        let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
119        let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
120        to_array_pycapsules(py, field.into(), &array, requested_schema)
121    }
122}
123
124impl From<RecordBatch> for PyRecordBatch {
125    fn from(value: RecordBatch) -> Self {
126        Self(value)
127    }
128}
129
130impl From<PyRecordBatch> for RecordBatch {
131    fn from(value: PyRecordBatch) -> Self {
132        value.0
133    }
134}
135
136impl AsRef<RecordBatch> for PyRecordBatch {
137    fn as_ref(&self) -> &RecordBatch {
138        &self.0
139    }
140}
141
142impl Display for PyRecordBatch {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        writeln!(f, "arro3.core.RecordBatch")?;
145        pretty_format_batches_with_options(
146            &[self.0.slice(0, 10.min(self.0.num_rows()))],
147            &default_repr_options(),
148        )
149        .map_err(|_| std::fmt::Error)?
150        .fmt(f)?;
151
152        Ok(())
153    }
154}
155
156#[pymethods]
157impl PyRecordBatch {
158    #[new]
159    #[pyo3(signature = (data, *, names=None, schema=None, metadata=None))]
160    fn init(
161        py: Python,
162        data: &Bound<PyAny>,
163        names: Option<Vec<String>>,
164        schema: Option<PySchema>,
165        metadata: Option<MetadataInput>,
166    ) -> PyArrowResult<Self> {
167        if data.hasattr(intern!(py, "__arrow_c_array__"))? {
168            Ok(data.extract::<PyRecordBatch>()?)
169        } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
170            Self::from_pydict(&py.get_type::<Self>(), mapping, metadata)
171        } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
172            Self::from_arrays(&py.get_type::<Self>(), arrays, names, schema, metadata)
173        } else {
174            Err(PyTypeError::new_err(
175                "Expected RecordBatch-like input or dict of arrays or list of arrays.",
176            )
177            .into())
178        }
179    }
180
181    #[pyo3(signature = (requested_schema=None))]
182    fn __arrow_c_array__<'py>(
183        &'py self,
184        py: Python<'py>,
185        requested_schema: Option<Bound<'py, PyCapsule>>,
186    ) -> PyArrowResult<Bound<'py, PyTuple>> {
187        Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
188    }
189
190    fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
191        to_schema_pycapsule(py, self.0.schema_ref().as_ref())
192    }
193
194    fn __eq__(&self, other: &PyRecordBatch) -> bool {
195        self.0 == other.0
196    }
197
198    fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
199        self.column(key)
200    }
201
202    fn __repr__(&self) -> String {
203        self.to_string()
204    }
205
206    #[classmethod]
207    #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))]
208    fn from_arrays(
209        _cls: &Bound<PyType>,
210        arrays: Vec<PyArray>,
211        names: Option<Vec<String>>,
212        schema: Option<PySchema>,
213        metadata: Option<MetadataInput>,
214    ) -> PyArrowResult<Self> {
215        if schema.is_some() && metadata.is_some() {
216            return Err(PyValueError::new_err("Cannot pass both schema and metadata").into());
217        }
218
219        let (arrays, fields): (Vec<ArrayRef>, Vec<FieldRef>) =
220            arrays.into_iter().map(|arr| arr.into_inner()).unzip();
221
222        let schema: SchemaRef = if let Some(schema) = schema {
223            schema.into_inner()
224        } else {
225            let names = names.ok_or(PyValueError::new_err(
226                "names must be passed if schema is not passed.",
227            ))?;
228
229            let fields: Vec<_> = fields
230                .iter()
231                .zip(names.iter())
232                .map(|(field, name)| field.as_ref().clone().with_name(name))
233                .collect();
234
235            Arc::new(
236                Schema::new(fields)
237                    .with_metadata(metadata.unwrap_or_default().into_string_hashmap()?),
238            )
239        };
240
241        if arrays.is_empty() {
242            let rb = RecordBatch::try_new(schema, vec![])?;
243            return Ok(Self::new(rb));
244        }
245
246        let rb = RecordBatch::try_new(schema, arrays)?;
247        Ok(Self::new(rb))
248    }
249
250    #[classmethod]
251    #[pyo3(signature = (mapping, *, metadata=None))]
252    fn from_pydict(
253        _cls: &Bound<PyType>,
254        mapping: IndexMap<String, PyArray>,
255        metadata: Option<MetadataInput>,
256    ) -> PyArrowResult<Self> {
257        let mut fields = vec![];
258        let mut arrays = vec![];
259        mapping.into_iter().for_each(|(name, py_array)| {
260            let (arr, field) = py_array.into_inner();
261            fields.push(field.as_ref().clone().with_name(name));
262            arrays.push(arr);
263        });
264        let schema =
265            Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
266        let rb = RecordBatch::try_new(schema.into(), arrays)?;
267        Ok(Self::new(rb))
268    }
269
270    #[classmethod]
271    fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
272        let (array, field) = struct_array.into_inner();
273        match field.data_type() {
274            DataType::Struct(fields) => {
275                let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
276                let struct_arr = array.as_struct();
277                let columns = struct_arr.columns().to_vec();
278                let rb = RecordBatch::try_new(schema.into(), columns)?;
279                Ok(Self::new(rb))
280            }
281            _ => Err(PyTypeError::new_err("Expected struct array").into()),
282        }
283    }
284
285    #[classmethod]
286    fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
287        match input {
288            AnyRecordBatch::RecordBatch(rb) => Ok(rb),
289            AnyRecordBatch::Stream(stream) => {
290                let (batches, schema) = stream.into_table()?.into_inner();
291                let single_batch = concat_batches(&schema, batches.iter())?;
292                Ok(Self::new(single_batch))
293            }
294        }
295    }
296
297    #[classmethod]
298    #[pyo3(name = "from_arrow_pycapsule")]
299    fn from_arrow_pycapsule_py(
300        _cls: &Bound<PyType>,
301        schema_capsule: &Bound<PyCapsule>,
302        array_capsule: &Bound<PyCapsule>,
303    ) -> PyResult<Self> {
304        Self::from_arrow_pycapsule(schema_capsule, array_capsule)
305    }
306
307    fn add_column(
308        &self,
309        i: usize,
310        field: NameOrField,
311        column: PyArray,
312    ) -> PyArrowResult<Arro3RecordBatch> {
313        let mut fields = self.0.schema_ref().fields().to_vec();
314        fields.insert(i, field.into_field(column.field()));
315        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
316
317        let mut arrays = self.0.columns().to_vec();
318        arrays.insert(i, column.array().clone());
319
320        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
321        Ok(PyRecordBatch::new(new_rb).into())
322    }
323
324    fn append_column(
325        &self,
326        field: NameOrField,
327        column: PyArray,
328    ) -> PyArrowResult<Arro3RecordBatch> {
329        let mut fields = self.0.schema_ref().fields().to_vec();
330        fields.push(field.into_field(column.field()));
331        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
332
333        let mut arrays = self.0.columns().to_vec();
334        arrays.push(column.array().clone());
335
336        let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
337        Ok(PyRecordBatch::new(new_rb).into())
338    }
339
340    fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
341        let column_index = i.into_position(self.0.schema_ref())?;
342        let field = self.0.schema().field(column_index).clone();
343        let array = self.0.column(column_index).clone();
344        Ok(PyArray::new(array, field.into()).into())
345    }
346
347    #[getter]
348    fn column_names(&self) -> Vec<String> {
349        self.0
350            .schema()
351            .fields()
352            .iter()
353            .map(|f| f.name().clone())
354            .collect()
355    }
356
357    #[getter]
358    fn columns(&self) -> PyResult<Vec<Arro3Array>> {
359        (0..self.num_columns())
360            .map(|i| self.column(FieldIndexInput::Position(i)))
361            .collect()
362    }
363
364    fn equals(&self, other: PyRecordBatch) -> bool {
365        self.0 == other.0
366    }
367
368    fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
369        let schema_ref = self.0.schema_ref();
370        let field = schema_ref.field(i.into_position(schema_ref)?);
371        Ok(PyField::new(field.clone().into()).into())
372    }
373
374    #[getter]
375    fn nbytes(&self) -> usize {
376        self.0.get_array_memory_size()
377    }
378
379    #[getter]
380    fn num_columns(&self) -> usize {
381        self.0.num_columns()
382    }
383
384    #[getter]
385    fn num_rows(&self) -> usize {
386        self.0.num_rows()
387    }
388
389    fn remove_column(&self, i: usize) -> Arro3RecordBatch {
390        let mut rb = self.0.clone();
391        rb.remove_column(i);
392        PyRecordBatch::new(rb).into()
393    }
394
395    #[getter]
396    fn schema(&self) -> Arro3Schema {
397        self.0.schema().into()
398    }
399
400    fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
401        let positions = columns.into_positions(self.0.schema_ref().fields())?;
402        Ok(self.0.project(&positions)?.into())
403    }
404
405    fn set_column(
406        &self,
407        i: usize,
408        field: NameOrField,
409        column: PyArray,
410    ) -> PyArrowResult<Arro3RecordBatch> {
411        let mut fields = self.0.schema_ref().fields().to_vec();
412        fields[i] = field.into_field(column.field());
413        let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
414
415        let mut arrays = self.0.columns().to_vec();
416        arrays[i] = column.array().clone();
417
418        Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
419    }
420
421    #[getter]
422    fn shape(&self) -> (usize, usize) {
423        (self.num_rows(), self.num_columns())
424    }
425
426    #[pyo3(signature = (offset=0, length=None))]
427    fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
428        let length = length.unwrap_or_else(|| self.num_rows() - offset);
429        self.0.slice(offset, length).into()
430    }
431
432    fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
433        let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
434        Ok(new_batch.into())
435    }
436
437    fn to_struct_array(&self) -> Arro3Array {
438        let struct_array: StructArray = self.0.clone().into();
439        let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
440            .with_metadata(self.0.schema_ref().metadata.clone());
441        PyArray::new(Arc::new(struct_array), field.into()).into()
442    }
443
444    fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
445        let new_schema = schema.into_inner();
446        let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
447        Ok(new_batch.into())
448    }
449}