1use std::collections::HashMap;
19use std::ffi::CString;
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use datafusion::arrow::datatypes::Schema;
29use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
30use datafusion::arrow::util::pretty;
31use datafusion::common::UnnestOptions;
32use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
33use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34use datafusion::datasource::TableProvider;
35use datafusion::error::DataFusionError;
36use datafusion::execution::SendableRecordBatchStream;
37use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
38use datafusion::prelude::*;
39use datafusion_ffi::table_provider::FFI_TableProvider;
40use futures::{StreamExt, TryStreamExt};
41use pyo3::exceptions::PyValueError;
42use pyo3::prelude::*;
43use pyo3::pybacked::PyBackedStr;
44use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45use tokio::task::JoinHandle;
46
47use crate::catalog::PyTable;
48use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
49use crate::expr::sort_expr::to_sort_expressions;
50use crate::physical_plan::PyExecutionPlan;
51use crate::record_batch::PyRecordBatchStream;
52use crate::sql::logical::PyLogicalPlan;
53use crate::utils::{
54 get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
55};
56use crate::{
57 errors::PyDataFusionResult,
58 expr::{sort_expr::PySortExpr, PyExpr},
59};
60
61#[pyclass(name = "TableProvider", module = "datafusion")]
65pub struct PyTableProvider {
66 provider: Arc<dyn TableProvider + Send>,
67}
68
69impl PyTableProvider {
70 pub fn new(provider: Arc<dyn TableProvider>) -> Self {
71 Self { provider }
72 }
73
74 pub fn as_table(&self) -> PyTable {
75 let table_provider: Arc<dyn TableProvider> = self.provider.clone();
76 PyTable::new(table_provider)
77 }
78}
79
80#[pymethods]
81impl PyTableProvider {
82 fn __datafusion_table_provider__<'py>(
83 &self,
84 py: Python<'py>,
85 ) -> PyResult<Bound<'py, PyCapsule>> {
86 let name = CString::new("datafusion_table_provider").unwrap();
87
88 let runtime = get_tokio_runtime().0.handle().clone();
89 let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
90
91 PyCapsule::new(py, provider, Some(name.clone()))
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct FormatterConfig {
98 pub max_bytes: usize,
100 pub min_rows: usize,
102 pub repr_rows: usize,
104}
105
106impl Default for FormatterConfig {
107 fn default() -> Self {
108 Self {
109 max_bytes: 2 * 1024 * 1024, min_rows: 20,
111 repr_rows: 10,
112 }
113 }
114}
115
116impl FormatterConfig {
117 pub fn validate(&self) -> Result<(), String> {
123 if self.max_bytes == 0 {
124 return Err("max_bytes must be a positive integer".to_string());
125 }
126
127 if self.min_rows == 0 {
128 return Err("min_rows must be a positive integer".to_string());
129 }
130
131 if self.repr_rows == 0 {
132 return Err("repr_rows must be a positive integer".to_string());
133 }
134
135 Ok(())
136 }
137}
138
139struct PythonFormatter<'py> {
141 formatter: Bound<'py, PyAny>,
143 config: FormatterConfig,
145}
146
147fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
149 let formatter = import_python_formatter(py)?;
150 let config = build_formatter_config_from_python(&formatter)?;
151 Ok(PythonFormatter { formatter, config })
152}
153
154fn import_python_formatter(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
156 let formatter_module = py.import("datafusion.dataframe_formatter")?;
157 let get_formatter = formatter_module.getattr("get_formatter")?;
158 get_formatter.call0()
159}
160
161fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
163where
164 T: for<'py> pyo3::FromPyObject<'py> + Clone,
165{
166 py_object
167 .getattr(attr_name)
168 .and_then(|v| v.extract::<T>())
169 .unwrap_or_else(|_| default_value.clone())
170}
171
172fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
174 let default_config = FormatterConfig::default();
175 let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
176 let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
177 let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
178
179 let config = FormatterConfig {
180 max_bytes,
181 min_rows,
182 repr_rows,
183 };
184
185 config.validate().map_err(PyValueError::new_err)?;
187 Ok(config)
188}
189
190#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)]
192#[derive(Clone, Default)]
193pub struct PyParquetWriterOptions {
194 options: ParquetOptions,
195}
196
197#[pymethods]
198impl PyParquetWriterOptions {
199 #[new]
200 #[allow(clippy::too_many_arguments)]
201 pub fn new(
202 data_pagesize_limit: usize,
203 write_batch_size: usize,
204 writer_version: String,
205 skip_arrow_metadata: bool,
206 compression: Option<String>,
207 dictionary_enabled: Option<bool>,
208 dictionary_page_size_limit: usize,
209 statistics_enabled: Option<String>,
210 max_row_group_size: usize,
211 created_by: String,
212 column_index_truncate_length: Option<usize>,
213 statistics_truncate_length: Option<usize>,
214 data_page_row_count_limit: usize,
215 encoding: Option<String>,
216 bloom_filter_on_write: bool,
217 bloom_filter_fpp: Option<f64>,
218 bloom_filter_ndv: Option<u64>,
219 allow_single_file_parallelism: bool,
220 maximum_parallel_row_group_writers: usize,
221 maximum_buffered_record_batches_per_stream: usize,
222 ) -> Self {
223 Self {
224 options: ParquetOptions {
225 data_pagesize_limit,
226 write_batch_size,
227 writer_version,
228 skip_arrow_metadata,
229 compression,
230 dictionary_enabled,
231 dictionary_page_size_limit,
232 statistics_enabled,
233 max_row_group_size,
234 created_by,
235 column_index_truncate_length,
236 statistics_truncate_length,
237 data_page_row_count_limit,
238 encoding,
239 bloom_filter_on_write,
240 bloom_filter_fpp,
241 bloom_filter_ndv,
242 allow_single_file_parallelism,
243 maximum_parallel_row_group_writers,
244 maximum_buffered_record_batches_per_stream,
245 ..Default::default()
246 },
247 }
248 }
249}
250
251#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)]
253#[derive(Clone, Default)]
254pub struct PyParquetColumnOptions {
255 options: ParquetColumnOptions,
256}
257
258#[pymethods]
259impl PyParquetColumnOptions {
260 #[new]
261 pub fn new(
262 bloom_filter_enabled: Option<bool>,
263 encoding: Option<String>,
264 dictionary_enabled: Option<bool>,
265 compression: Option<String>,
266 statistics_enabled: Option<String>,
267 bloom_filter_fpp: Option<f64>,
268 bloom_filter_ndv: Option<u64>,
269 ) -> Self {
270 Self {
271 options: ParquetColumnOptions {
272 bloom_filter_enabled,
273 encoding,
274 dictionary_enabled,
275 compression,
276 statistics_enabled,
277 bloom_filter_fpp,
278 bloom_filter_ndv,
279 },
280 }
281 }
282}
283
284#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
288#[derive(Clone)]
289pub struct PyDataFrame {
290 df: Arc<DataFrame>,
291
292 batches: Option<(Vec<RecordBatch>, bool)>,
294}
295
296impl PyDataFrame {
297 pub fn new(df: DataFrame) -> Self {
299 Self {
300 df: Arc::new(df),
301 batches: None,
302 }
303 }
304
305 fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
306 let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
308
309 let should_cache = *is_ipython_env(py) && self.batches.is_none();
310 let (batches, has_more) = match self.batches.take() {
311 Some(b) => b,
312 None => wait_for_future(
313 py,
314 collect_record_batches_to_display(self.df.as_ref().clone(), config),
315 )??,
316 };
317
318 if batches.is_empty() {
319 return Ok("No data to display".to_string());
321 }
322
323 let table_uuid = uuid::Uuid::new_v4().to_string();
324
325 let py_batches = batches
327 .iter()
328 .map(|rb| rb.to_pyarrow(py))
329 .collect::<PyResult<Vec<PyObject>>>()?;
330
331 let py_schema = self.schema().into_pyobject(py)?;
332
333 let kwargs = pyo3::types::PyDict::new(py);
334 let py_batches_list = PyList::new(py, py_batches.as_slice())?;
335 kwargs.set_item("batches", py_batches_list)?;
336 kwargs.set_item("schema", py_schema)?;
337 kwargs.set_item("has_more", has_more)?;
338 kwargs.set_item("table_uuid", table_uuid)?;
339
340 let method_name = match as_html {
341 true => "format_html",
342 false => "format_str",
343 };
344
345 let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
346 let html_str: String = html_result.extract()?;
347
348 if should_cache {
349 self.batches = Some((batches, has_more));
350 }
351
352 Ok(html_str)
353 }
354}
355
356#[pymethods]
357impl PyDataFrame {
358 fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
360 if let Ok(key) = key.extract::<PyBackedStr>() {
361 self.select_columns(vec![key])
363 } else if let Ok(tuple) = key.downcast::<PyTuple>() {
364 let keys = tuple
366 .iter()
367 .map(|item| item.extract::<PyBackedStr>())
368 .collect::<PyResult<Vec<PyBackedStr>>>()?;
369 self.select_columns(keys)
370 } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
371 self.select_columns(keys)
373 } else {
374 let message = "DataFrame can only be indexed by string index or indices".to_string();
375 Err(PyDataFusionError::Common(message))
376 }
377 }
378
379 fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
380 self.prepare_repr_string(py, false)
381 }
382
383 #[staticmethod]
384 #[expect(unused_variables)]
385 fn default_str_repr<'py>(
386 batches: Vec<Bound<'py, PyAny>>,
387 schema: &Bound<'py, PyAny>,
388 has_more: bool,
389 table_uuid: &str,
390 ) -> PyResult<String> {
391 let batches = batches
392 .into_iter()
393 .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
394 .collect::<PyResult<Vec<RecordBatch>>>()?
395 .into_iter()
396 .filter(|batch| batch.num_rows() > 0)
397 .collect::<Vec<_>>();
398
399 if batches.is_empty() {
400 return Ok("No data to display".to_owned());
401 }
402
403 let batches_as_displ =
404 pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
405
406 let additional_str = match has_more {
407 true => "\nData truncated.",
408 false => "",
409 };
410
411 Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
412 }
413
414 fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
415 self.prepare_repr_string(py, true)
416 }
417
418 fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
420 let df = self.df.as_ref().clone();
421 let stat_df = wait_for_future(py, df.describe())??;
422 Ok(Self::new(stat_df))
423 }
424
425 fn schema(&self) -> PyArrowType<Schema> {
427 PyArrowType(self.df.schema().into())
428 }
429
430 #[allow(clippy::wrong_self_convention)]
438 fn into_view(&self) -> PyDataFusionResult<PyTable> {
439 let table_provider = self.df.as_ref().clone().into_view();
443 let table_provider = PyTableProvider::new(table_provider);
444
445 Ok(table_provider.as_table())
446 }
447
448 #[pyo3(signature = (*args))]
449 fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
450 let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
451 let df = self.df.as_ref().clone().select_columns(&args)?;
452 Ok(Self::new(df))
453 }
454
455 #[pyo3(signature = (*args))]
456 fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
457 let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
458 let df = self.df.as_ref().clone().select(expr)?;
459 Ok(Self::new(df))
460 }
461
462 #[pyo3(signature = (*args))]
463 fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
464 let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
465 let df = self.df.as_ref().clone().drop_columns(&cols)?;
466 Ok(Self::new(df))
467 }
468
469 fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
470 let df = self.df.as_ref().clone().filter(predicate.into())?;
471 Ok(Self::new(df))
472 }
473
474 fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
475 let df = self.df.as_ref().clone().with_column(name, expr.into())?;
476 Ok(Self::new(df))
477 }
478
479 fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
480 let mut df = self.df.as_ref().clone();
481 for expr in exprs {
482 let expr: Expr = expr.into();
483 let name = format!("{}", expr.schema_name());
484 df = df.with_column(name.as_str(), expr)?
485 }
486 Ok(Self::new(df))
487 }
488
489 fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
492 let df = self
493 .df
494 .as_ref()
495 .clone()
496 .with_column_renamed(old_name, new_name)?;
497 Ok(Self::new(df))
498 }
499
500 fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
501 let group_by = group_by.into_iter().map(|e| e.into()).collect();
502 let aggs = aggs.into_iter().map(|e| e.into()).collect();
503 let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
504 Ok(Self::new(df))
505 }
506
507 #[pyo3(signature = (*exprs))]
508 fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
509 let exprs = to_sort_expressions(exprs);
510 let df = self.df.as_ref().clone().sort(exprs)?;
511 Ok(Self::new(df))
512 }
513
514 #[pyo3(signature = (count, offset=0))]
515 fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
516 let df = self.df.as_ref().clone().limit(offset, Some(count))?;
517 Ok(Self::new(df))
518 }
519
520 fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
524 let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
525 .map_err(PyDataFusionError::from)?;
526 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
529 }
530
531 fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
533 let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
534 Ok(Self::new(df))
535 }
536
537 fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
540 let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
541 .map_err(PyDataFusionError::from)?;
542
543 batches
544 .into_iter()
545 .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
546 .collect()
547 }
548
549 #[pyo3(signature = (num=20))]
551 fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
552 let df = self.df.as_ref().clone().limit(0, Some(num))?;
553 print_dataframe(py, df)
554 }
555
556 fn distinct(&self) -> PyDataFusionResult<Self> {
558 let df = self.df.as_ref().clone().distinct()?;
559 Ok(Self::new(df))
560 }
561
562 fn join(
563 &self,
564 right: PyDataFrame,
565 how: &str,
566 left_on: Vec<PyBackedStr>,
567 right_on: Vec<PyBackedStr>,
568 ) -> PyDataFusionResult<Self> {
569 let join_type = match how {
570 "inner" => JoinType::Inner,
571 "left" => JoinType::Left,
572 "right" => JoinType::Right,
573 "full" => JoinType::Full,
574 "semi" => JoinType::LeftSemi,
575 "anti" => JoinType::LeftAnti,
576 how => {
577 return Err(PyDataFusionError::Common(format!(
578 "The join type {how} does not exist or is not implemented"
579 )));
580 }
581 };
582
583 let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
584 let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
585
586 let df = self.df.as_ref().clone().join(
587 right.df.as_ref().clone(),
588 join_type,
589 &left_keys,
590 &right_keys,
591 None,
592 )?;
593 Ok(Self::new(df))
594 }
595
596 fn join_on(
597 &self,
598 right: PyDataFrame,
599 on_exprs: Vec<PyExpr>,
600 how: &str,
601 ) -> PyDataFusionResult<Self> {
602 let join_type = match how {
603 "inner" => JoinType::Inner,
604 "left" => JoinType::Left,
605 "right" => JoinType::Right,
606 "full" => JoinType::Full,
607 "semi" => JoinType::LeftSemi,
608 "anti" => JoinType::LeftAnti,
609 how => {
610 return Err(PyDataFusionError::Common(format!(
611 "The join type {how} does not exist or is not implemented"
612 )));
613 }
614 };
615 let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
616
617 let df = self
618 .df
619 .as_ref()
620 .clone()
621 .join_on(right.df.as_ref().clone(), join_type, exprs)?;
622 Ok(Self::new(df))
623 }
624
625 #[pyo3(signature = (verbose=false, analyze=false))]
627 fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
628 let df = self.df.as_ref().clone().explain(verbose, analyze)?;
629 print_dataframe(py, df)
630 }
631
632 fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
634 Ok(self.df.as_ref().clone().logical_plan().clone().into())
635 }
636
637 fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
639 Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
640 }
641
642 fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
644 let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
645 Ok(plan.into())
646 }
647
648 fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
650 let new_df = self
651 .df
652 .as_ref()
653 .clone()
654 .repartition(Partitioning::RoundRobinBatch(num))?;
655 Ok(Self::new(new_df))
656 }
657
658 #[pyo3(signature = (*args, num))]
660 fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
661 let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
662 let new_df = self
663 .df
664 .as_ref()
665 .clone()
666 .repartition(Partitioning::Hash(expr, num))?;
667 Ok(Self::new(new_df))
668 }
669
670 #[pyo3(signature = (py_df, distinct=false))]
673 fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
674 let new_df = if distinct {
675 self.df
676 .as_ref()
677 .clone()
678 .union_distinct(py_df.df.as_ref().clone())?
679 } else {
680 self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
681 };
682
683 Ok(Self::new(new_df))
684 }
685
686 fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
689 let new_df = self
690 .df
691 .as_ref()
692 .clone()
693 .union_distinct(py_df.df.as_ref().clone())?;
694 Ok(Self::new(new_df))
695 }
696
697 #[pyo3(signature = (column, preserve_nulls=true))]
698 fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
699 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
702 let df = self
703 .df
704 .as_ref()
705 .clone()
706 .unnest_columns_with_options(&[column], unnest_options)?;
707 Ok(Self::new(df))
708 }
709
710 #[pyo3(signature = (columns, preserve_nulls=true))]
711 fn unnest_columns(
712 &self,
713 columns: Vec<String>,
714 preserve_nulls: bool,
715 ) -> PyDataFusionResult<Self> {
716 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
719 let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
720 let df = self
721 .df
722 .as_ref()
723 .clone()
724 .unnest_columns_with_options(&cols, unnest_options)?;
725 Ok(Self::new(df))
726 }
727
728 fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
730 let new_df = self
731 .df
732 .as_ref()
733 .clone()
734 .intersect(py_df.df.as_ref().clone())?;
735 Ok(Self::new(new_df))
736 }
737
738 fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
740 let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
741 Ok(Self::new(new_df))
742 }
743
744 fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
746 let csv_options = CsvOptions {
747 has_header: Some(with_header),
748 ..Default::default()
749 };
750 wait_for_future(
751 py,
752 self.df.as_ref().clone().write_csv(
753 path,
754 DataFrameWriteOptions::new(),
755 Some(csv_options),
756 ),
757 )??;
758 Ok(())
759 }
760
761 #[pyo3(signature = (
763 path,
764 compression="zstd",
765 compression_level=None
766 ))]
767 fn write_parquet(
768 &self,
769 path: &str,
770 compression: &str,
771 compression_level: Option<u32>,
772 py: Python,
773 ) -> PyDataFusionResult<()> {
774 fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
775 cl.ok_or(PyValueError::new_err("compression_level is not defined"))
776 }
777
778 let _validated = match compression.to_lowercase().as_str() {
779 "snappy" => Compression::SNAPPY,
780 "gzip" => Compression::GZIP(
781 GzipLevel::try_new(compression_level.unwrap_or(6))
782 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
783 ),
784 "brotli" => Compression::BROTLI(
785 BrotliLevel::try_new(verify_compression_level(compression_level)?)
786 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
787 ),
788 "zstd" => Compression::ZSTD(
789 ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
790 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
791 ),
792 "lzo" => Compression::LZO,
793 "lz4" => Compression::LZ4,
794 "lz4_raw" => Compression::LZ4_RAW,
795 "uncompressed" => Compression::UNCOMPRESSED,
796 _ => {
797 return Err(PyDataFusionError::Common(format!(
798 "Unrecognized compression type {compression}"
799 )));
800 }
801 };
802
803 let mut compression_string = compression.to_string();
804 if let Some(level) = compression_level {
805 compression_string.push_str(&format!("({level})"));
806 }
807
808 let mut options = TableParquetOptions::default();
809 options.global.compression = Some(compression_string);
810
811 wait_for_future(
812 py,
813 self.df.as_ref().clone().write_parquet(
814 path,
815 DataFrameWriteOptions::new(),
816 Option::from(options),
817 ),
818 )??;
819 Ok(())
820 }
821
822 fn write_parquet_with_options(
824 &self,
825 path: &str,
826 options: PyParquetWriterOptions,
827 column_specific_options: HashMap<String, PyParquetColumnOptions>,
828 py: Python,
829 ) -> PyDataFusionResult<()> {
830 let table_options = TableParquetOptions {
831 global: options.options,
832 column_specific_options: column_specific_options
833 .into_iter()
834 .map(|(k, v)| (k, v.options))
835 .collect(),
836 ..Default::default()
837 };
838
839 wait_for_future(
840 py,
841 self.df.as_ref().clone().write_parquet(
842 path,
843 DataFrameWriteOptions::new(),
844 Option::from(table_options),
845 ),
846 )??;
847 Ok(())
848 }
849
850 fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
852 wait_for_future(
853 py,
854 self.df
855 .as_ref()
856 .clone()
857 .write_json(path, DataFrameWriteOptions::new(), None),
858 )??;
859 Ok(())
860 }
861
862 fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
865 let batches = self.collect(py)?.into_pyobject(py)?;
866 let schema = self.schema().into_pyobject(py)?;
867
868 let table_class = py.import("pyarrow")?.getattr("Table")?;
870 let args = PyTuple::new(py, &[batches, schema])?;
871 let table: PyObject = table_class.call_method1("from_batches", args)?.into();
872 Ok(table)
873 }
874
875 #[pyo3(signature = (requested_schema=None))]
876 fn __arrow_c_stream__<'py>(
877 &'py mut self,
878 py: Python<'py>,
879 requested_schema: Option<Bound<'py, PyCapsule>>,
880 ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
881 let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
882 let mut schema: Schema = self.df.schema().to_owned().into();
883
884 if let Some(schema_capsule) = requested_schema {
885 validate_pycapsule(&schema_capsule, "arrow_schema")?;
886
887 let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
888 let desired_schema = Schema::try_from(schema_ptr)?;
889
890 schema = project_schema(schema, desired_schema)?;
891
892 batches = batches
893 .into_iter()
894 .map(|record_batch| record_batch_into_schema(record_batch, &schema))
895 .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
896 }
897
898 let batches_wrapped = batches.into_iter().map(Ok);
899
900 let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
901 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
902
903 let ffi_stream = FFI_ArrowArrayStream::new(reader);
904 let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
905 PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
906 }
907
908 fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
909 let rt = &get_tokio_runtime().0;
911 let df = self.df.as_ref().clone();
912 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
913 rt.spawn(async move { df.execute_stream().await });
914 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
915 Ok(PyRecordBatchStream::new(stream))
916 }
917
918 fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
919 let rt = &get_tokio_runtime().0;
921 let df = self.df.as_ref().clone();
922 let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
923 rt.spawn(async move { df.execute_stream_partitioned().await });
924 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
925 .map_err(py_datafusion_err)?
926 .map_err(py_datafusion_err)?;
927
928 Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
929 }
930
931 fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
934 let table = self.to_arrow_table(py)?;
935
936 let result = table.call_method0(py, "to_pandas")?;
938 Ok(result)
939 }
940
941 fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
944 let table = self.to_arrow_table(py)?;
945
946 let result = table.call_method0(py, "to_pylist")?;
948 Ok(result)
949 }
950
951 fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
954 let table = self.to_arrow_table(py)?;
955
956 let result = table.call_method0(py, "to_pydict")?;
958 Ok(result)
959 }
960
961 fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
964 let table = self.to_arrow_table(py)?;
965 let dataframe = py.import("polars")?.getattr("DataFrame")?;
966 let args = PyTuple::new(py, &[table])?;
967 let result: PyObject = dataframe.call1(args)?.into();
968 Ok(result)
969 }
970
971 fn count(&self, py: Python) -> PyDataFusionResult<usize> {
973 Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
974 }
975
976 #[pyo3(signature = (value, columns=None))]
978 fn fill_null(
979 &self,
980 value: PyObject,
981 columns: Option<Vec<PyBackedStr>>,
982 py: Python,
983 ) -> PyDataFusionResult<Self> {
984 let scalar_value = py_obj_to_scalar_value(py, value)?;
985
986 let cols = match columns {
987 Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
988 None => Vec::new(), };
990
991 let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
992 Ok(Self::new(df))
993 }
994}
995
996fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
998 let batches = wait_for_future(py, df.collect())??;
1000 let result = if batches.is_empty() {
1001 "DataFrame has no rows".to_string()
1002 } else {
1003 match pretty::pretty_format_batches(&batches) {
1004 Ok(batch) => format!("DataFrame()\n{batch}"),
1005 Err(err) => format!("Error: {:?}", err.to_string()),
1006 }
1007 };
1008
1009 let print = py.import("builtins")?.getattr("print")?;
1012 print.call1((result,))?;
1013 Ok(())
1014}
1015
1016fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1017 let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1018
1019 let project_indices: Vec<usize> = to_schema
1020 .fields
1021 .iter()
1022 .map(|field| field.name())
1023 .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1024 .collect();
1025
1026 merged_schema.project(&project_indices)
1027}
1028
1029fn record_batch_into_schema(
1030 record_batch: RecordBatch,
1031 schema: &Schema,
1032) -> Result<RecordBatch, ArrowError> {
1033 let schema = Arc::new(schema.clone());
1034 let base_schema = record_batch.schema();
1035 if base_schema.fields().is_empty() {
1036 return Ok(RecordBatch::new_empty(schema));
1038 }
1039
1040 let array_size = record_batch.column(0).len();
1041 let mut data_arrays = Vec::with_capacity(schema.fields().len());
1042
1043 for field in schema.fields() {
1044 let desired_data_type = field.data_type();
1045 if let Some(original_data) = record_batch.column_by_name(field.name()) {
1046 let original_data_type = original_data.data_type();
1047
1048 if can_cast_types(original_data_type, desired_data_type) {
1049 data_arrays.push(arrow::compute::kernels::cast(
1050 original_data,
1051 desired_data_type,
1052 )?);
1053 } else if field.is_nullable() {
1054 data_arrays.push(new_null_array(desired_data_type, array_size));
1055 } else {
1056 return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1057 }
1058 } else {
1059 if !field.is_nullable() {
1060 return Err(ArrowError::CastError(format!(
1061 "Attempting to set null to non-nullable field {} during schema projection.",
1062 field.name()
1063 )));
1064 }
1065 data_arrays.push(new_null_array(desired_data_type, array_size));
1066 }
1067 }
1068
1069 RecordBatch::try_new(schema, data_arrays)
1070}
1071
1072async fn collect_record_batches_to_display(
1083 df: DataFrame,
1084 config: FormatterConfig,
1085) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1086 let FormatterConfig {
1087 max_bytes,
1088 min_rows,
1089 repr_rows,
1090 } = config;
1091
1092 let partitioned_stream = df.execute_stream_partitioned().await?;
1093 let mut stream = futures::stream::iter(partitioned_stream).flatten();
1094 let mut size_estimate_so_far = 0;
1095 let mut rows_so_far = 0;
1096 let mut record_batches = Vec::default();
1097 let mut has_more = false;
1098
1099 while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1101 let mut rb = match stream.next().await {
1102 None => {
1103 break;
1104 }
1105 Some(Ok(r)) => r,
1106 Some(Err(e)) => return Err(e),
1107 };
1108
1109 let mut rows_in_rb = rb.num_rows();
1110 if rows_in_rb > 0 {
1111 size_estimate_so_far += rb.get_array_memory_size();
1112
1113 if size_estimate_so_far > max_bytes {
1114 let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1115 let total_rows = rows_in_rb + rows_so_far;
1116
1117 let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1118 if reduced_row_num < min_rows {
1119 reduced_row_num = min_rows.min(total_rows);
1120 }
1121
1122 let limited_rows_this_rb = reduced_row_num - rows_so_far;
1123 if limited_rows_this_rb < rows_in_rb {
1124 rows_in_rb = limited_rows_this_rb;
1125 rb = rb.slice(0, limited_rows_this_rb);
1126 has_more = true;
1127 }
1128 }
1129
1130 if rows_in_rb + rows_so_far > repr_rows {
1131 rb = rb.slice(0, repr_rows - rows_so_far);
1132 has_more = true;
1133 }
1134
1135 rows_so_far += rb.num_rows();
1136 record_batches.push(rb);
1137 }
1138 }
1139
1140 if record_batches.is_empty() {
1141 return Ok((Vec::default(), false));
1142 }
1143
1144 if !has_more {
1145 has_more = match stream.try_next().await {
1147 Ok(None) => false, Ok(Some(_)) => true,
1149 Err(_) => false, };
1151 }
1152
1153 Ok((record_batches, has_more))
1154}