1use std::fmt::Display;
2use std::sync::Arc;
3
4use arrow::compute::concat_batches;
5use arrow::ffi_stream::ArrowArrayStreamReader as ArrowRecordBatchStreamReader;
6use arrow_array::{ArrayRef, RecordBatchReader, StructArray};
7use arrow_array::{RecordBatch, RecordBatchIterator};
8use arrow_schema::{ArrowError, Field, Schema, SchemaRef};
9use indexmap::IndexMap;
10use pyo3::exceptions::{PyTypeError, PyValueError};
11use pyo3::prelude::*;
12use pyo3::types::{PyCapsule, PyTuple, PyType};
13use pyo3::{intern, IntoPyObjectExt};
14
15use crate::error::{PyArrowError, PyArrowResult};
16use crate::export::{
17 Arro3ChunkedArray, Arro3Field, Arro3RecordBatch, Arro3RecordBatchReader, Arro3Schema,
18 Arro3Table,
19};
20use crate::ffi::from_python::utils::import_stream_pycapsule;
21use crate::ffi::to_python::chunked::ArrayIterator;
22use crate::ffi::to_python::nanoarrow::to_nanoarrow_array_stream;
23use crate::ffi::to_python::to_stream_pycapsule;
24use crate::ffi::to_schema_pycapsule;
25use crate::input::{
26 AnyArray, AnyRecordBatch, FieldIndexInput, MetadataInput, NameOrField, SelectIndices,
27};
28use crate::schema::display_schema;
29use crate::utils::schema_equals;
30use crate::{PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PySchema};
31
32#[pyclass(module = "arro3.core._core", name = "Table", subclass, frozen)]
36#[derive(Debug)]
37pub struct PyTable {
38 batches: Vec<RecordBatch>,
39 schema: SchemaRef,
40}
41
42impl PyTable {
43 pub fn try_new(batches: Vec<RecordBatch>, schema: SchemaRef) -> PyResult<Self> {
45 if !batches
46 .iter()
47 .all(|rb| schema_equals(rb.schema_ref(), &schema))
48 {
49 return Err(PyTypeError::new_err("All batches must have same schema"));
50 }
51
52 Ok(Self { schema, batches })
53 }
54
55 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
57 let stream = import_stream_pycapsule(capsule)?;
58 let stream_reader = ArrowRecordBatchStreamReader::try_new(stream)
59 .map_err(|err| PyValueError::new_err(err.to_string()))?;
60 let schema = stream_reader.schema();
61
62 let mut batches = vec![];
63 for batch in stream_reader {
64 let batch = batch.map_err(|err| PyTypeError::new_err(err.to_string()))?;
65 batches.push(batch);
66 }
67
68 Self::try_new(batches, schema)
69 }
70
71 pub fn batches(&self) -> &[RecordBatch] {
73 &self.batches
74 }
75
76 pub fn into_inner(self) -> (Vec<RecordBatch>, SchemaRef) {
78 (self.batches, self.schema)
79 }
80
81 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
83 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
84 arro3_mod.getattr(intern!(py, "Table"))?.call_method1(
85 intern!(py, "from_arrow_pycapsule"),
86 PyTuple::new(py, vec![self.__arrow_c_stream__(py, None)?])?,
87 )
88 }
89
90 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
92 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
93 let capsule =
94 Self::to_stream_pycapsule(py, self.batches.clone(), self.schema.clone(), None)?;
95 arro3_mod.getattr(intern!(py, "Table"))?.call_method1(
96 intern!(py, "from_arrow_pycapsule"),
97 PyTuple::new(py, vec![capsule])?,
98 )
99 }
100
101 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
103 to_nanoarrow_array_stream(py, &self.__arrow_c_stream__(py, None)?)
104 }
105
106 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
110 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
111 let pyarrow_obj = pyarrow_mod
112 .getattr(intern!(py, "table"))?
113 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
114 pyarrow_obj.into_py_any(py)
115 }
116
117 pub(crate) fn to_stream_pycapsule<'py>(
118 py: Python<'py>,
119 batches: Vec<RecordBatch>,
120 schema: SchemaRef,
121 requested_schema: Option<Bound<'py, PyCapsule>>,
122 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
123 let field = schema.fields();
124 let array_reader = batches.into_iter().map(|batch| {
125 let arr: ArrayRef = Arc::new(StructArray::from(batch));
126 Ok(arr)
127 });
128 let array_reader = Box::new(ArrayIterator::new(
129 array_reader,
130 Field::new_struct("", field.clone(), false)
131 .with_metadata(schema.metadata.clone())
132 .into(),
133 ));
134 to_stream_pycapsule(py, array_reader, requested_schema)
135 }
136
137 pub(crate) fn rechunk(&self, chunk_lengths: Vec<usize>) -> PyArrowResult<Self> {
138 let total_chunk_length = chunk_lengths.iter().sum::<usize>();
139 if total_chunk_length != self.num_rows() {
140 return Err(
141 PyValueError::new_err("Chunk lengths do not add up to table length").into(),
142 );
143 }
144
145 let matches_existing_chunking = chunk_lengths
147 .iter()
148 .zip(self.batches())
149 .all(|(length, batch)| *length == batch.num_rows());
150 if matches_existing_chunking {
151 return Ok(Self::try_new(self.batches.clone(), self.schema.clone())?);
152 }
153
154 let mut offset = 0;
155 let batches = chunk_lengths
156 .iter()
157 .map(|chunk_length| {
158 let sliced_table = self.slice(offset, *chunk_length)?;
159 let sliced_concatted = concat_batches(&self.schema, sliced_table.batches.iter())?;
160 offset += chunk_length;
161 Ok(sliced_concatted)
162 })
163 .collect::<PyArrowResult<Vec<_>>>()?;
164
165 Ok(Self::try_new(batches, self.schema.clone())?)
166 }
167
168 pub(crate) fn slice(&self, mut offset: usize, mut length: usize) -> PyArrowResult<Self> {
169 if offset + length > self.num_rows() {
170 return Err(
171 PyValueError::new_err("offset + length may not exceed length of array").into(),
172 );
173 }
174
175 let mut sliced_batches: Vec<RecordBatch> = vec![];
176 for chunk in self.batches() {
177 if chunk.num_rows() == 0 {
178 continue;
179 }
180
181 if offset >= chunk.num_rows() {
184 offset -= chunk.num_rows();
185 continue;
186 }
187
188 let take_count = length.min(chunk.num_rows() - offset);
189 let sliced_chunk = chunk.slice(offset, take_count);
190 sliced_batches.push(sliced_chunk);
191
192 length -= take_count;
193
194 if length == 0 {
196 break;
197 } else {
198 offset = 0;
199 }
200 }
201
202 Ok(Self::try_new(sliced_batches, self.schema.clone())?)
203 }
204}
205
206impl Display for PyTable {
207 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208 writeln!(f, "arro3.core.Table")?;
209 writeln!(f, "-----------")?;
210 display_schema(&self.schema, f)
211 }
212}
213
214#[pymethods]
215impl PyTable {
216 #[new]
217 #[pyo3(signature = (data, *, names=None, schema=None, metadata=None))]
218 fn new(
219 py: Python,
220 data: &Bound<PyAny>,
221 names: Option<Vec<String>>,
222 schema: Option<PySchema>,
223 metadata: Option<MetadataInput>,
224 ) -> PyArrowResult<Self> {
225 if let Ok(data) = data.extract::<AnyRecordBatch>() {
226 Ok(data.into_table()?)
227 } else if let Ok(mapping) = data.extract::<IndexMap<String, AnyArray>>() {
228 Self::from_pydict(&py.get_type::<PyTable>(), mapping, schema, metadata)
229 } else if let Ok(arrays) = data.extract::<Vec<AnyArray>>() {
230 Self::from_arrays(&py.get_type::<PyTable>(), arrays, names, schema, metadata)
231 } else {
232 Err(PyTypeError::new_err(
233 "Expected Table-like input or dict of arrays or sequence of arrays.",
234 )
235 .into())
236 }
237 }
238
239 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
240 to_schema_pycapsule(py, self.schema.as_ref())
241 }
242
243 #[pyo3(signature = (requested_schema=None))]
244 fn __arrow_c_stream__<'py>(
245 &'py self,
246 py: Python<'py>,
247 requested_schema: Option<Bound<'py, PyCapsule>>,
248 ) -> PyArrowResult<Bound<'py, PyCapsule>> {
249 Self::to_stream_pycapsule(
250 py,
251 self.batches.clone(),
252 self.schema.clone(),
253 requested_schema,
254 )
255 }
256
257 fn __eq__(&self, other: &PyTable) -> bool {
258 self.batches == other.batches && self.schema == other.schema
259 }
260
261 fn __getitem__(&self, key: FieldIndexInput) -> PyArrowResult<Arro3ChunkedArray> {
262 self.column(key)
263 }
264
265 fn __len__(&self) -> usize {
266 self.batches.iter().fold(0, |acc, x| acc + x.num_rows())
267 }
268
269 fn __repr__(&self) -> String {
270 self.to_string()
271 }
272
273 #[classmethod]
274 fn from_arrow(_cls: &Bound<PyType>, input: AnyRecordBatch) -> PyArrowResult<Self> {
275 input.into_table()
276 }
277
278 #[classmethod]
279 #[pyo3(name = "from_arrow_pycapsule")]
280 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
281 Self::from_arrow_pycapsule(capsule)
282 }
283
284 #[classmethod]
285 #[pyo3(signature = (batches, *, schema=None))]
286 fn from_batches(
287 _cls: &Bound<PyType>,
288 batches: Vec<PyRecordBatch>,
289 schema: Option<PySchema>,
290 ) -> PyArrowResult<Self> {
291 if batches.is_empty() {
292 let schema = schema.ok_or(PyValueError::new_err(
293 "schema must be passed for an empty list of batches",
294 ))?;
295 return Ok(Self::try_new(vec![], schema.into_inner())?);
296 }
297
298 let batches = batches
299 .into_iter()
300 .map(|batch| batch.into_inner())
301 .collect::<Vec<_>>();
302 let schema = schema
303 .map(|s| s.into_inner())
304 .unwrap_or(batches.first().unwrap().schema());
305 Ok(Self::try_new(batches, schema)?)
306 }
307
308 #[classmethod]
309 #[pyo3(signature = (mapping, *, schema=None, metadata=None))]
310 fn from_pydict(
311 cls: &Bound<PyType>,
312 mapping: IndexMap<String, AnyArray>,
313 schema: Option<PySchema>,
314 metadata: Option<MetadataInput>,
315 ) -> PyArrowResult<Self> {
316 let (names, arrays): (Vec<_>, Vec<_>) = mapping.into_iter().unzip();
317 Self::from_arrays(cls, arrays, Some(names), schema, metadata)
318 }
319
320 #[classmethod]
321 #[pyo3(signature = (arrays, *, names=None, schema=None, metadata=None))]
322 fn from_arrays(
323 _cls: &Bound<PyType>,
324 arrays: Vec<AnyArray>,
325 names: Option<Vec<String>>,
326 schema: Option<PySchema>,
327 metadata: Option<MetadataInput>,
328 ) -> PyArrowResult<Self> {
329 let columns = arrays
330 .into_iter()
331 .map(|array| array.into_chunked_array())
332 .collect::<PyArrowResult<Vec<_>>>()?;
333
334 let schema: SchemaRef = if let Some(schema) = schema {
335 schema.into_inner()
336 } else {
337 let names = names.ok_or(PyValueError::new_err(
338 "names must be passed if schema is not passed.",
339 ))?;
340
341 let fields = columns
342 .iter()
343 .zip(names.iter())
344 .map(|(array, name)| Arc::new(array.field().as_ref().clone().with_name(name)))
345 .collect::<Vec<_>>();
346 Arc::new(
347 Schema::new(fields)
348 .with_metadata(metadata.unwrap_or_default().into_string_hashmap().unwrap()),
349 )
350 };
351
352 if columns.is_empty() {
353 return Ok(Self::try_new(vec![], schema)?);
354 }
355
356 let column_chunk_lengths = columns
357 .iter()
358 .map(|column| {
359 let chunk_lengths = column
360 .chunks()
361 .iter()
362 .map(|chunk| chunk.len())
363 .collect::<Vec<_>>();
364 chunk_lengths
365 })
366 .collect::<Vec<_>>();
367 if !column_chunk_lengths.windows(2).all(|w| w[0] == w[1]) {
368 return Err(
369 PyValueError::new_err("All columns must have the same chunk lengths").into(),
370 );
371 }
372 let num_batches = column_chunk_lengths[0].len();
373
374 let mut batches = vec![];
375 for batch_idx in 0..num_batches {
376 let batch = RecordBatch::try_new(
377 schema.clone(),
378 columns
379 .iter()
380 .map(|column| column.chunks()[batch_idx].clone())
381 .collect(),
382 )?;
383 batches.push(batch);
384 }
385
386 Ok(Self::try_new(batches, schema)?)
387 }
388
389 fn add_column(
390 &self,
391 i: usize,
392 field: NameOrField,
393 column: PyChunkedArray,
394 ) -> PyArrowResult<Arro3Table> {
395 if self.num_rows() != column.len() {
396 return Err(
397 PyValueError::new_err("Number of rows in column does not match table.").into(),
398 );
399 }
400
401 let column = column.rechunk(self.chunk_lengths())?;
402
403 let mut fields = self.schema.fields().to_vec();
404 fields.insert(i, field.into_field(column.field()));
405 let new_schema = Arc::new(Schema::new_with_metadata(
406 fields,
407 self.schema.metadata().clone(),
408 ));
409
410 let new_batches = self
411 .batches
412 .iter()
413 .zip(column.chunks())
414 .map(|(batch, array)| {
415 debug_assert_eq!(
416 array.len(),
417 batch.num_rows(),
418 "Array and batch should have same number of rows."
419 );
420
421 let mut columns = batch.columns().to_vec();
422 columns.insert(i, array.clone());
423 Ok(RecordBatch::try_new(new_schema.clone(), columns)?)
424 })
425 .collect::<Result<Vec<_>, PyArrowError>>()?;
426
427 Ok(PyTable::try_new(new_batches, new_schema)?.into())
428 }
429
430 fn append_column(
431 &self,
432 field: NameOrField,
433 column: PyChunkedArray,
434 ) -> PyArrowResult<Arro3Table> {
435 if self.num_rows() != column.len() {
436 return Err(
437 PyValueError::new_err("Number of rows in column does not match table.").into(),
438 );
439 }
440
441 let column = column.rechunk(self.chunk_lengths())?;
442
443 let mut fields = self.schema.fields().to_vec();
444 fields.push(field.into_field(column.field()));
445 let new_schema = Arc::new(Schema::new_with_metadata(
446 fields,
447 self.schema.metadata().clone(),
448 ));
449
450 let new_batches = self
451 .batches
452 .iter()
453 .zip(column.chunks())
454 .map(|(batch, array)| {
455 debug_assert_eq!(
456 array.len(),
457 batch.num_rows(),
458 "Array and batch should have same number of rows."
459 );
460
461 let mut columns = batch.columns().to_vec();
462 columns.push(array.clone());
463 Ok(RecordBatch::try_new(new_schema.clone(), columns)?)
464 })
465 .collect::<Result<Vec<_>, PyArrowError>>()?;
466
467 Ok(PyTable::try_new(new_batches, new_schema)?.into())
468 }
469
470 #[getter]
471 fn chunk_lengths(&self) -> Vec<usize> {
472 self.batches.iter().map(|batch| batch.num_rows()).collect()
473 }
474
475 fn column(&self, i: FieldIndexInput) -> PyArrowResult<Arro3ChunkedArray> {
476 let column_index = i.into_position(&self.schema)?;
477 let field = self.schema.field(column_index).clone();
478 let chunks = self
479 .batches
480 .iter()
481 .map(|batch| batch.column(column_index).clone())
482 .collect();
483 Ok(PyChunkedArray::try_new(chunks, field.into())?.into())
484 }
485
486 #[getter]
487 fn column_names(&self) -> Vec<String> {
488 self.schema
489 .fields()
490 .iter()
491 .map(|f| f.name().clone())
492 .collect()
493 }
494
495 #[getter]
496 fn columns(&self) -> PyArrowResult<Vec<Arro3ChunkedArray>> {
497 (0..self.num_columns())
498 .map(|i| self.column(FieldIndexInput::Position(i)))
499 .collect()
500 }
501
502 fn combine_chunks(&self) -> PyArrowResult<Arro3Table> {
503 let batch = concat_batches(&self.schema, &self.batches)?;
504 Ok(PyTable::try_new(vec![batch], self.schema.clone())?.into())
505 }
506
507 fn field(&self, i: FieldIndexInput) -> PyArrowResult<Arro3Field> {
508 let field = self.schema.field(i.into_position(&self.schema)?);
509 Ok(PyField::new(field.clone().into()).into())
510 }
511
512 #[getter]
513 fn nbytes(&self) -> usize {
514 self.batches
515 .iter()
516 .fold(0, |acc, batch| acc + batch.get_array_memory_size())
517 }
518
519 #[getter]
520 fn num_columns(&self) -> usize {
521 self.schema.fields().len()
522 }
523
524 #[getter]
525 fn num_rows(&self) -> usize {
526 self.batches()
527 .iter()
528 .fold(0, |acc, batch| acc + batch.num_rows())
529 }
530
531 #[pyo3(signature = (*, max_chunksize=None))]
532 #[pyo3(name = "rechunk")]
533 fn rechunk_py(&self, max_chunksize: Option<usize>) -> PyArrowResult<Arro3Table> {
534 let max_chunksize = max_chunksize.unwrap_or(self.num_rows());
535 if max_chunksize == 0 {
536 return Err(PyValueError::new_err("max_chunksize must be > 0").into());
537 }
538
539 let mut chunk_lengths = vec![];
540 let mut offset = 0;
541 while offset < self.num_rows() {
542 let chunk_length = max_chunksize.min(self.num_rows() - offset);
543 offset += chunk_length;
544 chunk_lengths.push(chunk_length);
545 }
546 Ok(self.rechunk(chunk_lengths)?.into())
547 }
548
549 fn remove_column(&self, i: usize) -> PyArrowResult<Arro3Table> {
550 let mut fields = self.schema.fields().to_vec();
551 fields.remove(i);
552 let new_schema = Arc::new(Schema::new_with_metadata(
553 fields,
554 self.schema.metadata().clone(),
555 ));
556
557 let new_batches = self
558 .batches
559 .iter()
560 .map(|batch| {
561 let mut columns = batch.columns().to_vec();
562 columns.remove(i);
563 Ok(RecordBatch::try_new(new_schema.clone(), columns)?)
564 })
565 .collect::<Result<Vec<_>, PyArrowError>>()?;
566
567 Ok(PyTable::try_new(new_batches, new_schema)?.into())
568 }
569
570 fn rename_columns(&self, names: Vec<String>) -> PyArrowResult<Arro3Table> {
571 if names.len() != self.num_columns() {
572 return Err(PyValueError::new_err("When names is a list[str], must pass the same number of names as there are columns.").into());
573 }
574
575 let new_fields = self
576 .schema
577 .fields()
578 .iter()
579 .zip(names)
580 .map(|(field, name)| field.as_ref().clone().with_name(name))
581 .collect::<Vec<_>>();
582 let new_schema = Arc::new(Schema::new_with_metadata(
583 new_fields,
584 self.schema.metadata().clone(),
585 ));
586 Ok(PyTable::try_new(self.batches.clone(), new_schema)?.into())
587 }
588
589 #[getter]
590 fn schema(&self) -> Arro3Schema {
591 PySchema::new(self.schema.clone()).into()
592 }
593
594 fn select(&self, columns: SelectIndices) -> PyArrowResult<Arro3Table> {
595 let positions = columns.into_positions(self.schema.fields())?;
596
597 let new_schema = Arc::new(self.schema.project(&positions)?);
598 let new_batches = self
599 .batches
600 .iter()
601 .map(|batch| batch.project(&positions))
602 .collect::<Result<Vec<_>, ArrowError>>()?;
603 Ok(PyTable::try_new(new_batches, new_schema)?.into())
604 }
605
606 fn set_column(
607 &self,
608 i: usize,
609 field: NameOrField,
610 column: PyChunkedArray,
611 ) -> PyArrowResult<Arro3Table> {
612 if self.num_rows() != column.len() {
613 return Err(
614 PyValueError::new_err("Number of rows in column does not match table.").into(),
615 );
616 }
617
618 let column = column.rechunk(self.chunk_lengths())?;
619
620 let mut fields = self.schema.fields().to_vec();
621 fields[i] = field.into_field(column.field());
622 let new_schema = Arc::new(Schema::new_with_metadata(
623 fields,
624 self.schema.metadata().clone(),
625 ));
626
627 let new_batches = self
628 .batches
629 .iter()
630 .zip(column.chunks())
631 .map(|(batch, array)| {
632 debug_assert_eq!(
633 array.len(),
634 batch.num_rows(),
635 "Array and batch should have same number of rows."
636 );
637
638 let mut columns = batch.columns().to_vec();
639 columns[i] = array.clone();
640 Ok(RecordBatch::try_new(new_schema.clone(), columns)?)
641 })
642 .collect::<Result<Vec<_>, PyArrowError>>()?;
643
644 Ok(PyTable::try_new(new_batches, new_schema)?.into())
645 }
646
647 #[getter]
648 fn shape(&self) -> (usize, usize) {
649 (self.num_rows(), self.num_columns())
650 }
651
652 #[pyo3(signature = (offset=0, length=None))]
653 #[pyo3(name = "slice")]
654 fn slice_py(&self, offset: usize, length: Option<usize>) -> PyArrowResult<Arro3Table> {
655 let length = length.unwrap_or_else(|| self.num_rows() - offset);
656 Ok(self.slice(offset, length)?.into())
657 }
658
659 fn to_batches(&self) -> Vec<Arro3RecordBatch> {
660 self.batches
661 .iter()
662 .map(|batch| PyRecordBatch::new(batch.clone()).into())
663 .collect()
664 }
665
666 fn to_reader(&self) -> Arro3RecordBatchReader {
667 let reader = Box::new(RecordBatchIterator::new(
668 self.batches.clone().into_iter().map(Ok),
669 self.schema.clone(),
670 ));
671 PyRecordBatchReader::new(reader).into()
672 }
673
674 fn to_struct_array(&self) -> PyArrowResult<Arro3ChunkedArray> {
675 let chunks = self
676 .batches
677 .iter()
678 .map(|batch| {
679 let struct_array: StructArray = batch.clone().into();
680 Arc::new(struct_array) as ArrayRef
681 })
682 .collect::<Vec<_>>();
683 let field = Field::new_struct("", self.schema.fields().clone(), false)
684 .with_metadata(self.schema.metadata.clone());
685 Ok(PyChunkedArray::try_new(chunks, field.into())?.into())
686 }
687
688 fn with_schema(&self, schema: PySchema) -> PyArrowResult<Arro3Table> {
689 let new_schema = schema.into_inner();
690 let new_batches = self
691 .batches
692 .iter()
693 .map(|batch| RecordBatch::try_new(new_schema.clone(), batch.columns().to_vec()))
694 .collect::<Result<Vec<_>, ArrowError>>()?;
695 Ok(PyTable::try_new(new_batches, new_schema)?.into())
696 }
697}