1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_cast::pretty::pretty_format_batches_with_options;
7use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
8use arrow_select::concat::concat_batches;
9use arrow_select::take::take_record_batch;
10use indexmap::IndexMap;
11use pyo3::exceptions::{PyTypeError, PyValueError};
12use pyo3::intern;
13use pyo3::prelude::*;
14use pyo3::types::{PyCapsule, PyTuple, PyType};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
18use crate::ffi::from_python::utils::import_array_pycapsules;
19use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
20use crate::ffi::to_python::to_array_pycapsules;
21use crate::ffi::to_schema_pycapsule;
22use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
23use crate::utils::default_repr_options;
24use crate::{PyArray, PyField, PySchema};
25
26#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
30#[derive(Debug)]
31pub struct PyRecordBatch(RecordBatch);
32
33impl PyRecordBatch {
34 pub fn new(batch: RecordBatch) -> Self {
36 Self(batch)
37 }
38
39 pub fn from_arrow_pycapsule(
41 schema_capsule: &Bound<PyCapsule>,
42 array_capsule: &Bound<PyCapsule>,
43 ) -> PyResult<Self> {
44 let (array, field) = import_array_pycapsules(schema_capsule, array_capsule)?;
45
46 match field.data_type() {
47 DataType::Struct(fields) => {
48 let struct_array = array.as_struct();
49 let schema = SchemaBuilder::from(fields)
50 .finish()
51 .with_metadata(field.metadata().clone());
52 assert_eq!(
53 struct_array.null_count(),
54 0,
55 "Cannot convert nullable StructArray to RecordBatch"
56 );
57
58 let columns = struct_array.columns().to_vec();
59
60 let batch = RecordBatch::try_new_with_options(
61 Arc::new(schema),
62 columns,
63 &RecordBatchOptions::new().with_row_count(Some(array.len())),
64 )
65 .map_err(|err| PyValueError::new_err(err.to_string()))?;
66 Ok(Self::new(batch))
67 }
68 dt => Err(PyValueError::new_err(format!(
69 "Unexpected data type {}",
70 dt
71 ))),
72 }
73 }
74
75 pub fn into_inner(self) -> RecordBatch {
77 self.0
78 }
79
80 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
82 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
83 arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
84 intern!(py, "from_arrow_pycapsule"),
85 self.__arrow_c_array__(py, None)?,
86 )
87 }
88
89 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
91 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
92 let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
93 arro3_mod
94 .getattr(intern!(py, "RecordBatch"))?
95 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
96 }
97
98 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
100 to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
101 }
102
103 pub fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
107 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
108 pyarrow_mod
109 .getattr(intern!(py, "record_batch"))?
110 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)
111 }
112
113 pub(crate) fn to_array_pycapsules<'py>(
114 py: Python<'py>,
115 record_batch: RecordBatch,
116 requested_schema: Option<Bound<'py, PyCapsule>>,
117 ) -> PyArrowResult<Bound<'py, PyTuple>> {
118 let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false)
119 .with_metadata(record_batch.schema_ref().metadata().clone());
120 let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
121 to_array_pycapsules(py, field.into(), &array, requested_schema)
122 }
123}
124
125impl From<RecordBatch> for PyRecordBatch {
126 fn from(value: RecordBatch) -> Self {
127 Self(value)
128 }
129}
130
131impl From<PyRecordBatch> for RecordBatch {
132 fn from(value: PyRecordBatch) -> Self {
133 value.0
134 }
135}
136
137impl AsRef<RecordBatch> for PyRecordBatch {
138 fn as_ref(&self) -> &RecordBatch {
139 &self.0
140 }
141}
142
143impl Display for PyRecordBatch {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 writeln!(f, "arro3.core.RecordBatch")?;
146 pretty_format_batches_with_options(
147 &[self.0.slice(0, 10.min(self.0.num_rows()))],
148 &default_repr_options(),
149 )
150 .map_err(|_| std::fmt::Error)?
151 .fmt(f)?;
152
153 Ok(())
154 }
155}
156
157#[pymethods]
158impl PyRecordBatch {
159 #[new]
160 #[pyo3(signature = (data, *, names=None, schema=None, metadata=None))]
161 fn init(
162 py: Python,
163 data: &Bound<PyAny>,
164 names: Option<Vec<String>>,
165 schema: Option<PySchema>,
166 metadata: Option<MetadataInput>,
167 ) -> PyArrowResult<Self> {
168 if data.hasattr(intern!(py, "__arrow_c_array__"))? {
169 Ok(data.extract::<PyRecordBatch>()?)
170 } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
171 Self::from_pydict(&py.get_type::<Self>(), mapping, metadata)
172 } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
173 Self::from_arrays(&py.get_type::<Self>(), arrays, names, schema, metadata)
174 } else {
175 Err(PyTypeError::new_err(
176 "Expected RecordBatch-like input or dict of arrays or list of arrays.",
177 )
178 .into())
179 }
180 }
181
182 #[pyo3(signature = (requested_schema=None))]
183 fn __arrow_c_array__<'py>(
184 &'py self,
185 py: Python<'py>,
186 requested_schema: Option<Bound<'py, PyCapsule>>,
187 ) -> PyArrowResult<Bound<'py, PyTuple>> {
188 Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
189 }
190
191 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
192 to_schema_pycapsule(py, self.0.schema_ref().as_ref())
193 }
194
195 fn __eq__(&self, other: PyRecordBatch) -> bool {
196 self.0 == other.0
197 }
198
199 fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
200 self.column(key)
201 }
202
203 fn __repr__(&self) -> String {
204 self.to_string()
205 }
206
207 #[classmethod]
208 #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))]
209 fn from_arrays(
210 _cls: &Bound<PyType>,
211 arrays: Vec<PyArray>,
212 names: Option<Vec<String>>,
213 schema: Option<PySchema>,
214 metadata: Option<MetadataInput>,
215 ) -> PyArrowResult<Self> {
216 if schema.is_some() && metadata.is_some() {
217 return Err(PyValueError::new_err("Cannot pass both schema and metadata").into());
218 }
219
220 let (arrays, fields): (Vec<ArrayRef>, Vec<FieldRef>) =
221 arrays.into_iter().map(|arr| arr.into_inner()).unzip();
222
223 let schema: SchemaRef = if let Some(schema) = schema {
224 schema.into_inner()
225 } else {
226 let names = names.ok_or(PyValueError::new_err(
227 "names must be passed if schema is not passed.",
228 ))?;
229
230 let fields: Vec<_> = fields
231 .iter()
232 .zip(names.iter())
233 .map(|(field, name)| field.as_ref().clone().with_name(name))
234 .collect();
235
236 Arc::new(
237 Schema::new(fields)
238 .with_metadata(metadata.unwrap_or_default().into_string_hashmap()?),
239 )
240 };
241
242 if arrays.is_empty() {
243 let rb = RecordBatch::try_new(schema, vec![])?;
244 return Ok(Self::new(rb));
245 }
246
247 let rb = RecordBatch::try_new(schema, arrays)?;
248 Ok(Self::new(rb))
249 }
250
251 #[classmethod]
252 #[pyo3(signature = (mapping, *, metadata=None))]
253 fn from_pydict(
254 _cls: &Bound<PyType>,
255 mapping: IndexMap<String, PyArray>,
256 metadata: Option<MetadataInput>,
257 ) -> PyArrowResult<Self> {
258 let mut fields = vec![];
259 let mut arrays = vec![];
260 mapping.into_iter().for_each(|(name, py_array)| {
261 let (arr, field) = py_array.into_inner();
262 fields.push(field.as_ref().clone().with_name(name));
263 arrays.push(arr);
264 });
265 let schema =
266 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
267 let rb = RecordBatch::try_new(schema.into(), arrays)?;
268 Ok(Self::new(rb))
269 }
270
271 #[classmethod]
272 fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
273 let (array, field) = struct_array.into_inner();
274 match field.data_type() {
275 DataType::Struct(fields) => {
276 let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
277 let struct_arr = array.as_struct();
278 let columns = struct_arr.columns().to_vec();
279 let rb = RecordBatch::try_new(schema.into(), columns)?;
280 Ok(Self::new(rb))
281 }
282 _ => Err(PyTypeError::new_err("Expected struct array").into()),
283 }
284 }
285
286 #[classmethod]
287 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
288 match input {
289 AnyRecordBatch::RecordBatch(rb) => Ok(rb),
290 AnyRecordBatch::Stream(stream) => {
291 let (batches, schema) = stream.into_table()?.into_inner();
292 let single_batch = concat_batches(&schema, batches.iter())?;
293 Ok(Self::new(single_batch))
294 }
295 }
296 }
297
298 #[classmethod]
299 #[pyo3(name = "from_arrow_pycapsule")]
300 fn from_arrow_pycapsule_py(
301 _cls: &Bound<PyType>,
302 schema_capsule: &Bound<PyCapsule>,
303 array_capsule: &Bound<PyCapsule>,
304 ) -> PyResult<Self> {
305 Self::from_arrow_pycapsule(schema_capsule, array_capsule)
306 }
307
308 fn add_column(
309 &self,
310 i: usize,
311 field: NameOrField,
312 column: PyArray,
313 ) -> PyArrowResult<Arro3RecordBatch> {
314 let mut fields = self.0.schema_ref().fields().to_vec();
315 fields.insert(i, field.into_field(column.field()));
316 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
317
318 let mut arrays = self.0.columns().to_vec();
319 arrays.insert(i, column.array().clone());
320
321 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
322 Ok(PyRecordBatch::new(new_rb).into())
323 }
324
325 fn append_column(
326 &self,
327 field: NameOrField,
328 column: PyArray,
329 ) -> PyArrowResult<Arro3RecordBatch> {
330 let mut fields = self.0.schema_ref().fields().to_vec();
331 fields.push(field.into_field(column.field()));
332 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
333
334 let mut arrays = self.0.columns().to_vec();
335 arrays.push(column.array().clone());
336
337 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
338 Ok(PyRecordBatch::new(new_rb).into())
339 }
340
341 fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
342 let column_index = i.into_position(self.0.schema_ref())?;
343 let field = self.0.schema().field(column_index).clone();
344 let array = self.0.column(column_index).clone();
345 Ok(PyArray::new(array, field.into()).into())
346 }
347
348 #[getter]
349 fn column_names(&self) -> Vec<String> {
350 self.0
351 .schema()
352 .fields()
353 .iter()
354 .map(|f| f.name().clone())
355 .collect()
356 }
357
358 #[getter]
359 fn columns(&self) -> PyResult<Vec<Arro3Array>> {
360 (0..self.num_columns())
361 .map(|i| self.column(FieldIndexInput::Position(i)))
362 .collect()
363 }
364
365 fn equals(&self, other: PyRecordBatch) -> bool {
366 self.0 == other.0
367 }
368
369 fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
370 let schema_ref = self.0.schema_ref();
371 let field = schema_ref.field(i.into_position(schema_ref)?);
372 Ok(PyField::new(field.clone().into()).into())
373 }
374
375 #[getter]
376 fn nbytes(&self) -> usize {
377 self.0.get_array_memory_size()
378 }
379
380 #[getter]
381 fn num_columns(&self) -> usize {
382 self.0.num_columns()
383 }
384
385 #[getter]
386 fn num_rows(&self) -> usize {
387 self.0.num_rows()
388 }
389
390 fn remove_column(&self, i: usize) -> Arro3RecordBatch {
391 let mut rb = self.0.clone();
392 rb.remove_column(i);
393 PyRecordBatch::new(rb).into()
394 }
395
396 #[getter]
397 fn schema(&self) -> Arro3Schema {
398 self.0.schema().into()
399 }
400
401 fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
402 let positions = columns.into_positions(self.0.schema_ref().fields())?;
403 Ok(self.0.project(&positions)?.into())
404 }
405
406 fn set_column(
407 &self,
408 i: usize,
409 field: NameOrField,
410 column: PyArray,
411 ) -> PyArrowResult<Arro3RecordBatch> {
412 let mut fields = self.0.schema_ref().fields().to_vec();
413 fields[i] = field.into_field(column.field());
414 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
415
416 let mut arrays = self.0.columns().to_vec();
417 arrays[i] = column.array().clone();
418
419 Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
420 }
421
422 #[getter]
423 fn shape(&self) -> (usize, usize) {
424 (self.num_rows(), self.num_columns())
425 }
426
427 #[pyo3(signature = (offset=0, length=None))]
428 fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
429 let length = length.unwrap_or_else(|| self.num_rows() - offset);
430 self.0.slice(offset, length).into()
431 }
432
433 fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
434 let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
435 Ok(new_batch.into())
436 }
437
438 fn to_struct_array(&self) -> Arro3Array {
439 let struct_array: StructArray = self.0.clone().into();
440 let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
441 .with_metadata(self.0.schema_ref().metadata.clone());
442 PyArray::new(Arc::new(struct_array), field.into()).into()
443 }
444
445 fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
446 let new_schema = schema.into_inner();
447 let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
448 Ok(new_batch.into())
449 }
450}