1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_cast::pretty::pretty_format_batches_with_options;
7use arrow_schema::{DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
8use arrow_select::concat::concat_batches;
9use arrow_select::take::take_record_batch;
10use indexmap::IndexMap;
11use pyo3::exceptions::{PyTypeError, PyValueError};
12use pyo3::intern;
13use pyo3::prelude::*;
14use pyo3::types::{PyCapsule, PyTuple, PyType};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
18use crate::ffi::from_python::utils::import_array_pycapsules;
19use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
20use crate::ffi::to_python::to_array_pycapsules;
21use crate::ffi::to_schema_pycapsule;
22use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
23use crate::utils::default_repr_options;
24use crate::{PyArray, PyField, PySchema};
25
26#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
30#[derive(Debug)]
31pub struct PyRecordBatch(RecordBatch);
32
33impl PyRecordBatch {
34 pub fn new(batch: RecordBatch) -> Self {
36 Self(batch)
37 }
38
39 pub fn from_arrow_pycapsule(
41 schema_capsule: &Bound<PyCapsule>,
42 array_capsule: &Bound<PyCapsule>,
43 ) -> PyResult<Self> {
44 let (array, field) = import_array_pycapsules(schema_capsule, array_capsule)?;
45
46 match field.data_type() {
47 DataType::Struct(fields) => {
48 let struct_array = array.as_struct();
49 let schema = SchemaBuilder::from(fields)
50 .finish()
51 .with_metadata(field.metadata().clone());
52 assert_eq!(
53 struct_array.null_count(),
54 0,
55 "Cannot convert nullable StructArray to RecordBatch"
56 );
57
58 let columns = struct_array.columns().to_vec();
59
60 let batch = RecordBatch::try_new_with_options(
61 Arc::new(schema),
62 columns,
63 &RecordBatchOptions::new().with_row_count(Some(array.len())),
64 )
65 .map_err(|err| PyValueError::new_err(err.to_string()))?;
66 Ok(Self::new(batch))
67 }
68 dt => Err(PyValueError::new_err(format!(
69 "Unexpected data type {}",
70 dt
71 ))),
72 }
73 }
74
75 pub fn into_inner(self) -> RecordBatch {
77 self.0
78 }
79
80 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
82 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
83 arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
84 intern!(py, "from_arrow_pycapsule"),
85 self.__arrow_c_array__(py, None)?,
86 )
87 }
88
89 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
91 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
92 let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
93 arro3_mod
94 .getattr(intern!(py, "RecordBatch"))?
95 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
96 }
97
98 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
100 to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
101 }
102
103 pub fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
107 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
108 pyarrow_mod
109 .getattr(intern!(py, "record_batch"))?
110 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)
111 }
112
113 pub(crate) fn to_array_pycapsules<'py>(
114 py: Python<'py>,
115 record_batch: RecordBatch,
116 requested_schema: Option<Bound<'py, PyCapsule>>,
117 ) -> PyArrowResult<Bound<'py, PyTuple>> {
118 let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
119 let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
120 to_array_pycapsules(py, field.into(), &array, requested_schema)
121 }
122}
123
124impl From<RecordBatch> for PyRecordBatch {
125 fn from(value: RecordBatch) -> Self {
126 Self(value)
127 }
128}
129
130impl From<PyRecordBatch> for RecordBatch {
131 fn from(value: PyRecordBatch) -> Self {
132 value.0
133 }
134}
135
136impl AsRef<RecordBatch> for PyRecordBatch {
137 fn as_ref(&self) -> &RecordBatch {
138 &self.0
139 }
140}
141
142impl Display for PyRecordBatch {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 writeln!(f, "arro3.core.RecordBatch")?;
145 pretty_format_batches_with_options(
146 &[self.0.slice(0, 10.min(self.0.num_rows()))],
147 &default_repr_options(),
148 )
149 .map_err(|_| std::fmt::Error)?
150 .fmt(f)?;
151
152 Ok(())
153 }
154}
155
156#[pymethods]
157impl PyRecordBatch {
158 #[new]
159 #[pyo3(signature = (data, *, names=None, schema=None, metadata=None))]
160 fn init(
161 py: Python,
162 data: &Bound<PyAny>,
163 names: Option<Vec<String>>,
164 schema: Option<PySchema>,
165 metadata: Option<MetadataInput>,
166 ) -> PyArrowResult<Self> {
167 if data.hasattr(intern!(py, "__arrow_c_array__"))? {
168 Ok(data.extract::<PyRecordBatch>()?)
169 } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
170 Self::from_pydict(&py.get_type::<Self>(), mapping, metadata)
171 } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
172 Self::from_arrays(&py.get_type::<Self>(), arrays, names, schema, metadata)
173 } else {
174 Err(PyTypeError::new_err(
175 "Expected RecordBatch-like input or dict of arrays or list of arrays.",
176 )
177 .into())
178 }
179 }
180
181 #[pyo3(signature = (requested_schema=None))]
182 fn __arrow_c_array__<'py>(
183 &'py self,
184 py: Python<'py>,
185 requested_schema: Option<Bound<'py, PyCapsule>>,
186 ) -> PyArrowResult<Bound<'py, PyTuple>> {
187 Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
188 }
189
190 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
191 to_schema_pycapsule(py, self.0.schema_ref().as_ref())
192 }
193
194 fn __eq__(&self, other: &PyRecordBatch) -> bool {
195 self.0 == other.0
196 }
197
198 fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
199 self.column(key)
200 }
201
202 fn __repr__(&self) -> String {
203 self.to_string()
204 }
205
206 #[classmethod]
207 #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))]
208 fn from_arrays(
209 _cls: &Bound<PyType>,
210 arrays: Vec<PyArray>,
211 names: Option<Vec<String>>,
212 schema: Option<PySchema>,
213 metadata: Option<MetadataInput>,
214 ) -> PyArrowResult<Self> {
215 if schema.is_some() && metadata.is_some() {
216 return Err(PyValueError::new_err("Cannot pass both schema and metadata").into());
217 }
218
219 let (arrays, fields): (Vec<ArrayRef>, Vec<FieldRef>) =
220 arrays.into_iter().map(|arr| arr.into_inner()).unzip();
221
222 let schema: SchemaRef = if let Some(schema) = schema {
223 schema.into_inner()
224 } else {
225 let names = names.ok_or(PyValueError::new_err(
226 "names must be passed if schema is not passed.",
227 ))?;
228
229 let fields: Vec<_> = fields
230 .iter()
231 .zip(names.iter())
232 .map(|(field, name)| field.as_ref().clone().with_name(name))
233 .collect();
234
235 Arc::new(
236 Schema::new(fields)
237 .with_metadata(metadata.unwrap_or_default().into_string_hashmap()?),
238 )
239 };
240
241 if arrays.is_empty() {
242 let rb = RecordBatch::try_new(schema, vec![])?;
243 return Ok(Self::new(rb));
244 }
245
246 let rb = RecordBatch::try_new(schema, arrays)?;
247 Ok(Self::new(rb))
248 }
249
250 #[classmethod]
251 #[pyo3(signature = (mapping, *, metadata=None))]
252 fn from_pydict(
253 _cls: &Bound<PyType>,
254 mapping: IndexMap<String, PyArray>,
255 metadata: Option<MetadataInput>,
256 ) -> PyArrowResult<Self> {
257 let mut fields = vec![];
258 let mut arrays = vec![];
259 mapping.into_iter().for_each(|(name, py_array)| {
260 let (arr, field) = py_array.into_inner();
261 fields.push(field.as_ref().clone().with_name(name));
262 arrays.push(arr);
263 });
264 let schema =
265 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
266 let rb = RecordBatch::try_new(schema.into(), arrays)?;
267 Ok(Self::new(rb))
268 }
269
270 #[classmethod]
271 fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
272 let (array, field) = struct_array.into_inner();
273 match field.data_type() {
274 DataType::Struct(fields) => {
275 let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
276 let struct_arr = array.as_struct();
277 let columns = struct_arr.columns().to_vec();
278 let rb = RecordBatch::try_new(schema.into(), columns)?;
279 Ok(Self::new(rb))
280 }
281 _ => Err(PyTypeError::new_err("Expected struct array").into()),
282 }
283 }
284
285 #[classmethod]
286 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
287 match input {
288 AnyRecordBatch::RecordBatch(rb) => Ok(rb),
289 AnyRecordBatch::Stream(stream) => {
290 let (batches, schema) = stream.into_table()?.into_inner();
291 let single_batch = concat_batches(&schema, batches.iter())?;
292 Ok(Self::new(single_batch))
293 }
294 }
295 }
296
297 #[classmethod]
298 #[pyo3(name = "from_arrow_pycapsule")]
299 fn from_arrow_pycapsule_py(
300 _cls: &Bound<PyType>,
301 schema_capsule: &Bound<PyCapsule>,
302 array_capsule: &Bound<PyCapsule>,
303 ) -> PyResult<Self> {
304 Self::from_arrow_pycapsule(schema_capsule, array_capsule)
305 }
306
307 fn add_column(
308 &self,
309 i: usize,
310 field: NameOrField,
311 column: PyArray,
312 ) -> PyArrowResult<Arro3RecordBatch> {
313 let mut fields = self.0.schema_ref().fields().to_vec();
314 fields.insert(i, field.into_field(column.field()));
315 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
316
317 let mut arrays = self.0.columns().to_vec();
318 arrays.insert(i, column.array().clone());
319
320 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
321 Ok(PyRecordBatch::new(new_rb).into())
322 }
323
324 fn append_column(
325 &self,
326 field: NameOrField,
327 column: PyArray,
328 ) -> PyArrowResult<Arro3RecordBatch> {
329 let mut fields = self.0.schema_ref().fields().to_vec();
330 fields.push(field.into_field(column.field()));
331 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
332
333 let mut arrays = self.0.columns().to_vec();
334 arrays.push(column.array().clone());
335
336 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
337 Ok(PyRecordBatch::new(new_rb).into())
338 }
339
340 fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
341 let column_index = i.into_position(self.0.schema_ref())?;
342 let field = self.0.schema().field(column_index).clone();
343 let array = self.0.column(column_index).clone();
344 Ok(PyArray::new(array, field.into()).into())
345 }
346
347 #[getter]
348 fn column_names(&self) -> Vec<String> {
349 self.0
350 .schema()
351 .fields()
352 .iter()
353 .map(|f| f.name().clone())
354 .collect()
355 }
356
357 #[getter]
358 fn columns(&self) -> PyResult<Vec<Arro3Array>> {
359 (0..self.num_columns())
360 .map(|i| self.column(FieldIndexInput::Position(i)))
361 .collect()
362 }
363
364 fn equals(&self, other: PyRecordBatch) -> bool {
365 self.0 == other.0
366 }
367
368 fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
369 let schema_ref = self.0.schema_ref();
370 let field = schema_ref.field(i.into_position(schema_ref)?);
371 Ok(PyField::new(field.clone().into()).into())
372 }
373
374 #[getter]
375 fn nbytes(&self) -> usize {
376 self.0.get_array_memory_size()
377 }
378
379 #[getter]
380 fn num_columns(&self) -> usize {
381 self.0.num_columns()
382 }
383
384 #[getter]
385 fn num_rows(&self) -> usize {
386 self.0.num_rows()
387 }
388
389 fn remove_column(&self, i: usize) -> Arro3RecordBatch {
390 let mut rb = self.0.clone();
391 rb.remove_column(i);
392 PyRecordBatch::new(rb).into()
393 }
394
395 #[getter]
396 fn schema(&self) -> Arro3Schema {
397 self.0.schema().into()
398 }
399
400 fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
401 let positions = columns.into_positions(self.0.schema_ref().fields())?;
402 Ok(self.0.project(&positions)?.into())
403 }
404
405 fn set_column(
406 &self,
407 i: usize,
408 field: NameOrField,
409 column: PyArray,
410 ) -> PyArrowResult<Arro3RecordBatch> {
411 let mut fields = self.0.schema_ref().fields().to_vec();
412 fields[i] = field.into_field(column.field());
413 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
414
415 let mut arrays = self.0.columns().to_vec();
416 arrays[i] = column.array().clone();
417
418 Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
419 }
420
421 #[getter]
422 fn shape(&self) -> (usize, usize) {
423 (self.num_rows(), self.num_columns())
424 }
425
426 #[pyo3(signature = (offset=0, length=None))]
427 fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
428 let length = length.unwrap_or_else(|| self.num_rows() - offset);
429 self.0.slice(offset, length).into()
430 }
431
432 fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
433 let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
434 Ok(new_batch.into())
435 }
436
437 fn to_struct_array(&self) -> Arro3Array {
438 let struct_array: StructArray = self.0.clone().into();
439 let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
440 .with_metadata(self.0.schema_ref().metadata.clone());
441 PyArray::new(Arc::new(struct_array), field.into()).into()
442 }
443
444 fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
445 let new_schema = schema.into_inner();
446 let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
447 Ok(new_batch.into())
448 }
449}