datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::CString;
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use datafusion::arrow::datatypes::Schema;
29use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
30use datafusion::arrow::util::pretty;
31use datafusion::common::UnnestOptions;
32use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
33use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
34use datafusion::error::DataFusionError;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::logical_expr::dml::InsertOp;
37use datafusion::logical_expr::SortExpr;
38use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
39use datafusion::prelude::*;
40use futures::{StreamExt, TryStreamExt};
41use pyo3::exceptions::PyValueError;
42use pyo3::prelude::*;
43use pyo3::pybacked::PyBackedStr;
44use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
45use tokio::task::JoinHandle;
46
47use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
48use crate::expr::sort_expr::to_sort_expressions;
49use crate::physical_plan::PyExecutionPlan;
50use crate::record_batch::PyRecordBatchStream;
51use crate::sql::logical::PyLogicalPlan;
52use crate::table::PyTable;
53use crate::utils::{
54    get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
55};
56use crate::{
57    errors::PyDataFusionResult,
58    expr::{sort_expr::PySortExpr, PyExpr},
59};
60
61use parking_lot::Mutex;
62
63// Type aliases to simplify very complex types used in this file and
64// avoid compiler complaints about deeply nested types in struct fields.
65type CachedBatches = Option<(Vec<RecordBatch>, bool)>;
66type SharedCachedBatches = Arc<Mutex<CachedBatches>>;
67
68/// Configuration for DataFrame display formatting
69#[derive(Debug, Clone)]
70pub struct FormatterConfig {
71    /// Maximum memory in bytes to use for display (default: 2MB)
72    pub max_bytes: usize,
73    /// Minimum number of rows to display (default: 20)
74    pub min_rows: usize,
75    /// Number of rows to include in __repr__ output (default: 10)
76    pub repr_rows: usize,
77}
78
79impl Default for FormatterConfig {
80    fn default() -> Self {
81        Self {
82            max_bytes: 2 * 1024 * 1024, // 2MB
83            min_rows: 20,
84            repr_rows: 10,
85        }
86    }
87}
88
89impl FormatterConfig {
90    /// Validates that all configuration values are positive integers.
91    ///
92    /// # Returns
93    ///
94    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
95    pub fn validate(&self) -> Result<(), String> {
96        if self.max_bytes == 0 {
97            return Err("max_bytes must be a positive integer".to_string());
98        }
99
100        if self.min_rows == 0 {
101            return Err("min_rows must be a positive integer".to_string());
102        }
103
104        if self.repr_rows == 0 {
105            return Err("repr_rows must be a positive integer".to_string());
106        }
107
108        Ok(())
109    }
110}
111
112/// Holds the Python formatter and its configuration
113struct PythonFormatter<'py> {
114    /// The Python formatter object
115    formatter: Bound<'py, PyAny>,
116    /// The formatter configuration
117    config: FormatterConfig,
118}
119
120/// Get the Python formatter and its configuration
121fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
122    let formatter = import_python_formatter(py)?;
123    let config = build_formatter_config_from_python(&formatter)?;
124    Ok(PythonFormatter { formatter, config })
125}
126
127/// Get the Python formatter from the datafusion.dataframe_formatter module
128fn import_python_formatter(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
129    let formatter_module = py.import("datafusion.dataframe_formatter")?;
130    let get_formatter = formatter_module.getattr("get_formatter")?;
131    get_formatter.call0()
132}
133
134// Helper function to extract attributes with fallback to default
135fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
136where
137    T: for<'py> pyo3::FromPyObject<'py> + Clone,
138{
139    py_object
140        .getattr(attr_name)
141        .and_then(|v| v.extract::<T>())
142        .unwrap_or_else(|_| default_value.clone())
143}
144
145/// Helper function to create a FormatterConfig from a Python formatter object
146fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
147    let default_config = FormatterConfig::default();
148    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
149    let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
150    let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
151
152    let config = FormatterConfig {
153        max_bytes,
154        min_rows,
155        repr_rows,
156    };
157
158    // Return the validated config, converting String error to PyErr
159    config.validate().map_err(PyValueError::new_err)?;
160    Ok(config)
161}
162
163/// Python mapping of `ParquetOptions` (includes just the writer-related options).
164#[pyclass(frozen, name = "ParquetWriterOptions", module = "datafusion", subclass)]
165#[derive(Clone, Default)]
166pub struct PyParquetWriterOptions {
167    options: ParquetOptions,
168}
169
170#[pymethods]
171impl PyParquetWriterOptions {
172    #[new]
173    #[allow(clippy::too_many_arguments)]
174    pub fn new(
175        data_pagesize_limit: usize,
176        write_batch_size: usize,
177        writer_version: String,
178        skip_arrow_metadata: bool,
179        compression: Option<String>,
180        dictionary_enabled: Option<bool>,
181        dictionary_page_size_limit: usize,
182        statistics_enabled: Option<String>,
183        max_row_group_size: usize,
184        created_by: String,
185        column_index_truncate_length: Option<usize>,
186        statistics_truncate_length: Option<usize>,
187        data_page_row_count_limit: usize,
188        encoding: Option<String>,
189        bloom_filter_on_write: bool,
190        bloom_filter_fpp: Option<f64>,
191        bloom_filter_ndv: Option<u64>,
192        allow_single_file_parallelism: bool,
193        maximum_parallel_row_group_writers: usize,
194        maximum_buffered_record_batches_per_stream: usize,
195    ) -> Self {
196        Self {
197            options: ParquetOptions {
198                data_pagesize_limit,
199                write_batch_size,
200                writer_version,
201                skip_arrow_metadata,
202                compression,
203                dictionary_enabled,
204                dictionary_page_size_limit,
205                statistics_enabled,
206                max_row_group_size,
207                created_by,
208                column_index_truncate_length,
209                statistics_truncate_length,
210                data_page_row_count_limit,
211                encoding,
212                bloom_filter_on_write,
213                bloom_filter_fpp,
214                bloom_filter_ndv,
215                allow_single_file_parallelism,
216                maximum_parallel_row_group_writers,
217                maximum_buffered_record_batches_per_stream,
218                ..Default::default()
219            },
220        }
221    }
222}
223
224/// Python mapping of `ParquetColumnOptions`.
225#[pyclass(frozen, name = "ParquetColumnOptions", module = "datafusion", subclass)]
226#[derive(Clone, Default)]
227pub struct PyParquetColumnOptions {
228    options: ParquetColumnOptions,
229}
230
231#[pymethods]
232impl PyParquetColumnOptions {
233    #[new]
234    pub fn new(
235        bloom_filter_enabled: Option<bool>,
236        encoding: Option<String>,
237        dictionary_enabled: Option<bool>,
238        compression: Option<String>,
239        statistics_enabled: Option<String>,
240        bloom_filter_fpp: Option<f64>,
241        bloom_filter_ndv: Option<u64>,
242    ) -> Self {
243        Self {
244            options: ParquetColumnOptions {
245                bloom_filter_enabled,
246                encoding,
247                dictionary_enabled,
248                compression,
249                statistics_enabled,
250                bloom_filter_fpp,
251                bloom_filter_ndv,
252            },
253        }
254    }
255}
256
257/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
258/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
259/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
260#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)]
261#[derive(Clone)]
262pub struct PyDataFrame {
263    df: Arc<DataFrame>,
264
265    // In IPython environment cache batches between __repr__ and _repr_html_ calls.
266    batches: SharedCachedBatches,
267}
268
269impl PyDataFrame {
270    /// creates a new PyDataFrame
271    pub fn new(df: DataFrame) -> Self {
272        Self {
273            df: Arc::new(df),
274            batches: Arc::new(Mutex::new(None)),
275        }
276    }
277
278    /// Return a clone of the inner Arc<DataFrame> for crate-local callers.
279    pub(crate) fn inner_df(&self) -> Arc<DataFrame> {
280        Arc::clone(&self.df)
281    }
282
283    fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult<String> {
284        // Get the Python formatter and config
285        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
286
287        let is_ipython = *is_ipython_env(py);
288
289        let (cached_batches, should_cache) = {
290            let mut cache = self.batches.lock();
291            let should_cache = is_ipython && cache.is_none();
292            let batches = cache.take();
293            (batches, should_cache)
294        };
295
296        let (batches, has_more) = match cached_batches {
297            Some(b) => b,
298            None => wait_for_future(
299                py,
300                collect_record_batches_to_display(self.df.as_ref().clone(), config),
301            )??,
302        };
303
304        if batches.is_empty() {
305            // This should not be reached, but do it for safety since we index into the vector below
306            return Ok("No data to display".to_string());
307        }
308
309        let table_uuid = uuid::Uuid::new_v4().to_string();
310
311        // Convert record batches to PyObject list
312        let py_batches = batches
313            .iter()
314            .map(|rb| rb.to_pyarrow(py))
315            .collect::<PyResult<Vec<PyObject>>>()?;
316
317        let py_schema = self.schema().into_pyobject(py)?;
318
319        let kwargs = pyo3::types::PyDict::new(py);
320        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
321        kwargs.set_item("batches", py_batches_list)?;
322        kwargs.set_item("schema", py_schema)?;
323        kwargs.set_item("has_more", has_more)?;
324        kwargs.set_item("table_uuid", table_uuid)?;
325
326        let method_name = match as_html {
327            true => "format_html",
328            false => "format_str",
329        };
330
331        let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
332        let html_str: String = html_result.extract()?;
333
334        if should_cache {
335            let mut cache = self.batches.lock();
336            *cache = Some((batches.clone(), has_more));
337        }
338
339        Ok(html_str)
340    }
341}
342
343#[pymethods]
344impl PyDataFrame {
345    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
346    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
347        if let Ok(key) = key.extract::<PyBackedStr>() {
348            // df[col]
349            self.select_columns(vec![key])
350        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
351            // df[col1, col2, col3]
352            let keys = tuple
353                .iter()
354                .map(|item| item.extract::<PyBackedStr>())
355                .collect::<PyResult<Vec<PyBackedStr>>>()?;
356            self.select_columns(keys)
357        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
358            // df[[col1, col2, col3]]
359            self.select_columns(keys)
360        } else {
361            let message = "DataFrame can only be indexed by string index or indices".to_string();
362            Err(PyDataFusionError::Common(message))
363        }
364    }
365
366    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
367        self.prepare_repr_string(py, false)
368    }
369
370    #[staticmethod]
371    #[expect(unused_variables)]
372    fn default_str_repr<'py>(
373        batches: Vec<Bound<'py, PyAny>>,
374        schema: &Bound<'py, PyAny>,
375        has_more: bool,
376        table_uuid: &str,
377    ) -> PyResult<String> {
378        let batches = batches
379            .into_iter()
380            .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
381            .collect::<PyResult<Vec<RecordBatch>>>()?
382            .into_iter()
383            .filter(|batch| batch.num_rows() > 0)
384            .collect::<Vec<_>>();
385
386        if batches.is_empty() {
387            return Ok("No data to display".to_owned());
388        }
389
390        let batches_as_displ =
391            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
392
393        let additional_str = match has_more {
394            true => "\nData truncated.",
395            false => "",
396        };
397
398        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
399    }
400
401    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
402        self.prepare_repr_string(py, true)
403    }
404
405    /// Calculate summary statistics for a DataFrame
406    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
407        let df = self.df.as_ref().clone();
408        let stat_df = wait_for_future(py, df.describe())??;
409        Ok(Self::new(stat_df))
410    }
411
412    /// Returns the schema from the logical plan
413    fn schema(&self) -> PyArrowType<Schema> {
414        PyArrowType(self.df.schema().into())
415    }
416
417    /// Convert this DataFrame into a Table Provider that can be used in register_table
418    /// By convention, into_... methods consume self and return the new object.
419    /// Disabling the clippy lint, so we can use &self
420    /// because we're working with Python bindings
421    /// where objects are shared
422    #[allow(clippy::wrong_self_convention)]
423    pub fn into_view(&self) -> PyDataFusionResult<PyTable> {
424        // Call the underlying Rust DataFrame::into_view method.
425        // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
426        // so that we don't invalidate this PyDataFrame.
427        let table_provider = self.df.as_ref().clone().into_view();
428        Ok(PyTable::from(table_provider))
429    }
430
431    #[pyo3(signature = (*args))]
432    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
433        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
434        let df = self.df.as_ref().clone().select_columns(&args)?;
435        Ok(Self::new(df))
436    }
437
438    #[pyo3(signature = (*args))]
439    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
440        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
441        let df = self.df.as_ref().clone().select(expr)?;
442        Ok(Self::new(df))
443    }
444
445    #[pyo3(signature = (*args))]
446    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
447        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
448        let df = self.df.as_ref().clone().drop_columns(&cols)?;
449        Ok(Self::new(df))
450    }
451
452    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
453        let df = self.df.as_ref().clone().filter(predicate.into())?;
454        Ok(Self::new(df))
455    }
456
457    fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {
458        self.df
459            .as_ref()
460            .parse_sql_expr(&expr)
461            .map(|e| PyExpr::from(e))
462            .map_err(PyDataFusionError::from)
463    }
464
465    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
466        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
467        Ok(Self::new(df))
468    }
469
470    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
471        let mut df = self.df.as_ref().clone();
472        for expr in exprs {
473            let expr: Expr = expr.into();
474            let name = format!("{}", expr.schema_name());
475            df = df.with_column(name.as_str(), expr)?
476        }
477        Ok(Self::new(df))
478    }
479
480    /// Rename one column by applying a new projection. This is a no-op if the column to be
481    /// renamed does not exist.
482    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
483        let df = self
484            .df
485            .as_ref()
486            .clone()
487            .with_column_renamed(old_name, new_name)?;
488        Ok(Self::new(df))
489    }
490
491    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
492        let group_by = group_by.into_iter().map(|e| e.into()).collect();
493        let aggs = aggs.into_iter().map(|e| e.into()).collect();
494        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
495        Ok(Self::new(df))
496    }
497
498    #[pyo3(signature = (*exprs))]
499    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
500        let exprs = to_sort_expressions(exprs);
501        let df = self.df.as_ref().clone().sort(exprs)?;
502        Ok(Self::new(df))
503    }
504
505    #[pyo3(signature = (count, offset=0))]
506    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
507        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
508        Ok(Self::new(df))
509    }
510
511    /// Executes the plan, returning a list of `RecordBatch`es.
512    /// Unless some order is specified in the plan, there is no
513    /// guarantee of the order of the result.
514    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
515        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
516            .map_err(PyDataFusionError::from)?;
517        // cannot use PyResult<Vec<RecordBatch>> return type due to
518        // https://github.com/PyO3/pyo3/issues/1813
519        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
520    }
521
522    /// Cache DataFrame.
523    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
524        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
525        Ok(Self::new(df))
526    }
527
528    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
529    /// maintaining the input partitioning.
530    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
531        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
532            .map_err(PyDataFusionError::from)?;
533
534        batches
535            .into_iter()
536            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
537            .collect()
538    }
539
540    /// Print the result, 20 lines by default
541    #[pyo3(signature = (num=20))]
542    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
543        let df = self.df.as_ref().clone().limit(0, Some(num))?;
544        print_dataframe(py, df)
545    }
546
547    /// Filter out duplicate rows
548    fn distinct(&self) -> PyDataFusionResult<Self> {
549        let df = self.df.as_ref().clone().distinct()?;
550        Ok(Self::new(df))
551    }
552
553    fn join(
554        &self,
555        right: PyDataFrame,
556        how: &str,
557        left_on: Vec<PyBackedStr>,
558        right_on: Vec<PyBackedStr>,
559    ) -> PyDataFusionResult<Self> {
560        let join_type = match how {
561            "inner" => JoinType::Inner,
562            "left" => JoinType::Left,
563            "right" => JoinType::Right,
564            "full" => JoinType::Full,
565            "semi" => JoinType::LeftSemi,
566            "anti" => JoinType::LeftAnti,
567            how => {
568                return Err(PyDataFusionError::Common(format!(
569                    "The join type {how} does not exist or is not implemented"
570                )));
571            }
572        };
573
574        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
575        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
576
577        let df = self.df.as_ref().clone().join(
578            right.df.as_ref().clone(),
579            join_type,
580            &left_keys,
581            &right_keys,
582            None,
583        )?;
584        Ok(Self::new(df))
585    }
586
587    fn join_on(
588        &self,
589        right: PyDataFrame,
590        on_exprs: Vec<PyExpr>,
591        how: &str,
592    ) -> PyDataFusionResult<Self> {
593        let join_type = match how {
594            "inner" => JoinType::Inner,
595            "left" => JoinType::Left,
596            "right" => JoinType::Right,
597            "full" => JoinType::Full,
598            "semi" => JoinType::LeftSemi,
599            "anti" => JoinType::LeftAnti,
600            how => {
601                return Err(PyDataFusionError::Common(format!(
602                    "The join type {how} does not exist or is not implemented"
603                )));
604            }
605        };
606        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
607
608        let df = self
609            .df
610            .as_ref()
611            .clone()
612            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
613        Ok(Self::new(df))
614    }
615
616    /// Print the query plan
617    #[pyo3(signature = (verbose=false, analyze=false))]
618    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
619        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
620        print_dataframe(py, df)
621    }
622
623    /// Get the logical plan for this `DataFrame`
624    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
625        Ok(self.df.as_ref().clone().logical_plan().clone().into())
626    }
627
628    /// Get the optimized logical plan for this `DataFrame`
629    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
630        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
631    }
632
633    /// Get the execution plan for this `DataFrame`
634    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
635        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
636        Ok(plan.into())
637    }
638
639    /// Repartition a `DataFrame` based on a logical partitioning scheme.
640    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
641        let new_df = self
642            .df
643            .as_ref()
644            .clone()
645            .repartition(Partitioning::RoundRobinBatch(num))?;
646        Ok(Self::new(new_df))
647    }
648
649    /// Repartition a `DataFrame` based on a logical partitioning scheme.
650    #[pyo3(signature = (*args, num))]
651    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
652        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
653        let new_df = self
654            .df
655            .as_ref()
656            .clone()
657            .repartition(Partitioning::Hash(expr, num))?;
658        Ok(Self::new(new_df))
659    }
660
661    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
662    /// two `DataFrame`s must have exactly the same schema
663    #[pyo3(signature = (py_df, distinct=false))]
664    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
665        let new_df = if distinct {
666            self.df
667                .as_ref()
668                .clone()
669                .union_distinct(py_df.df.as_ref().clone())?
670        } else {
671            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
672        };
673
674        Ok(Self::new(new_df))
675    }
676
677    /// Calculate the distinct union of two `DataFrame`s.  The
678    /// two `DataFrame`s must have exactly the same schema
679    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
680        let new_df = self
681            .df
682            .as_ref()
683            .clone()
684            .union_distinct(py_df.df.as_ref().clone())?;
685        Ok(Self::new(new_df))
686    }
687
688    #[pyo3(signature = (column, preserve_nulls=true))]
689    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
690        // TODO: expose RecursionUnnestOptions
691        // REF: https://github.com/apache/datafusion/pull/11577
692        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
693        let df = self
694            .df
695            .as_ref()
696            .clone()
697            .unnest_columns_with_options(&[column], unnest_options)?;
698        Ok(Self::new(df))
699    }
700
701    #[pyo3(signature = (columns, preserve_nulls=true))]
702    fn unnest_columns(
703        &self,
704        columns: Vec<String>,
705        preserve_nulls: bool,
706    ) -> PyDataFusionResult<Self> {
707        // TODO: expose RecursionUnnestOptions
708        // REF: https://github.com/apache/datafusion/pull/11577
709        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
710        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
711        let df = self
712            .df
713            .as_ref()
714            .clone()
715            .unnest_columns_with_options(&cols, unnest_options)?;
716        Ok(Self::new(df))
717    }
718
719    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
720    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
721        let new_df = self
722            .df
723            .as_ref()
724            .clone()
725            .intersect(py_df.df.as_ref().clone())?;
726        Ok(Self::new(new_df))
727    }
728
729    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
730    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
731        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
732        Ok(Self::new(new_df))
733    }
734
735    /// Write a `DataFrame` to a CSV file.
736    fn write_csv(
737        &self,
738        py: Python,
739        path: &str,
740        with_header: bool,
741        write_options: Option<PyDataFrameWriteOptions>,
742    ) -> PyDataFusionResult<()> {
743        let csv_options = CsvOptions {
744            has_header: Some(with_header),
745            ..Default::default()
746        };
747        let write_options = write_options
748            .map(DataFrameWriteOptions::from)
749            .unwrap_or_default();
750
751        wait_for_future(
752            py,
753            self.df
754                .as_ref()
755                .clone()
756                .write_csv(path, write_options, Some(csv_options)),
757        )??;
758        Ok(())
759    }
760
761    /// Write a `DataFrame` to a Parquet file.
762    #[pyo3(signature = (
763        path,
764        compression="zstd",
765        compression_level=None,
766        write_options=None,
767        ))]
768    fn write_parquet(
769        &self,
770        path: &str,
771        compression: &str,
772        compression_level: Option<u32>,
773        write_options: Option<PyDataFrameWriteOptions>,
774        py: Python,
775    ) -> PyDataFusionResult<()> {
776        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
777            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
778        }
779
780        let _validated = match compression.to_lowercase().as_str() {
781            "snappy" => Compression::SNAPPY,
782            "gzip" => Compression::GZIP(
783                GzipLevel::try_new(compression_level.unwrap_or(6))
784                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
785            ),
786            "brotli" => Compression::BROTLI(
787                BrotliLevel::try_new(verify_compression_level(compression_level)?)
788                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
789            ),
790            "zstd" => Compression::ZSTD(
791                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
792                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
793            ),
794            "lzo" => Compression::LZO,
795            "lz4" => Compression::LZ4,
796            "lz4_raw" => Compression::LZ4_RAW,
797            "uncompressed" => Compression::UNCOMPRESSED,
798            _ => {
799                return Err(PyDataFusionError::Common(format!(
800                    "Unrecognized compression type {compression}"
801                )));
802            }
803        };
804
805        let mut compression_string = compression.to_string();
806        if let Some(level) = compression_level {
807            compression_string.push_str(&format!("({level})"));
808        }
809
810        let mut options = TableParquetOptions::default();
811        options.global.compression = Some(compression_string);
812        let write_options = write_options
813            .map(DataFrameWriteOptions::from)
814            .unwrap_or_default();
815
816        wait_for_future(
817            py,
818            self.df
819                .as_ref()
820                .clone()
821                .write_parquet(path, write_options, Option::from(options)),
822        )??;
823        Ok(())
824    }
825
826    /// Write a `DataFrame` to a Parquet file, using advanced options.
827    fn write_parquet_with_options(
828        &self,
829        path: &str,
830        options: PyParquetWriterOptions,
831        column_specific_options: HashMap<String, PyParquetColumnOptions>,
832        write_options: Option<PyDataFrameWriteOptions>,
833        py: Python,
834    ) -> PyDataFusionResult<()> {
835        let table_options = TableParquetOptions {
836            global: options.options,
837            column_specific_options: column_specific_options
838                .into_iter()
839                .map(|(k, v)| (k, v.options))
840                .collect(),
841            ..Default::default()
842        };
843        let write_options = write_options
844            .map(DataFrameWriteOptions::from)
845            .unwrap_or_default();
846        wait_for_future(
847            py,
848            self.df.as_ref().clone().write_parquet(
849                path,
850                write_options,
851                Option::from(table_options),
852            ),
853        )??;
854        Ok(())
855    }
856
857    /// Executes a query and writes the results to a partitioned JSON file.
858    fn write_json(
859        &self,
860        path: &str,
861        py: Python,
862        write_options: Option<PyDataFrameWriteOptions>,
863    ) -> PyDataFusionResult<()> {
864        let write_options = write_options
865            .map(DataFrameWriteOptions::from)
866            .unwrap_or_default();
867        wait_for_future(
868            py,
869            self.df
870                .as_ref()
871                .clone()
872                .write_json(path, write_options, None),
873        )??;
874        Ok(())
875    }
876
877    fn write_table(
878        &self,
879        py: Python,
880        table_name: &str,
881        write_options: Option<PyDataFrameWriteOptions>,
882    ) -> PyDataFusionResult<()> {
883        let write_options = write_options
884            .map(DataFrameWriteOptions::from)
885            .unwrap_or_default();
886        wait_for_future(
887            py,
888            self.df
889                .as_ref()
890                .clone()
891                .write_table(table_name, write_options),
892        )??;
893        Ok(())
894    }
895
896    /// Convert to Arrow Table
897    /// Collect the batches and pass to Arrow Table
898    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
899        let batches = self.collect(py)?.into_pyobject(py)?;
900        let schema = self.schema().into_pyobject(py)?;
901
902        // Instantiate pyarrow Table object and use its from_batches method
903        let table_class = py.import("pyarrow")?.getattr("Table")?;
904        let args = PyTuple::new(py, &[batches, schema])?;
905        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
906        Ok(table)
907    }
908
909    #[pyo3(signature = (requested_schema=None))]
910    fn __arrow_c_stream__<'py>(
911        &'py self,
912        py: Python<'py>,
913        requested_schema: Option<Bound<'py, PyCapsule>>,
914    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
915        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())??;
916        let mut schema: Schema = self.df.schema().to_owned().into();
917
918        if let Some(schema_capsule) = requested_schema {
919            validate_pycapsule(&schema_capsule, "arrow_schema")?;
920
921            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
922            let desired_schema = Schema::try_from(schema_ptr)?;
923
924            schema = project_schema(schema, desired_schema)?;
925
926            batches = batches
927                .into_iter()
928                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
929                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
930        }
931
932        let batches_wrapped = batches.into_iter().map(Ok);
933
934        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
935        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
936
937        let ffi_stream = FFI_ArrowArrayStream::new(reader);
938        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
939        PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
940    }
941
942    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
943        // create a Tokio runtime to run the async code
944        let rt = &get_tokio_runtime().0;
945        let df = self.df.as_ref().clone();
946        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
947            rt.spawn(async move { df.execute_stream().await });
948        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
949        Ok(PyRecordBatchStream::new(stream))
950    }
951
952    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
953        // create a Tokio runtime to run the async code
954        let rt = &get_tokio_runtime().0;
955        let df = self.df.as_ref().clone();
956        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
957            rt.spawn(async move { df.execute_stream_partitioned().await });
958        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })?
959            .map_err(py_datafusion_err)?
960            .map_err(py_datafusion_err)?;
961
962        Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
963    }
964
965    /// Convert to pandas dataframe with pyarrow
966    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
967    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
968        let table = self.to_arrow_table(py)?;
969
970        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
971        let result = table.call_method0(py, "to_pandas")?;
972        Ok(result)
973    }
974
975    /// Convert to Python list using pyarrow
976    /// Each list item represents one row encoded as dictionary
977    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
978        let table = self.to_arrow_table(py)?;
979
980        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
981        let result = table.call_method0(py, "to_pylist")?;
982        Ok(result)
983    }
984
985    /// Convert to Python dictionary using pyarrow
986    /// Each dictionary key is a column and the dictionary value represents the column values
987    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
988        let table = self.to_arrow_table(py)?;
989
990        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
991        let result = table.call_method0(py, "to_pydict")?;
992        Ok(result)
993    }
994
995    /// Convert to polars dataframe with pyarrow
996    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
997    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
998        let table = self.to_arrow_table(py)?;
999        let dataframe = py.import("polars")?.getattr("DataFrame")?;
1000        let args = PyTuple::new(py, &[table])?;
1001        let result: PyObject = dataframe.call1(args)?.into();
1002        Ok(result)
1003    }
1004
1005    // Executes this DataFrame to get the total number of rows.
1006    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
1007        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
1008    }
1009
1010    /// Fill null values with a specified value for specific columns
1011    #[pyo3(signature = (value, columns=None))]
1012    fn fill_null(
1013        &self,
1014        value: PyObject,
1015        columns: Option<Vec<PyBackedStr>>,
1016        py: Python,
1017    ) -> PyDataFusionResult<Self> {
1018        let scalar_value = py_obj_to_scalar_value(py, value)?;
1019
1020        let cols = match columns {
1021            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
1022            None => Vec::new(), // Empty vector means fill null for all columns
1023        };
1024
1025        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
1026        Ok(Self::new(df))
1027    }
1028}
1029
1030#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
1031#[pyclass(frozen, eq, eq_int, name = "InsertOp", module = "datafusion")]
1032pub enum PyInsertOp {
1033    APPEND,
1034    REPLACE,
1035    OVERWRITE,
1036}
1037
1038impl From<PyInsertOp> for InsertOp {
1039    fn from(value: PyInsertOp) -> Self {
1040        match value {
1041            PyInsertOp::APPEND => InsertOp::Append,
1042            PyInsertOp::REPLACE => InsertOp::Replace,
1043            PyInsertOp::OVERWRITE => InsertOp::Overwrite,
1044        }
1045    }
1046}
1047
1048#[derive(Debug, Clone)]
1049#[pyclass(frozen, name = "DataFrameWriteOptions", module = "datafusion")]
1050pub struct PyDataFrameWriteOptions {
1051    insert_operation: InsertOp,
1052    single_file_output: bool,
1053    partition_by: Vec<String>,
1054    sort_by: Vec<SortExpr>,
1055}
1056
1057impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
1058    fn from(value: PyDataFrameWriteOptions) -> Self {
1059        DataFrameWriteOptions::new()
1060            .with_insert_operation(value.insert_operation)
1061            .with_single_file_output(value.single_file_output)
1062            .with_partition_by(value.partition_by)
1063            .with_sort_by(value.sort_by)
1064    }
1065}
1066
1067#[pymethods]
1068impl PyDataFrameWriteOptions {
1069    #[new]
1070    fn new(
1071        insert_operation: Option<PyInsertOp>,
1072        single_file_output: bool,
1073        partition_by: Option<Vec<String>>,
1074        sort_by: Option<Vec<PySortExpr>>,
1075    ) -> Self {
1076        let insert_operation = insert_operation.map(Into::into).unwrap_or(InsertOp::Append);
1077        let sort_by = sort_by
1078            .unwrap_or_default()
1079            .into_iter()
1080            .map(Into::into)
1081            .collect();
1082        Self {
1083            insert_operation,
1084            single_file_output,
1085            partition_by: partition_by.unwrap_or_default(),
1086            sort_by,
1087        }
1088    }
1089}
1090
1091/// Print DataFrame
1092fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
1093    // Get string representation of record batches
1094    let batches = wait_for_future(py, df.collect())??;
1095    let result = if batches.is_empty() {
1096        "DataFrame has no rows".to_string()
1097    } else {
1098        match pretty::pretty_format_batches(&batches) {
1099            Ok(batch) => format!("DataFrame()\n{batch}"),
1100            Err(err) => format!("Error: {:?}", err.to_string()),
1101        }
1102    };
1103
1104    // Import the Python 'builtins' module to access the print function
1105    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
1106    let print = py.import("builtins")?.getattr("print")?;
1107    print.call1((result,))?;
1108    Ok(())
1109}
1110
1111fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1112    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1113
1114    let project_indices: Vec<usize> = to_schema
1115        .fields
1116        .iter()
1117        .map(|field| field.name())
1118        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1119        .collect();
1120
1121    merged_schema.project(&project_indices)
1122}
1123
1124fn record_batch_into_schema(
1125    record_batch: RecordBatch,
1126    schema: &Schema,
1127) -> Result<RecordBatch, ArrowError> {
1128    let schema = Arc::new(schema.clone());
1129    let base_schema = record_batch.schema();
1130    if base_schema.fields().is_empty() {
1131        // Nothing to project
1132        return Ok(RecordBatch::new_empty(schema));
1133    }
1134
1135    let array_size = record_batch.column(0).len();
1136    let mut data_arrays = Vec::with_capacity(schema.fields().len());
1137
1138    for field in schema.fields() {
1139        let desired_data_type = field.data_type();
1140        if let Some(original_data) = record_batch.column_by_name(field.name()) {
1141            let original_data_type = original_data.data_type();
1142
1143            if can_cast_types(original_data_type, desired_data_type) {
1144                data_arrays.push(arrow::compute::kernels::cast(
1145                    original_data,
1146                    desired_data_type,
1147                )?);
1148            } else if field.is_nullable() {
1149                data_arrays.push(new_null_array(desired_data_type, array_size));
1150            } else {
1151                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1152            }
1153        } else {
1154            if !field.is_nullable() {
1155                return Err(ArrowError::CastError(format!(
1156                    "Attempting to set null to non-nullable field {} during schema projection.",
1157                    field.name()
1158                )));
1159            }
1160            data_arrays.push(new_null_array(desired_data_type, array_size));
1161        }
1162    }
1163
1164    RecordBatch::try_new(schema, data_arrays)
1165}
1166
1167/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
1168/// It additionally returns a bool, which indicates if there are more record batches available.
1169/// We do this so we can determine if we should indicate to the user that the data has been
1170/// truncated. This collects until we have archived both of these two conditions
1171///
1172/// - We have collected our minimum number of rows
1173/// - We have reached our limit, either data size or maximum number of rows
1174///
1175/// Otherwise it will return when the stream has exhausted. If you want a specific number of
1176/// rows, set min_rows == max_rows.
1177async fn collect_record_batches_to_display(
1178    df: DataFrame,
1179    config: FormatterConfig,
1180) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1181    let FormatterConfig {
1182        max_bytes,
1183        min_rows,
1184        repr_rows,
1185    } = config;
1186
1187    let partitioned_stream = df.execute_stream_partitioned().await?;
1188    let mut stream = futures::stream::iter(partitioned_stream).flatten();
1189    let mut size_estimate_so_far = 0;
1190    let mut rows_so_far = 0;
1191    let mut record_batches = Vec::default();
1192    let mut has_more = false;
1193
1194    // ensure minimum rows even if memory/row limits are hit
1195    while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1196        let mut rb = match stream.next().await {
1197            None => {
1198                break;
1199            }
1200            Some(Ok(r)) => r,
1201            Some(Err(e)) => return Err(e),
1202        };
1203
1204        let mut rows_in_rb = rb.num_rows();
1205        if rows_in_rb > 0 {
1206            size_estimate_so_far += rb.get_array_memory_size();
1207
1208            if size_estimate_so_far > max_bytes {
1209                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1210                let total_rows = rows_in_rb + rows_so_far;
1211
1212                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1213                if reduced_row_num < min_rows {
1214                    reduced_row_num = min_rows.min(total_rows);
1215                }
1216
1217                let limited_rows_this_rb = reduced_row_num - rows_so_far;
1218                if limited_rows_this_rb < rows_in_rb {
1219                    rows_in_rb = limited_rows_this_rb;
1220                    rb = rb.slice(0, limited_rows_this_rb);
1221                    has_more = true;
1222                }
1223            }
1224
1225            if rows_in_rb + rows_so_far > repr_rows {
1226                rb = rb.slice(0, repr_rows - rows_so_far);
1227                has_more = true;
1228            }
1229
1230            rows_so_far += rb.num_rows();
1231            record_batches.push(rb);
1232        }
1233    }
1234
1235    if record_batches.is_empty() {
1236        return Ok((Vec::default(), false));
1237    }
1238
1239    if !has_more {
1240        // Data was not already truncated, so check to see if more record batches remain
1241        has_more = match stream.try_next().await {
1242            Ok(None) => false, // reached end
1243            Ok(Some(_)) => true,
1244            Err(_) => false, // Stream disconnected
1245        };
1246    }
1247
1248    Ok((record_batches, has_more))
1249}