1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow_array::cast::AsArray;
5use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
6use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
7use arrow_select::concat::concat_batches;
8use arrow_select::take::take_record_batch;
9use indexmap::IndexMap;
10use pyo3::exceptions::{PyTypeError, PyValueError};
11use pyo3::prelude::*;
12use pyo3::types::{PyCapsule, PyTuple, PyType};
13use pyo3::{intern, IntoPyObjectExt};
14
15use crate::error::PyArrowResult;
16use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
17use crate::ffi::from_python::utils::import_array_pycapsules;
18use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
19use crate::ffi::to_python::to_array_pycapsules;
20use crate::ffi::to_schema_pycapsule;
21use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
22use crate::schema::display_schema;
23use crate::{PyArray, PyField, PySchema};
24
25#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
29#[derive(Debug)]
30pub struct PyRecordBatch(RecordBatch);
31
32impl PyRecordBatch {
33 pub fn new(batch: RecordBatch) -> Self {
35 Self(batch)
36 }
37
38 pub fn from_arrow_pycapsule(
40 schema_capsule: &Bound<PyCapsule>,
41 array_capsule: &Bound<PyCapsule>,
42 ) -> PyResult<Self> {
43 let (array, field, data_len) = import_array_pycapsules(schema_capsule, array_capsule)?;
44
45 match field.data_type() {
46 DataType::Struct(fields) => {
47 let struct_array = array.as_struct();
48 let schema = SchemaBuilder::from(fields)
49 .finish()
50 .with_metadata(field.metadata().clone());
51 assert_eq!(
52 struct_array.null_count(),
53 0,
54 "Cannot convert nullable StructArray to RecordBatch"
55 );
56
57 let columns = struct_array.columns().to_vec();
58
59 let batch = if array.is_empty() && data_len > 0 {
61 RecordBatch::try_new_with_options(
62 Arc::new(schema),
63 columns,
64 &RecordBatchOptions::new().with_row_count(Some(data_len)),
65 )
66 .map_err(|err| PyValueError::new_err(err.to_string()))?
67 } else {
68 RecordBatch::try_new(Arc::new(schema), columns)
69 .map_err(|err| PyValueError::new_err(err.to_string()))?
70 };
71 Ok(Self::new(batch))
72 }
73 dt => Err(PyValueError::new_err(format!(
74 "Unexpected data type {}",
75 dt
76 ))),
77 }
78 }
79
80 pub fn into_inner(self) -> RecordBatch {
82 self.0
83 }
84
85 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
87 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
88 arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
89 intern!(py, "from_arrow_pycapsule"),
90 self.__arrow_c_array__(py, None)?,
91 )
92 }
93
94 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
96 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
97 let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
98 arro3_mod
99 .getattr(intern!(py, "RecordBatch"))?
100 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
101 }
102
103 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
105 to_nanoarrow_array(py, self.__arrow_c_array__(py, None)?)
106 }
107
108 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
112 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
113 let pyarrow_obj = pyarrow_mod
114 .getattr(intern!(py, "record_batch"))?
115 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
116 pyarrow_obj.into_py_any(py)
117 }
118
119 pub(crate) fn to_array_pycapsules<'py>(
120 py: Python<'py>,
121 record_batch: RecordBatch,
122 requested_schema: Option<Bound<'py, PyCapsule>>,
123 ) -> PyArrowResult<Bound<'py, PyTuple>> {
124 let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
125 let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
126 to_array_pycapsules(py, field.into(), &array, requested_schema)
127 }
128}
129
130impl From<RecordBatch> for PyRecordBatch {
131 fn from(value: RecordBatch) -> Self {
132 Self(value)
133 }
134}
135
136impl From<PyRecordBatch> for RecordBatch {
137 fn from(value: PyRecordBatch) -> Self {
138 value.0
139 }
140}
141
142impl AsRef<RecordBatch> for PyRecordBatch {
143 fn as_ref(&self) -> &RecordBatch {
144 &self.0
145 }
146}
147
148impl Display for PyRecordBatch {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 writeln!(f, "arro3.core.RecordBatch")?;
151 writeln!(f, "-----------------")?;
152 display_schema(&self.0.schema(), f)
153 }
154}
155
156#[pymethods]
157impl PyRecordBatch {
158 #[new]
159 #[pyo3(signature = (data, *, schema=None, metadata=None))]
160 fn init(
161 py: Python,
162 data: &Bound<PyAny>,
163 schema: Option<PySchema>,
164 metadata: Option<MetadataInput>,
165 ) -> PyArrowResult<Self> {
166 if let Ok(data) = data.extract::<PyRecordBatch>() {
167 Ok(data)
168 } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
169 Self::from_pydict(&py.get_type::<PyRecordBatch>(), mapping, metadata)
170 } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
171 Self::from_arrays(
172 &py.get_type::<PyRecordBatch>(),
173 arrays,
174 schema.ok_or(PyValueError::new_err(
175 "Schema must be passed with list of arrays",
176 ))?,
177 )
178 } else {
179 Err(PyTypeError::new_err(
180 "Expected RecordBatch-like input or dict of arrays or list of arrays.",
181 )
182 .into())
183 }
184 }
185
186 #[pyo3(signature = (requested_schema=None))]
187 fn __arrow_c_array__<'py>(
188 &'py self,
189 py: Python<'py>,
190 requested_schema: Option<Bound<'py, PyCapsule>>,
191 ) -> PyArrowResult<Bound<'py, PyTuple>> {
192 Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
193 }
194
195 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
196 to_schema_pycapsule(py, self.0.schema_ref().as_ref())
197 }
198
199 fn __eq__(&self, other: &PyRecordBatch) -> bool {
200 self.0 == other.0
201 }
202
203 fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
204 self.column(key)
205 }
206
207 fn __repr__(&self) -> String {
208 self.to_string()
209 }
210
211 #[classmethod]
212 #[pyo3(signature = (arrays, *, schema))]
213 fn from_arrays(
214 _cls: &Bound<PyType>,
215 arrays: Vec<PyArray>,
216 schema: PySchema,
217 ) -> PyArrowResult<Self> {
218 let rb = RecordBatch::try_new(
219 schema.into(),
220 arrays
221 .into_iter()
222 .map(|arr| {
223 let (arr, _field) = arr.into_inner();
224 arr
225 })
226 .collect(),
227 )?;
228 Ok(Self::new(rb))
229 }
230
231 #[classmethod]
232 #[pyo3(signature = (mapping, *, metadata=None))]
233 fn from_pydict(
234 _cls: &Bound<PyType>,
235 mapping: IndexMap<String, PyArray>,
236 metadata: Option<MetadataInput>,
237 ) -> PyArrowResult<Self> {
238 let mut fields = vec![];
239 let mut arrays = vec![];
240 mapping.into_iter().for_each(|(name, py_array)| {
241 let (arr, field) = py_array.into_inner();
242 fields.push(field.as_ref().clone().with_name(name));
243 arrays.push(arr);
244 });
245 let schema =
246 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
247 let rb = RecordBatch::try_new(schema.into(), arrays)?;
248 Ok(Self::new(rb))
249 }
250
251 #[classmethod]
252 fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
253 let (array, field) = struct_array.into_inner();
254 match field.data_type() {
255 DataType::Struct(fields) => {
256 let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
257 let struct_arr = array.as_struct();
258 let columns = struct_arr.columns().to_vec();
259 let rb = RecordBatch::try_new(schema.into(), columns)?;
260 Ok(Self::new(rb))
261 }
262 _ => Err(PyTypeError::new_err("Expected struct array").into()),
263 }
264 }
265
266 #[classmethod]
267 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
268 match input {
269 AnyRecordBatch::RecordBatch(rb) => Ok(rb),
270 AnyRecordBatch::Stream(stream) => {
271 let (batches, schema) = stream.into_table()?.into_inner();
272 let single_batch = concat_batches(&schema, batches.iter())?;
273 Ok(Self::new(single_batch))
274 }
275 }
276 }
277
278 #[classmethod]
279 #[pyo3(name = "from_arrow_pycapsule")]
280 fn from_arrow_pycapsule_py(
281 _cls: &Bound<PyType>,
282 schema_capsule: &Bound<PyCapsule>,
283 array_capsule: &Bound<PyCapsule>,
284 ) -> PyResult<Self> {
285 Self::from_arrow_pycapsule(schema_capsule, array_capsule)
286 }
287
288 fn add_column(
289 &self,
290 i: usize,
291 field: NameOrField,
292 column: PyArray,
293 ) -> PyArrowResult<Arro3RecordBatch> {
294 let mut fields = self.0.schema_ref().fields().to_vec();
295 fields.insert(i, field.into_field(column.field()));
296 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
297
298 let mut arrays = self.0.columns().to_vec();
299 arrays.insert(i, column.array().clone());
300
301 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
302 Ok(PyRecordBatch::new(new_rb).into())
303 }
304
305 fn append_column(
306 &self,
307 field: NameOrField,
308 column: PyArray,
309 ) -> PyArrowResult<Arro3RecordBatch> {
310 let mut fields = self.0.schema_ref().fields().to_vec();
311 fields.push(field.into_field(column.field()));
312 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
313
314 let mut arrays = self.0.columns().to_vec();
315 arrays.push(column.array().clone());
316
317 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
318 Ok(PyRecordBatch::new(new_rb).into())
319 }
320
321 fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
322 let column_index = i.into_position(self.0.schema_ref())?;
323 let field = self.0.schema().field(column_index).clone();
324 let array = self.0.column(column_index).clone();
325 Ok(PyArray::new(array, field.into()).into())
326 }
327
328 #[getter]
329 fn column_names(&self) -> Vec<String> {
330 self.0
331 .schema()
332 .fields()
333 .iter()
334 .map(|f| f.name().clone())
335 .collect()
336 }
337
338 #[getter]
339 fn columns(&self) -> PyResult<Vec<Arro3Array>> {
340 (0..self.num_columns())
341 .map(|i| self.column(FieldIndexInput::Position(i)))
342 .collect()
343 }
344
345 fn equals(&self, other: PyRecordBatch) -> bool {
346 self.0 == other.0
347 }
348
349 fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
350 let schema_ref = self.0.schema_ref();
351 let field = schema_ref.field(i.into_position(schema_ref)?);
352 Ok(PyField::new(field.clone().into()).into())
353 }
354
355 #[getter]
356 fn nbytes(&self) -> usize {
357 self.0.get_array_memory_size()
358 }
359
360 #[getter]
361 fn num_columns(&self) -> usize {
362 self.0.num_columns()
363 }
364
365 #[getter]
366 fn num_rows(&self) -> usize {
367 self.0.num_rows()
368 }
369
370 fn remove_column(&self, i: usize) -> Arro3RecordBatch {
371 let mut rb = self.0.clone();
372 rb.remove_column(i);
373 PyRecordBatch::new(rb).into()
374 }
375
376 #[getter]
377 fn schema(&self) -> Arro3Schema {
378 self.0.schema().into()
379 }
380
381 fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
382 let positions = columns.into_positions(self.0.schema_ref().fields())?;
383 Ok(self.0.project(&positions)?.into())
384 }
385
386 fn set_column(
387 &self,
388 i: usize,
389 field: NameOrField,
390 column: PyArray,
391 ) -> PyArrowResult<Arro3RecordBatch> {
392 let mut fields = self.0.schema_ref().fields().to_vec();
393 fields[i] = field.into_field(column.field());
394 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
395
396 let mut arrays = self.0.columns().to_vec();
397 arrays[i] = column.array().clone();
398
399 Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
400 }
401
402 #[getter]
403 fn shape(&self) -> (usize, usize) {
404 (self.num_rows(), self.num_columns())
405 }
406
407 #[pyo3(signature = (offset=0, length=None))]
408 fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
409 let length = length.unwrap_or_else(|| self.num_rows() - offset);
410 self.0.slice(offset, length).into()
411 }
412
413 fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
414 let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
415 Ok(new_batch.into())
416 }
417
418 fn to_struct_array(&self) -> Arro3Array {
419 let struct_array: StructArray = self.0.clone().into();
420 let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
421 .with_metadata(self.0.schema_ref().metadata.clone());
422 PyArray::new(Arc::new(struct_array), field.into()).into()
423 }
424
425 fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
426 let new_schema = schema.into_inner();
427 let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
428 Ok(new_batch.into())
429 }
430}