datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::CString;
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use datafusion::arrow::datatypes::Schema;
29use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
30use datafusion::arrow::util::pretty;
31use datafusion::common::UnnestOptions;
32use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
33use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34use datafusion::datasource::TableProvider;
35use datafusion::error::DataFusionError;
36use datafusion::execution::SendableRecordBatchStream;
37use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
38use datafusion::prelude::*;
39use datafusion_ffi::table_provider::FFI_TableProvider;
40use futures::{StreamExt, TryStreamExt};
41use pyo3::exceptions::PyValueError;
42use pyo3::prelude::*;
43use pyo3::pybacked::PyBackedStr;
44use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45use tokio::task::JoinHandle;
46
47use crate::catalog::PyTable;
48use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
49use crate::expr::sort_expr::to_sort_expressions;
50use crate::physical_plan::PyExecutionPlan;
51use crate::record_batch::PyRecordBatchStream;
52use crate::sql::logical::PyLogicalPlan;
53use crate::utils::{
54    get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
55};
56use crate::{
57    errors::PyDataFusionResult,
58    expr::{sort_expr::PySortExpr, PyExpr},
59};
60
61// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
62// - we have not decided on the table_provider approach yet
63// this is an interim implementation
64#[pyclass(name = "TableProvider", module = "datafusion")]
65pub struct PyTableProvider {
66    provider: Arc<dyn TableProvider + Send>,
67}
68
69impl PyTableProvider {
70    pub fn new(provider: Arc<dyn TableProvider>) -> Self {
71        Self { provider }
72    }
73
74    pub fn as_table(&self) -> PyTable {
75        let table_provider: Arc<dyn TableProvider> = self.provider.clone();
76        PyTable::new(table_provider)
77    }
78}
79
80#[pymethods]
81impl PyTableProvider {
82    fn __datafusion_table_provider__<'py>(
83        &self,
84        py: Python<'py>,
85    ) -> PyResult<Bound<'py, PyCapsule>> {
86        let name = CString::new("datafusion_table_provider").unwrap();
87
88        let runtime = get_tokio_runtime().0.handle().clone();
89        let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
90
91        PyCapsule::new(py, provider, Some(name.clone()))
92    }
93}
94
95/// Configuration for DataFrame display formatting
96#[derive(Debug, Clone)]
97pub struct FormatterConfig {
98    /// Maximum memory in bytes to use for display (default: 2MB)
99    pub max_bytes: usize,
100    /// Minimum number of rows to display (default: 20)
101    pub min_rows: usize,
102    /// Number of rows to include in __repr__ output (default: 10)
103    pub repr_rows: usize,
104}
105
106impl Default for FormatterConfig {
107    fn default() -> Self {
108        Self {
109            max_bytes: 2 * 1024 * 1024, // 2MB
110            min_rows: 20,
111            repr_rows: 10,
112        }
113    }
114}
115
116impl FormatterConfig {
117    /// Validates that all configuration values are positive integers.
118    ///
119    /// # Returns
120    ///
121    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
122    pub fn validate(&self) -> Result<(), String> {
123        if self.max_bytes == 0 {
124            return Err("max_bytes must be a positive integer".to_string());
125        }
126
127        if self.min_rows == 0 {
128            return Err("min_rows must be a positive integer".to_string());
129        }
130
131        if self.repr_rows == 0 {
132            return Err("repr_rows must be a positive integer".to_string());
133        }
134
135        Ok(())
136    }
137}
138
139/// Holds the Python formatter and its configuration
140struct PythonFormatter<'py> {
141    /// The Python formatter object
142    formatter: Bound<'py, PyAny>,
143    /// The formatter configuration
144    config: FormatterConfig,
145}
146
147/// Get the Python formatter and its configuration
148fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
149    let formatter = import_python_formatter(py)?;
150    let config = build_formatter_config_from_python(&formatter)?;
151    Ok(PythonFormatter { formatter, config })
152}
153
154/// Get the Python formatter from the datafusion.dataframe_formatter module
155fn import_python_formatter(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
156    let formatter_module = py.import("datafusion.dataframe_formatter")?;
157    let get_formatter = formatter_module.getattr("get_formatter")?;
158    get_formatter.call0()
159}
160
161// Helper function to extract attributes with fallback to default
162fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
163where
164    T: for<'py> pyo3::FromPyObject<'py> + Clone,
165{
166    py_object
167        .getattr(attr_name)
168        .and_then(|v| v.extract::<T>())
169        .unwrap_or_else(|_| default_value.clone())
170}
171
172/// Helper function to create a FormatterConfig from a Python formatter object
173fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
174    let default_config = FormatterConfig::default();
175    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
176    let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
177    let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
178
179    let config = FormatterConfig {
180        max_bytes,
181        min_rows,
182        repr_rows,
183    };
184
185    // Return the validated config, converting String error to PyErr
186    config.validate().map_err(PyValueError::new_err)?;
187    Ok(config)
188}
189
190/// Python mapping of `ParquetOptions` (includes just the writer-related options).
191#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)]
192#[derive(Clone, Default)]
193pub struct PyParquetWriterOptions {
194    options: ParquetOptions,
195}
196
197#[pymethods]
198impl PyParquetWriterOptions {
199    #[new]
200    #[allow(clippy::too_many_arguments)]
201    pub fn new(
202        data_pagesize_limit: usize,
203        write_batch_size: usize,
204        writer_version: String,
205        skip_arrow_metadata: bool,
206        compression: Option<String>,
207        dictionary_enabled: Option<bool>,
208        dictionary_page_size_limit: usize,
209        statistics_enabled: Option<String>,
210        max_row_group_size: usize,
211        created_by: String,
212        column_index_truncate_length: Option<usize>,
213        statistics_truncate_length: Option<usize>,
214        data_page_row_count_limit: usize,
215        encoding: Option<String>,
216        bloom_filter_on_write: bool,
217        bloom_filter_fpp: Option<f64>,
218        bloom_filter_ndv: Option<u64>,
219        allow_single_file_parallelism: bool,
220        maximum_parallel_row_group_writers: usize,
221        maximum_buffered_record_batches_per_stream: usize,
222    ) -> Self {
223        Self {
224            options: ParquetOptions {
225                data_pagesize_limit,
226                write_batch_size,
227                writer_version,
228                skip_arrow_metadata,
229                compression,
230                dictionary_enabled,
231                dictionary_page_size_limit,
232                statistics_enabled,
233                max_row_group_size,
234                created_by,
235                column_index_truncate_length,
236                statistics_truncate_length,
237                data_page_row_count_limit,
238                encoding,
239                bloom_filter_on_write,
240                bloom_filter_fpp,
241                bloom_filter_ndv,
242                allow_single_file_parallelism,
243                maximum_parallel_row_group_writers,
244                maximum_buffered_record_batches_per_stream,
245                ..Default::default()
246            },
247        }
248    }
249}
250
251/// Python mapping of `ParquetColumnOptions`.
252#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)]
253#[derive(Clone, Default)]
254pub struct PyParquetColumnOptions {
255    options: ParquetColumnOptions,
256}
257
258#[pymethods]
259impl PyParquetColumnOptions {
260    #[new]
261    pub fn new(
262        bloom_filter_enabled: Option<bool>,
263        encoding: Option<String>,
264        dictionary_enabled: Option<bool>,
265        compression: Option<String>,
266        statistics_enabled: Option<String>,
267        bloom_filter_fpp: Option<f64>,
268        bloom_filter_ndv: Option<u64>,
269    ) -> Self {
270        Self {
271            options: ParquetColumnOptions {
272                bloom_filter_enabled,
273                encoding,
274                dictionary_enabled,
275                compression,
276                statistics_enabled,
277                bloom_filter_fpp,
278                bloom_filter_ndv,
279            },
280        }
281    }
282}
283
284/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
285/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
286/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
287#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
288#[derive(Clone)]
289pub struct PyDataFrame {
290    df: Arc<DataFrame>,
291
292    // In IPython environment cache batches between __repr__ and _repr_html_ calls.
293    batches: Option<(Vec<RecordBatch>, bool)>,
294}
295
296impl PyDataFrame {
297    /// creates a new PyDataFrame
298    pub fn new(df: DataFrame) -> Self {
299        Self {
300            df: Arc::new(df),
301            batches: None,
302        }
303    }
304
305    fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
306        // Get the Python formatter and config
307        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
308
309        let should_cache = *is_ipython_env(py) && self.batches.is_none();
310        let (batches, has_more) = match self.batches.take() {
311            Some(b) => b,
312            None => wait_for_future(
313                py,
314                collect_record_batches_to_display(self.df.as_ref().clone(), config),
315            )??,
316        };
317
318        if batches.is_empty() {
319            // This should not be reached, but do it for safety since we index into the vector below
320            return Ok("No data to display".to_string());
321        }
322
323        let table_uuid = uuid::Uuid::new_v4().to_string();
324
325        // Convert record batches to PyObject list
326        let py_batches = batches
327            .iter()
328            .map(|rb| rb.to_pyarrow(py))
329            .collect::<PyResult<Vec<PyObject>>>()?;
330
331        let py_schema = self.schema().into_pyobject(py)?;
332
333        let kwargs = pyo3::types::PyDict::new(py);
334        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
335        kwargs.set_item("batches", py_batches_list)?;
336        kwargs.set_item("schema", py_schema)?;
337        kwargs.set_item("has_more", has_more)?;
338        kwargs.set_item("table_uuid", table_uuid)?;
339
340        let method_name = match as_html {
341            true => "format_html",
342            false => "format_str",
343        };
344
345        let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
346        let html_str: String = html_result.extract()?;
347
348        if should_cache {
349            self.batches = Some((batches, has_more));
350        }
351
352        Ok(html_str)
353    }
354}
355
356#[pymethods]
357impl PyDataFrame {
358    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
359    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
360        if let Ok(key) = key.extract::<PyBackedStr>() {
361            // df[col]
362            self.select_columns(vec![key])
363        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
364            // df[col1, col2, col3]
365            let keys = tuple
366                .iter()
367                .map(|item| item.extract::<PyBackedStr>())
368                .collect::<PyResult<Vec<PyBackedStr>>>()?;
369            self.select_columns(keys)
370        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
371            // df[[col1, col2, col3]]
372            self.select_columns(keys)
373        } else {
374            let message = "DataFrame can only be indexed by string index or indices".to_string();
375            Err(PyDataFusionError::Common(message))
376        }
377    }
378
379    fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
380        self.prepare_repr_string(py, false)
381    }
382
383    #[staticmethod]
384    #[expect(unused_variables)]
385    fn default_str_repr<'py>(
386        batches: Vec<Bound<'py, PyAny>>,
387        schema: &Bound<'py, PyAny>,
388        has_more: bool,
389        table_uuid: &str,
390    ) -> PyResult<String> {
391        let batches = batches
392            .into_iter()
393            .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
394            .collect::<PyResult<Vec<RecordBatch>>>()?
395            .into_iter()
396            .filter(|batch| batch.num_rows() > 0)
397            .collect::<Vec<_>>();
398
399        if batches.is_empty() {
400            return Ok("No data to display".to_owned());
401        }
402
403        let batches_as_displ =
404            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
405
406        let additional_str = match has_more {
407            true => "\nData truncated.",
408            false => "",
409        };
410
411        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
412    }
413
414    fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
415        self.prepare_repr_string(py, true)
416    }
417
418    /// Calculate summary statistics for a DataFrame
419    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
420        let df = self.df.as_ref().clone();
421        let stat_df = wait_for_future(py, df.describe())??;
422        Ok(Self::new(stat_df))
423    }
424
425    /// Returns the schema from the logical plan
426    fn schema(&self) -> PyArrowType<Schema> {
427        PyArrowType(self.df.schema().into())
428    }
429
430    /// Convert this DataFrame into a Table that can be used in register_table
431    /// By convention, into_... methods consume self and return the new object.
432    /// Disabling the clippy lint, so we can use &self
433    /// because we're working with Python bindings
434    /// where objects are shared
435    /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
436    /// - we have not decided on the table_provider approach yet
437    #[allow(clippy::wrong_self_convention)]
438    fn into_view(&self) -> PyDataFusionResult<PyTable> {
439        // Call the underlying Rust DataFrame::into_view method.
440        // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
441        // so that we don’t invalidate this PyDataFrame.
442        let table_provider = self.df.as_ref().clone().into_view();
443        let table_provider = PyTableProvider::new(table_provider);
444
445        Ok(table_provider.as_table())
446    }
447
448    #[pyo3(signature = (*args))]
449    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
450        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
451        let df = self.df.as_ref().clone().select_columns(&args)?;
452        Ok(Self::new(df))
453    }
454
455    #[pyo3(signature = (*args))]
456    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
457        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
458        let df = self.df.as_ref().clone().select(expr)?;
459        Ok(Self::new(df))
460    }
461
462    #[pyo3(signature = (*args))]
463    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
464        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
465        let df = self.df.as_ref().clone().drop_columns(&cols)?;
466        Ok(Self::new(df))
467    }
468
469    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
470        let df = self.df.as_ref().clone().filter(predicate.into())?;
471        Ok(Self::new(df))
472    }
473
474    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
475        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
476        Ok(Self::new(df))
477    }
478
479    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
480        let mut df = self.df.as_ref().clone();
481        for expr in exprs {
482            let expr: Expr = expr.into();
483            let name = format!("{}", expr.schema_name());
484            df = df.with_column(name.as_str(), expr)?
485        }
486        Ok(Self::new(df))
487    }
488
489    /// Rename one column by applying a new projection. This is a no-op if the column to be
490    /// renamed does not exist.
491    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
492        let df = self
493            .df
494            .as_ref()
495            .clone()
496            .with_column_renamed(old_name, new_name)?;
497        Ok(Self::new(df))
498    }
499
500    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
501        let group_by = group_by.into_iter().map(|e| e.into()).collect();
502        let aggs = aggs.into_iter().map(|e| e.into()).collect();
503        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
504        Ok(Self::new(df))
505    }
506
507    #[pyo3(signature = (*exprs))]
508    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
509        let exprs = to_sort_expressions(exprs);
510        let df = self.df.as_ref().clone().sort(exprs)?;
511        Ok(Self::new(df))
512    }
513
514    #[pyo3(signature = (count, offset=0))]
515    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
516        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
517        Ok(Self::new(df))
518    }
519
520    /// Executes the plan, returning a list of `RecordBatch`es.
521    /// Unless some order is specified in the plan, there is no
522    /// guarantee of the order of the result.
523    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
524        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
525            .map_err(PyDataFusionError::from)?;
526        // cannot use PyResult<Vec<RecordBatch>> return type due to
527        // https://github.com/PyO3/pyo3/issues/1813
528        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
529    }
530
531    /// Cache DataFrame.
532    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
533        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
534        Ok(Self::new(df))
535    }
536
537    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
538    /// maintaining the input partitioning.
539    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
540        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
541            .map_err(PyDataFusionError::from)?;
542
543        batches
544            .into_iter()
545            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
546            .collect()
547    }
548
549    /// Print the result, 20 lines by default
550    #[pyo3(signature = (num=20))]
551    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
552        let df = self.df.as_ref().clone().limit(0, Some(num))?;
553        print_dataframe(py, df)
554    }
555
556    /// Filter out duplicate rows
557    fn distinct(&self) -> PyDataFusionResult<Self> {
558        let df = self.df.as_ref().clone().distinct()?;
559        Ok(Self::new(df))
560    }
561
562    fn join(
563        &self,
564        right: PyDataFrame,
565        how: &str,
566        left_on: Vec<PyBackedStr>,
567        right_on: Vec<PyBackedStr>,
568    ) -> PyDataFusionResult<Self> {
569        let join_type = match how {
570            "inner" => JoinType::Inner,
571            "left" => JoinType::Left,
572            "right" => JoinType::Right,
573            "full" => JoinType::Full,
574            "semi" => JoinType::LeftSemi,
575            "anti" => JoinType::LeftAnti,
576            how => {
577                return Err(PyDataFusionError::Common(format!(
578                    "The join type {how} does not exist or is not implemented"
579                )));
580            }
581        };
582
583        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
584        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
585
586        let df = self.df.as_ref().clone().join(
587            right.df.as_ref().clone(),
588            join_type,
589            &left_keys,
590            &right_keys,
591            None,
592        )?;
593        Ok(Self::new(df))
594    }
595
596    fn join_on(
597        &self,
598        right: PyDataFrame,
599        on_exprs: Vec<PyExpr>,
600        how: &str,
601    ) -> PyDataFusionResult<Self> {
602        let join_type = match how {
603            "inner" => JoinType::Inner,
604            "left" => JoinType::Left,
605            "right" => JoinType::Right,
606            "full" => JoinType::Full,
607            "semi" => JoinType::LeftSemi,
608            "anti" => JoinType::LeftAnti,
609            how => {
610                return Err(PyDataFusionError::Common(format!(
611                    "The join type {how} does not exist or is not implemented"
612                )));
613            }
614        };
615        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
616
617        let df = self
618            .df
619            .as_ref()
620            .clone()
621            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
622        Ok(Self::new(df))
623    }
624
625    /// Print the query plan
626    #[pyo3(signature = (verbose=false, analyze=false))]
627    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
628        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
629        print_dataframe(py, df)
630    }
631
632    /// Get the logical plan for this `DataFrame`
633    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
634        Ok(self.df.as_ref().clone().logical_plan().clone().into())
635    }
636
637    /// Get the optimized logical plan for this `DataFrame`
638    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
639        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
640    }
641
642    /// Get the execution plan for this `DataFrame`
643    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
644        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
645        Ok(plan.into())
646    }
647
648    /// Repartition a `DataFrame` based on a logical partitioning scheme.
649    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
650        let new_df = self
651            .df
652            .as_ref()
653            .clone()
654            .repartition(Partitioning::RoundRobinBatch(num))?;
655        Ok(Self::new(new_df))
656    }
657
658    /// Repartition a `DataFrame` based on a logical partitioning scheme.
659    #[pyo3(signature = (*args, num))]
660    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
661        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
662        let new_df = self
663            .df
664            .as_ref()
665            .clone()
666            .repartition(Partitioning::Hash(expr, num))?;
667        Ok(Self::new(new_df))
668    }
669
670    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
671    /// two `DataFrame`s must have exactly the same schema
672    #[pyo3(signature = (py_df, distinct=false))]
673    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
674        let new_df = if distinct {
675            self.df
676                .as_ref()
677                .clone()
678                .union_distinct(py_df.df.as_ref().clone())?
679        } else {
680            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
681        };
682
683        Ok(Self::new(new_df))
684    }
685
686    /// Calculate the distinct union of two `DataFrame`s.  The
687    /// two `DataFrame`s must have exactly the same schema
688    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
689        let new_df = self
690            .df
691            .as_ref()
692            .clone()
693            .union_distinct(py_df.df.as_ref().clone())?;
694        Ok(Self::new(new_df))
695    }
696
697    #[pyo3(signature = (column, preserve_nulls=true))]
698    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
699        // TODO: expose RecursionUnnestOptions
700        // REF: https://github.com/apache/datafusion/pull/11577
701        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
702        let df = self
703            .df
704            .as_ref()
705            .clone()
706            .unnest_columns_with_options(&[column], unnest_options)?;
707        Ok(Self::new(df))
708    }
709
710    #[pyo3(signature = (columns, preserve_nulls=true))]
711    fn unnest_columns(
712        &self,
713        columns: Vec<String>,
714        preserve_nulls: bool,
715    ) -> PyDataFusionResult<Self> {
716        // TODO: expose RecursionUnnestOptions
717        // REF: https://github.com/apache/datafusion/pull/11577
718        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
719        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
720        let df = self
721            .df
722            .as_ref()
723            .clone()
724            .unnest_columns_with_options(&cols, unnest_options)?;
725        Ok(Self::new(df))
726    }
727
728    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
729    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
730        let new_df = self
731            .df
732            .as_ref()
733            .clone()
734            .intersect(py_df.df.as_ref().clone())?;
735        Ok(Self::new(new_df))
736    }
737
738    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
739    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
740        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
741        Ok(Self::new(new_df))
742    }
743
744    /// Write a `DataFrame` to a CSV file.
745    fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
746        let csv_options = CsvOptions {
747            has_header: Some(with_header),
748            ..Default::default()
749        };
750        wait_for_future(
751            py,
752            self.df.as_ref().clone().write_csv(
753                path,
754                DataFrameWriteOptions::new(),
755                Some(csv_options),
756            ),
757        )??;
758        Ok(())
759    }
760
761    /// Write a `DataFrame` to a Parquet file.
762    #[pyo3(signature = (
763        path,
764        compression="zstd",
765        compression_level=None
766        ))]
767    fn write_parquet(
768        &self,
769        path: &str,
770        compression: &str,
771        compression_level: Option<u32>,
772        py: Python,
773    ) -> PyDataFusionResult<()> {
774        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
775            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
776        }
777
778        let _validated = match compression.to_lowercase().as_str() {
779            "snappy" => Compression::SNAPPY,
780            "gzip" => Compression::GZIP(
781                GzipLevel::try_new(compression_level.unwrap_or(6))
782                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
783            ),
784            "brotli" => Compression::BROTLI(
785                BrotliLevel::try_new(verify_compression_level(compression_level)?)
786                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
787            ),
788            "zstd" => Compression::ZSTD(
789                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
790                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
791            ),
792            "lzo" => Compression::LZO,
793            "lz4" => Compression::LZ4,
794            "lz4_raw" => Compression::LZ4_RAW,
795            "uncompressed" => Compression::UNCOMPRESSED,
796            _ => {
797                return Err(PyDataFusionError::Common(format!(
798                    "Unrecognized compression type {compression}"
799                )));
800            }
801        };
802
803        let mut compression_string = compression.to_string();
804        if let Some(level) = compression_level {
805            compression_string.push_str(&format!("({level})"));
806        }
807
808        let mut options = TableParquetOptions::default();
809        options.global.compression = Some(compression_string);
810
811        wait_for_future(
812            py,
813            self.df.as_ref().clone().write_parquet(
814                path,
815                DataFrameWriteOptions::new(),
816                Option::from(options),
817            ),
818        )??;
819        Ok(())
820    }
821
822    /// Write a `DataFrame` to a Parquet file, using advanced options.
823    fn write_parquet_with_options(
824        &self,
825        path: &str,
826        options: PyParquetWriterOptions,
827        column_specific_options: HashMap<String, PyParquetColumnOptions>,
828        py: Python,
829    ) -> PyDataFusionResult<()> {
830        let table_options = TableParquetOptions {
831            global: options.options,
832            column_specific_options: column_specific_options
833                .into_iter()
834                .map(|(k, v)| (k, v.options))
835                .collect(),
836            ..Default::default()
837        };
838
839        wait_for_future(
840            py,
841            self.df.as_ref().clone().write_parquet(
842                path,
843                DataFrameWriteOptions::new(),
844                Option::from(table_options),
845            ),
846        )??;
847        Ok(())
848    }
849
850    /// Executes a query and writes the results to a partitioned JSON file.
851    fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
852        wait_for_future(
853            py,
854            self.df
855                .as_ref()
856                .clone()
857                .write_json(path, DataFrameWriteOptions::new(), None),
858        )??;
859        Ok(())
860    }
861
862    /// Convert to Arrow Table
863    /// Collect the batches and pass to Arrow Table
864    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
865        let batches = self.collect(py)?.into_pyobject(py)?;
866        let schema = self.schema().into_pyobject(py)?;
867
868        // Instantiate pyarrow Table object and use its from_batches method
869        let table_class = py.import("pyarrow")?.getattr("Table")?;
870        let args = PyTuple::new(py, &[batches, schema])?;
871        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
872        Ok(table)
873    }
874
875    #[pyo3(signature = (requested_schema=None))]
876    fn __arrow_c_stream__<'py>(
877        &'py mut self,
878        py: Python<'py>,
879        requested_schema: Option<Bound<'py, PyCapsule>>,
880    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
881        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
882        let mut schema: Schema = self.df.schema().to_owned().into();
883
884        if let Some(schema_capsule) = requested_schema {
885            validate_pycapsule(&schema_capsule, "arrow_schema")?;
886
887            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
888            let desired_schema = Schema::try_from(schema_ptr)?;
889
890            schema = project_schema(schema, desired_schema)?;
891
892            batches = batches
893                .into_iter()
894                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
895                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
896        }
897
898        let batches_wrapped = batches.into_iter().map(Ok);
899
900        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
901        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
902
903        let ffi_stream = FFI_ArrowArrayStream::new(reader);
904        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
905        PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
906    }
907
908    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
909        // create a Tokio runtime to run the async code
910        let rt = &get_tokio_runtime().0;
911        let df = self.df.as_ref().clone();
912        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
913            rt.spawn(async move { df.execute_stream().await });
914        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
915        Ok(PyRecordBatchStream::new(stream))
916    }
917
918    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
919        // create a Tokio runtime to run the async code
920        let rt = &get_tokio_runtime().0;
921        let df = self.df.as_ref().clone();
922        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
923            rt.spawn(async move { df.execute_stream_partitioned().await });
924        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
925            .map_err(py_datafusion_err)?
926            .map_err(py_datafusion_err)?;
927
928        Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
929    }
930
931    /// Convert to pandas dataframe with pyarrow
932    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
933    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
934        let table = self.to_arrow_table(py)?;
935
936        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
937        let result = table.call_method0(py, "to_pandas")?;
938        Ok(result)
939    }
940
941    /// Convert to Python list using pyarrow
942    /// Each list item represents one row encoded as dictionary
943    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
944        let table = self.to_arrow_table(py)?;
945
946        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
947        let result = table.call_method0(py, "to_pylist")?;
948        Ok(result)
949    }
950
951    /// Convert to Python dictionary using pyarrow
952    /// Each dictionary key is a column and the dictionary value represents the column values
953    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
954        let table = self.to_arrow_table(py)?;
955
956        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
957        let result = table.call_method0(py, "to_pydict")?;
958        Ok(result)
959    }
960
961    /// Convert to polars dataframe with pyarrow
962    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
963    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
964        let table = self.to_arrow_table(py)?;
965        let dataframe = py.import("polars")?.getattr("DataFrame")?;
966        let args = PyTuple::new(py, &[table])?;
967        let result: PyObject = dataframe.call1(args)?.into();
968        Ok(result)
969    }
970
971    // Executes this DataFrame to get the total number of rows.
972    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
973        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
974    }
975
976    /// Fill null values with a specified value for specific columns
977    #[pyo3(signature = (value, columns=None))]
978    fn fill_null(
979        &self,
980        value: PyObject,
981        columns: Option<Vec<PyBackedStr>>,
982        py: Python,
983    ) -> PyDataFusionResult<Self> {
984        let scalar_value = py_obj_to_scalar_value(py, value)?;
985
986        let cols = match columns {
987            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
988            None => Vec::new(), // Empty vector means fill null for all columns
989        };
990
991        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
992        Ok(Self::new(df))
993    }
994}
995
996/// Print DataFrame
997fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
998    // Get string representation of record batches
999    let batches = wait_for_future(py, df.collect())??;
1000    let result = if batches.is_empty() {
1001        "DataFrame has no rows".to_string()
1002    } else {
1003        match pretty::pretty_format_batches(&batches) {
1004            Ok(batch) => format!("DataFrame()\n{batch}"),
1005            Err(err) => format!("Error: {:?}", err.to_string()),
1006        }
1007    };
1008
1009    // Import the Python 'builtins' module to access the print function
1010    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
1011    let print = py.import("builtins")?.getattr("print")?;
1012    print.call1((result,))?;
1013    Ok(())
1014}
1015
1016fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1017    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1018
1019    let project_indices: Vec<usize> = to_schema
1020        .fields
1021        .iter()
1022        .map(|field| field.name())
1023        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1024        .collect();
1025
1026    merged_schema.project(&project_indices)
1027}
1028
1029fn record_batch_into_schema(
1030    record_batch: RecordBatch,
1031    schema: &Schema,
1032) -> Result<RecordBatch, ArrowError> {
1033    let schema = Arc::new(schema.clone());
1034    let base_schema = record_batch.schema();
1035    if base_schema.fields().is_empty() {
1036        // Nothing to project
1037        return Ok(RecordBatch::new_empty(schema));
1038    }
1039
1040    let array_size = record_batch.column(0).len();
1041    let mut data_arrays = Vec::with_capacity(schema.fields().len());
1042
1043    for field in schema.fields() {
1044        let desired_data_type = field.data_type();
1045        if let Some(original_data) = record_batch.column_by_name(field.name()) {
1046            let original_data_type = original_data.data_type();
1047
1048            if can_cast_types(original_data_type, desired_data_type) {
1049                data_arrays.push(arrow::compute::kernels::cast(
1050                    original_data,
1051                    desired_data_type,
1052                )?);
1053            } else if field.is_nullable() {
1054                data_arrays.push(new_null_array(desired_data_type, array_size));
1055            } else {
1056                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1057            }
1058        } else {
1059            if !field.is_nullable() {
1060                return Err(ArrowError::CastError(format!(
1061                    "Attempting to set null to non-nullable field {} during schema projection.",
1062                    field.name()
1063                )));
1064            }
1065            data_arrays.push(new_null_array(desired_data_type, array_size));
1066        }
1067    }
1068
1069    RecordBatch::try_new(schema, data_arrays)
1070}
1071
1072/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
1073/// It additionally returns a bool, which indicates if there are more record batches available.
1074/// We do this so we can determine if we should indicate to the user that the data has been
1075/// truncated. This collects until we have archived both of these two conditions
1076///
1077/// - We have collected our minimum number of rows
1078/// - We have reached our limit, either data size or maximum number of rows
1079///
1080/// Otherwise it will return when the stream has exhausted. If you want a specific number of
1081/// rows, set min_rows == max_rows.
1082async fn collect_record_batches_to_display(
1083    df: DataFrame,
1084    config: FormatterConfig,
1085) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1086    let FormatterConfig {
1087        max_bytes,
1088        min_rows,
1089        repr_rows,
1090    } = config;
1091
1092    let partitioned_stream = df.execute_stream_partitioned().await?;
1093    let mut stream = futures::stream::iter(partitioned_stream).flatten();
1094    let mut size_estimate_so_far = 0;
1095    let mut rows_so_far = 0;
1096    let mut record_batches = Vec::default();
1097    let mut has_more = false;
1098
1099    // ensure minimum rows even if memory/row limits are hit
1100    while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1101        let mut rb = match stream.next().await {
1102            None => {
1103                break;
1104            }
1105            Some(Ok(r)) => r,
1106            Some(Err(e)) => return Err(e),
1107        };
1108
1109        let mut rows_in_rb = rb.num_rows();
1110        if rows_in_rb > 0 {
1111            size_estimate_so_far += rb.get_array_memory_size();
1112
1113            if size_estimate_so_far > max_bytes {
1114                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1115                let total_rows = rows_in_rb + rows_so_far;
1116
1117                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1118                if reduced_row_num < min_rows {
1119                    reduced_row_num = min_rows.min(total_rows);
1120                }
1121
1122                let limited_rows_this_rb = reduced_row_num - rows_so_far;
1123                if limited_rows_this_rb < rows_in_rb {
1124                    rows_in_rb = limited_rows_this_rb;
1125                    rb = rb.slice(0, limited_rows_this_rb);
1126                    has_more = true;
1127                }
1128            }
1129
1130            if rows_in_rb + rows_so_far > repr_rows {
1131                rb = rb.slice(0, repr_rows - rows_so_far);
1132                has_more = true;
1133            }
1134
1135            rows_so_far += rb.num_rows();
1136            record_batches.push(rb);
1137        }
1138    }
1139
1140    if record_batches.is_empty() {
1141        return Ok((Vec::default(), false));
1142    }
1143
1144    if !has_more {
1145        // Data was not already truncated, so check to see if more record batches remain
1146        has_more = match stream.try_next().await {
1147            Ok(None) => false, // reached end
1148            Ok(Some(_)) => true,
1149            Err(_) => false, // Stream disconnected
1150        };
1151    }
1152
1153    Ok((record_batches, has_more))
1154}