1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow::array::AsArray;
5use arrow::compute::{concat_batches, take_record_batch};
6use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, StructArray};
7use arrow_schema::{DataType, Field, Schema, SchemaBuilder};
8use indexmap::IndexMap;
9use pyo3::exceptions::{PyTypeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyCapsule, PyTuple, PyType};
12use pyo3::{intern, IntoPyObjectExt};
13
14use crate::error::PyArrowResult;
15use crate::export::{Arro3Array, Arro3Field, Arro3RecordBatch, Arro3Schema};
16use crate::ffi::from_python::utils::import_array_pycapsules;
17use crate::ffi::to_python::nanoarrow::to_nanoarrow_array;
18use crate::ffi::to_python::to_array_pycapsules;
19use crate::ffi::to_schema_pycapsule;
20use crate::input::{AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices};
21use crate::schema::display_schema;
22use crate::{PyArray, PyField, PySchema};
23
24#[pyclass(module = "arro3.core._core", name = "RecordBatch", subclass, frozen)]
28#[derive(Debug)]
29pub struct PyRecordBatch(RecordBatch);
30
31impl PyRecordBatch {
32 pub fn new(batch: RecordBatch) -> Self {
34 Self(batch)
35 }
36
37 pub fn from_arrow_pycapsule(
39 schema_capsule: &Bound<PyCapsule>,
40 array_capsule: &Bound<PyCapsule>,
41 ) -> PyResult<Self> {
42 let (array, field, data_len) = import_array_pycapsules(schema_capsule, array_capsule)?;
43
44 match field.data_type() {
45 DataType::Struct(fields) => {
46 let struct_array = array.as_struct();
47 let schema = SchemaBuilder::from(fields)
48 .finish()
49 .with_metadata(field.metadata().clone());
50 assert_eq!(
51 struct_array.null_count(),
52 0,
53 "Cannot convert nullable StructArray to RecordBatch"
54 );
55
56 let columns = struct_array.columns().to_vec();
57
58 let batch = if array.len() == 0 && data_len > 0 {
60 RecordBatch::try_new_with_options(
61 Arc::new(schema),
62 columns,
63 &RecordBatchOptions::new().with_row_count(Some(data_len)),
64 )
65 .map_err(|err| PyValueError::new_err(err.to_string()))?
66 } else {
67 RecordBatch::try_new(Arc::new(schema), columns)
68 .map_err(|err| PyValueError::new_err(err.to_string()))?
69 };
70 Ok(Self::new(batch))
71 }
72 dt => Err(PyValueError::new_err(format!(
73 "Unexpected data type {}",
74 dt
75 ))),
76 }
77 }
78
79 pub fn into_inner(self) -> RecordBatch {
81 self.0
82 }
83
84 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
86 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
87 arro3_mod.getattr(intern!(py, "RecordBatch"))?.call_method1(
88 intern!(py, "from_arrow_pycapsule"),
89 self.__arrow_c_array__(py, None)?,
90 )
91 }
92
93 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
95 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
96 let capsules = Self::to_array_pycapsules(py, self.0.clone(), None)?;
97 arro3_mod
98 .getattr(intern!(py, "RecordBatch"))?
99 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
100 }
101
102 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
104 to_nanoarrow_array(py, &self.__arrow_c_array__(py, None)?)
105 }
106
107 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
111 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
112 let pyarrow_obj = pyarrow_mod
113 .getattr(intern!(py, "record_batch"))?
114 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
115 pyarrow_obj.into_py_any(py)
116 }
117
118 pub(crate) fn to_array_pycapsules<'py>(
119 py: Python<'py>,
120 record_batch: RecordBatch,
121 requested_schema: Option<Bound<'py, PyCapsule>>,
122 ) -> PyArrowResult<Bound<'py, PyTuple>> {
123 let field = Field::new_struct("", record_batch.schema_ref().fields().clone(), false);
124 let array: ArrayRef = Arc::new(StructArray::from(record_batch.clone()));
125 to_array_pycapsules(py, field.into(), &array, requested_schema)
126 }
127}
128
129impl From<RecordBatch> for PyRecordBatch {
130 fn from(value: RecordBatch) -> Self {
131 Self(value)
132 }
133}
134
135impl From<PyRecordBatch> for RecordBatch {
136 fn from(value: PyRecordBatch) -> Self {
137 value.0
138 }
139}
140
141impl AsRef<RecordBatch> for PyRecordBatch {
142 fn as_ref(&self) -> &RecordBatch {
143 &self.0
144 }
145}
146
147impl Display for PyRecordBatch {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 writeln!(f, "arro3.core.RecordBatch")?;
150 writeln!(f, "-----------------")?;
151 display_schema(&self.0.schema(), f)
152 }
153}
154
155#[pymethods]
156impl PyRecordBatch {
157 #[new]
158 #[pyo3(signature = (data, *, schema=None, metadata=None))]
159 fn init(
160 py: Python,
161 data: &Bound<PyAny>,
162 schema: Option<PySchema>,
163 metadata: Option<MetadataInput>,
164 ) -> PyArrowResult<Self> {
165 if let Ok(data) = data.extract::<PyRecordBatch>() {
166 Ok(data)
167 } else if let Ok(mapping) = data.extract::<IndexMap<String, PyArray>>() {
168 Self::from_pydict(&py.get_type::<PyRecordBatch>(), mapping, metadata)
169 } else if let Ok(arrays) = data.extract::<Vec<PyArray>>() {
170 Self::from_arrays(
171 &py.get_type::<PyRecordBatch>(),
172 arrays,
173 schema.ok_or(PyValueError::new_err(
174 "Schema must be passed with list of arrays",
175 ))?,
176 )
177 } else {
178 Err(PyTypeError::new_err(
179 "Expected RecordBatch-like input or dict of arrays or list of arrays.",
180 )
181 .into())
182 }
183 }
184
185 #[pyo3(signature = (requested_schema=None))]
186 fn __arrow_c_array__<'py>(
187 &'py self,
188 py: Python<'py>,
189 requested_schema: Option<Bound<'py, PyCapsule>>,
190 ) -> PyArrowResult<Bound<'py, PyTuple>> {
191 Self::to_array_pycapsules(py, self.0.clone(), requested_schema)
192 }
193
194 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
195 to_schema_pycapsule(py, self.0.schema_ref().as_ref())
196 }
197
198 fn __eq__(&self, other: &PyRecordBatch) -> bool {
199 self.0 == other.0
200 }
201
202 fn __getitem__(&self, key: FieldIndexInput) -> PyResult<Arro3Array> {
203 self.column(key)
204 }
205
206 fn __repr__(&self) -> String {
207 self.to_string()
208 }
209
210 #[classmethod]
211 #[pyo3(signature = (arrays, *, schema))]
212 fn from_arrays(
213 _cls: &Bound<PyType>,
214 arrays: Vec<PyArray>,
215 schema: PySchema,
216 ) -> PyArrowResult<Self> {
217 let rb = RecordBatch::try_new(
218 schema.into(),
219 arrays
220 .into_iter()
221 .map(|arr| {
222 let (arr, _field) = arr.into_inner();
223 arr
224 })
225 .collect(),
226 )?;
227 Ok(Self::new(rb))
228 }
229
230 #[classmethod]
231 #[pyo3(signature = (mapping, *, metadata=None))]
232 fn from_pydict(
233 _cls: &Bound<PyType>,
234 mapping: IndexMap<String, PyArray>,
235 metadata: Option<MetadataInput>,
236 ) -> PyArrowResult<Self> {
237 let mut fields = vec![];
238 let mut arrays = vec![];
239 mapping.into_iter().for_each(|(name, py_array)| {
240 let (arr, field) = py_array.into_inner();
241 fields.push(field.as_ref().clone().with_name(name));
242 arrays.push(arr);
243 });
244 let schema =
245 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?);
246 let rb = RecordBatch::try_new(schema.into(), arrays)?;
247 Ok(Self::new(rb))
248 }
249
250 #[classmethod]
251 fn from_struct_array(_cls: &Bound<PyType>, struct_array: PyArray) -> PyArrowResult<Self> {
252 let (array, field) = struct_array.into_inner();
253 match field.data_type() {
254 DataType::Struct(fields) => {
255 let schema = Schema::new_with_metadata(fields.clone(), field.metadata().clone());
256 let struct_arr = array.as_struct();
257 let columns = struct_arr.columns().to_vec();
258 let rb = RecordBatch::try_new(schema.into(), columns)?;
259 Ok(Self::new(rb))
260 }
261 _ => Err(PyTypeError::new_err("Expected struct array").into()),
262 }
263 }
264
265 #[classmethod]
266 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
267 match input {
268 AnyRecordBatch::RecordBatch(rb) => Ok(rb),
269 AnyRecordBatch::Stream(stream) => {
270 let (batches, schema) = stream.into_table()?.into_inner();
271 let single_batch = concat_batches(&schema, batches.iter())?;
272 Ok(Self::new(single_batch))
273 }
274 }
275 }
276
277 #[classmethod]
278 #[pyo3(name = "from_arrow_pycapsule")]
279 fn from_arrow_pycapsule_py(
280 _cls: &Bound<PyType>,
281 schema_capsule: &Bound<PyCapsule>,
282 array_capsule: &Bound<PyCapsule>,
283 ) -> PyResult<Self> {
284 Self::from_arrow_pycapsule(schema_capsule, array_capsule)
285 }
286
287 fn add_column(
288 &self,
289 i: usize,
290 field: NameOrField,
291 column: PyArray,
292 ) -> PyArrowResult<Arro3RecordBatch> {
293 let mut fields = self.0.schema_ref().fields().to_vec();
294 fields.insert(i, field.into_field(column.field()));
295 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
296
297 let mut arrays = self.0.columns().to_vec();
298 arrays.insert(i, column.array().clone());
299
300 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
301 Ok(PyRecordBatch::new(new_rb).into())
302 }
303
304 fn append_column(
305 &self,
306 field: NameOrField,
307 column: PyArray,
308 ) -> PyArrowResult<Arro3RecordBatch> {
309 let mut fields = self.0.schema_ref().fields().to_vec();
310 fields.push(field.into_field(column.field()));
311 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
312
313 let mut arrays = self.0.columns().to_vec();
314 arrays.push(column.array().clone());
315
316 let new_rb = RecordBatch::try_new(schema.into(), arrays)?;
317 Ok(PyRecordBatch::new(new_rb).into())
318 }
319
320 fn column(&self, i: FieldIndexInput) -> PyResult<Arro3Array> {
321 let column_index = i.into_position(self.0.schema_ref())?;
322 let field = self.0.schema().field(column_index).clone();
323 let array = self.0.column(column_index).clone();
324 Ok(PyArray::new(array, field.into()).into())
325 }
326
327 #[getter]
328 fn column_names(&self) -> Vec<String> {
329 self.0
330 .schema()
331 .fields()
332 .iter()
333 .map(|f| f.name().clone())
334 .collect()
335 }
336
337 #[getter]
338 fn columns(&self) -> PyResult<Vec<Arro3Array>> {
339 (0..self.num_columns())
340 .map(|i| self.column(FieldIndexInput::Position(i)))
341 .collect()
342 }
343
344 fn equals(&self, other: PyRecordBatch) -> bool {
345 self.0 == other.0
346 }
347
348 fn field(&self, i: FieldIndexInput) -> PyResult<Arro3Field> {
349 let schema_ref = self.0.schema_ref();
350 let field = schema_ref.field(i.into_position(schema_ref)?);
351 Ok(PyField::new(field.clone().into()).into())
352 }
353
354 #[getter]
355 fn nbytes(&self) -> usize {
356 self.0.get_array_memory_size()
357 }
358
359 #[getter]
360 fn num_columns(&self) -> usize {
361 self.0.num_columns()
362 }
363
364 #[getter]
365 fn num_rows(&self) -> usize {
366 self.0.num_rows()
367 }
368
369 fn remove_column(&self, i: usize) -> Arro3RecordBatch {
370 let mut rb = self.0.clone();
371 rb.remove_column(i);
372 PyRecordBatch::new(rb).into()
373 }
374
375 #[getter]
376 fn schema(&self) -> Arro3Schema {
377 self.0.schema().into()
378 }
379
380 fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3RecordBatch> {
381 let positions = columns.into_positions(self.0.schema_ref().fields())?;
382 Ok(self.0.project(&positions)?.into())
383 }
384
385 fn set_column(
386 &self,
387 i: usize,
388 field: NameOrField,
389 column: PyArray,
390 ) -> PyArrowResult<Arro3RecordBatch> {
391 let mut fields = self.0.schema_ref().fields().to_vec();
392 fields[i] = field.into_field(column.field());
393 let schema = Schema::new_with_metadata(fields, self.0.schema_ref().metadata().clone());
394
395 let mut arrays = self.0.columns().to_vec();
396 arrays[i] = column.array().clone();
397
398 Ok(RecordBatch::try_new(schema.into(), arrays)?.into())
399 }
400
401 #[getter]
402 fn shape(&self) -> (usize, usize) {
403 (self.num_rows(), self.num_columns())
404 }
405
406 #[pyo3(signature = (offset=0, length=None))]
407 fn slice(&self, offset: usize, length: Option<usize>) -> Arro3RecordBatch {
408 let length = length.unwrap_or_else(|| self.num_rows() - offset);
409 self.0.slice(offset, length).into()
410 }
411
412 fn take(&self, indices: PyArray) -> PyArrowResult<Arro3RecordBatch> {
413 let new_batch = take_record_batch(self.as_ref(), indices.as_ref())?;
414 Ok(new_batch.into())
415 }
416
417 fn to_struct_array(&self) -> Arro3Array {
418 let struct_array: StructArray = self.0.clone().into();
419 let field = Field::new_struct("", self.0.schema_ref().fields().clone(), false)
420 .with_metadata(self.0.schema_ref().metadata.clone());
421 PyArray::new(Arc::new(struct_array), field.into()).into()
422 }
423
424 fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3RecordBatch> {
425 let new_schema = schema.into_inner();
426 let new_batch = RecordBatch::try_new(new_schema.clone(), self.0.columns().to_vec())?;
427 Ok(new_batch.into())
428 }
429}