datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::CString;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22use arrow::compute::can_cast_types;
23use arrow::error::ArrowError;
24use arrow::ffi::FFI_ArrowSchema;
25use arrow::ffi_stream::FFI_ArrowArrayStream;
26use datafusion::arrow::datatypes::Schema;
27use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
28use datafusion::arrow::util::pretty;
29use datafusion::common::UnnestOptions;
30use datafusion::config::{CsvOptions, TableParquetOptions};
31use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
32use datafusion::datasource::TableProvider;
33use datafusion::error::DataFusionError;
34use datafusion::execution::SendableRecordBatchStream;
35use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
36use datafusion::prelude::*;
37use datafusion_ffi::table_provider::FFI_TableProvider;
38use futures::{StreamExt, TryStreamExt};
39use pyo3::exceptions::PyValueError;
40use pyo3::prelude::*;
41use pyo3::pybacked::PyBackedStr;
42use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
43use tokio::task::JoinHandle;
44
45use crate::catalog::PyTable;
46use crate::errors::{py_datafusion_err, PyDataFusionError};
47use crate::expr::sort_expr::to_sort_expressions;
48use crate::physical_plan::PyExecutionPlan;
49use crate::record_batch::PyRecordBatchStream;
50use crate::sql::logical::PyLogicalPlan;
51use crate::utils::{
52    get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
53};
54use crate::{
55    errors::PyDataFusionResult,
56    expr::{sort_expr::PySortExpr, PyExpr},
57};
58
59// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
60// - we have not decided on the table_provider approach yet
61// this is an interim implementation
62#[pyclass(name = "TableProvider", module = "datafusion")]
63pub struct PyTableProvider {
64    provider: Arc<dyn TableProvider + Send>,
65}
66
67impl PyTableProvider {
68    pub fn new(provider: Arc<dyn TableProvider>) -> Self {
69        Self { provider }
70    }
71
72    pub fn as_table(&self) -> PyTable {
73        let table_provider: Arc<dyn TableProvider> = self.provider.clone();
74        PyTable::new(table_provider)
75    }
76}
77
78#[pymethods]
79impl PyTableProvider {
80    fn __datafusion_table_provider__<'py>(
81        &self,
82        py: Python<'py>,
83    ) -> PyResult<Bound<'py, PyCapsule>> {
84        let name = CString::new("datafusion_table_provider").unwrap();
85
86        let runtime = get_tokio_runtime().0.handle().clone();
87        let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
88
89        PyCapsule::new(py, provider, Some(name.clone()))
90    }
91}
92
93/// Configuration for DataFrame display formatting
94#[derive(Debug, Clone)]
95pub struct FormatterConfig {
96    /// Maximum memory in bytes to use for display (default: 2MB)
97    pub max_bytes: usize,
98    /// Minimum number of rows to display (default: 20)
99    pub min_rows: usize,
100    /// Number of rows to include in __repr__ output (default: 10)
101    pub repr_rows: usize,
102}
103
104impl Default for FormatterConfig {
105    fn default() -> Self {
106        Self {
107            max_bytes: 2 * 1024 * 1024, // 2MB
108            min_rows: 20,
109            repr_rows: 10,
110        }
111    }
112}
113
114impl FormatterConfig {
115    /// Validates that all configuration values are positive integers.
116    ///
117    /// # Returns
118    ///
119    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
120    pub fn validate(&self) -> Result<(), String> {
121        if self.max_bytes == 0 {
122            return Err("max_bytes must be a positive integer".to_string());
123        }
124
125        if self.min_rows == 0 {
126            return Err("min_rows must be a positive integer".to_string());
127        }
128
129        if self.repr_rows == 0 {
130            return Err("repr_rows must be a positive integer".to_string());
131        }
132
133        Ok(())
134    }
135}
136
137/// Holds the Python formatter and its configuration
138struct PythonFormatter<'py> {
139    /// The Python formatter object
140    formatter: Bound<'py, PyAny>,
141    /// The formatter configuration
142    config: FormatterConfig,
143}
144
145/// Get the Python formatter and its configuration
146fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
147    let formatter = import_python_formatter(py)?;
148    let config = build_formatter_config_from_python(&formatter)?;
149    Ok(PythonFormatter { formatter, config })
150}
151
152/// Get the Python formatter from the datafusion.html_formatter module
153fn import_python_formatter(py: Python) -> PyResult<Bound<'_, PyAny>> {
154    let formatter_module = py.import("datafusion.html_formatter")?;
155    let get_formatter = formatter_module.getattr("get_formatter")?;
156    get_formatter.call0()
157}
158
159// Helper function to extract attributes with fallback to default
160fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
161where
162    T: for<'py> pyo3::FromPyObject<'py> + Clone,
163{
164    py_object
165        .getattr(attr_name)
166        .and_then(|v| v.extract::<T>())
167        .unwrap_or_else(|_| default_value.clone())
168}
169
170/// Helper function to create a FormatterConfig from a Python formatter object
171fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
172    let default_config = FormatterConfig::default();
173    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
174    let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
175    let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
176
177    let config = FormatterConfig {
178        max_bytes,
179        min_rows,
180        repr_rows,
181    };
182
183    // Return the validated config, converting String error to PyErr
184    config.validate().map_err(PyValueError::new_err)?;
185    Ok(config)
186}
187
188/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
189/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
190/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
191#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
192#[derive(Clone)]
193pub struct PyDataFrame {
194    df: Arc<DataFrame>,
195}
196
197impl PyDataFrame {
198    /// creates a new PyDataFrame
199    pub fn new(df: DataFrame) -> Self {
200        Self { df: Arc::new(df) }
201    }
202}
203
204#[pymethods]
205impl PyDataFrame {
206    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
207    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
208        if let Ok(key) = key.extract::<PyBackedStr>() {
209            // df[col]
210            self.select_columns(vec![key])
211        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
212            // df[col1, col2, col3]
213            let keys = tuple
214                .iter()
215                .map(|item| item.extract::<PyBackedStr>())
216                .collect::<PyResult<Vec<PyBackedStr>>>()?;
217            self.select_columns(keys)
218        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
219            // df[[col1, col2, col3]]
220            self.select_columns(keys)
221        } else {
222            let message = "DataFrame can only be indexed by string index or indices".to_string();
223            Err(PyDataFusionError::Common(message))
224        }
225    }
226
227    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
228        // Get the Python formatter config
229        let PythonFormatter {
230            formatter: _,
231            config,
232        } = get_python_formatter_with_config(py)?;
233        let (batches, has_more) = wait_for_future(
234            py,
235            collect_record_batches_to_display(self.df.as_ref().clone(), config),
236        )?;
237        if batches.is_empty() {
238            // This should not be reached, but do it for safety since we index into the vector below
239            return Ok("No data to display".to_string());
240        }
241
242        let batches_as_displ =
243            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
244
245        let additional_str = match has_more {
246            true => "\nData truncated.",
247            false => "",
248        };
249
250        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
251    }
252
253    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
254        // Get the Python formatter and config
255        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
256        let (batches, has_more) = wait_for_future(
257            py,
258            collect_record_batches_to_display(self.df.as_ref().clone(), config),
259        )?;
260        if batches.is_empty() {
261            // This should not be reached, but do it for safety since we index into the vector below
262            return Ok("No data to display".to_string());
263        }
264
265        let table_uuid = uuid::Uuid::new_v4().to_string();
266
267        // Convert record batches to PyObject list
268        let py_batches = batches
269            .into_iter()
270            .map(|rb| rb.to_pyarrow(py))
271            .collect::<PyResult<Vec<PyObject>>>()?;
272
273        let py_schema = self.schema().into_pyobject(py)?;
274
275        let kwargs = pyo3::types::PyDict::new(py);
276        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
277        kwargs.set_item("batches", py_batches_list)?;
278        kwargs.set_item("schema", py_schema)?;
279        kwargs.set_item("has_more", has_more)?;
280        kwargs.set_item("table_uuid", table_uuid)?;
281
282        let html_result = formatter.call_method("format_html", (), Some(&kwargs))?;
283        let html_str: String = html_result.extract()?;
284
285        Ok(html_str)
286    }
287
288    /// Calculate summary statistics for a DataFrame
289    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
290        let df = self.df.as_ref().clone();
291        let stat_df = wait_for_future(py, df.describe())?;
292        Ok(Self::new(stat_df))
293    }
294
295    /// Returns the schema from the logical plan
296    fn schema(&self) -> PyArrowType<Schema> {
297        PyArrowType(self.df.schema().into())
298    }
299
300    /// Convert this DataFrame into a Table that can be used in register_table
301    /// By convention, into_... methods consume self and return the new object.
302    /// Disabling the clippy lint, so we can use &self
303    /// because we're working with Python bindings
304    /// where objects are shared
305    /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
306    /// - we have not decided on the table_provider approach yet
307    #[allow(clippy::wrong_self_convention)]
308    fn into_view(&self) -> PyDataFusionResult<PyTable> {
309        // Call the underlying Rust DataFrame::into_view method.
310        // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
311        // so that we don’t invalidate this PyDataFrame.
312        let table_provider = self.df.as_ref().clone().into_view();
313        let table_provider = PyTableProvider::new(table_provider);
314
315        Ok(table_provider.as_table())
316    }
317
318    #[pyo3(signature = (*args))]
319    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
320        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
321        let df = self.df.as_ref().clone().select_columns(&args)?;
322        Ok(Self::new(df))
323    }
324
325    #[pyo3(signature = (*args))]
326    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
327        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
328        let df = self.df.as_ref().clone().select(expr)?;
329        Ok(Self::new(df))
330    }
331
332    #[pyo3(signature = (*args))]
333    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
334        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
335        let df = self.df.as_ref().clone().drop_columns(&cols)?;
336        Ok(Self::new(df))
337    }
338
339    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
340        let df = self.df.as_ref().clone().filter(predicate.into())?;
341        Ok(Self::new(df))
342    }
343
344    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
345        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
346        Ok(Self::new(df))
347    }
348
349    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
350        let mut df = self.df.as_ref().clone();
351        for expr in exprs {
352            let expr: Expr = expr.into();
353            let name = format!("{}", expr.schema_name());
354            df = df.with_column(name.as_str(), expr)?
355        }
356        Ok(Self::new(df))
357    }
358
359    /// Rename one column by applying a new projection. This is a no-op if the column to be
360    /// renamed does not exist.
361    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
362        let df = self
363            .df
364            .as_ref()
365            .clone()
366            .with_column_renamed(old_name, new_name)?;
367        Ok(Self::new(df))
368    }
369
370    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
371        let group_by = group_by.into_iter().map(|e| e.into()).collect();
372        let aggs = aggs.into_iter().map(|e| e.into()).collect();
373        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
374        Ok(Self::new(df))
375    }
376
377    #[pyo3(signature = (*exprs))]
378    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
379        let exprs = to_sort_expressions(exprs);
380        let df = self.df.as_ref().clone().sort(exprs)?;
381        Ok(Self::new(df))
382    }
383
384    #[pyo3(signature = (count, offset=0))]
385    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
386        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
387        Ok(Self::new(df))
388    }
389
390    /// Executes the plan, returning a list of `RecordBatch`es.
391    /// Unless some order is specified in the plan, there is no
392    /// guarantee of the order of the result.
393    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
394        let batches = wait_for_future(py, self.df.as_ref().clone().collect())
395            .map_err(PyDataFusionError::from)?;
396        // cannot use PyResult<Vec<RecordBatch>> return type due to
397        // https://github.com/PyO3/pyo3/issues/1813
398        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
399    }
400
401    /// Cache DataFrame.
402    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
403        let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
404        Ok(Self::new(df))
405    }
406
407    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
408    /// maintaining the input partitioning.
409    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
410        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
411            .map_err(PyDataFusionError::from)?;
412
413        batches
414            .into_iter()
415            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
416            .collect()
417    }
418
419    /// Print the result, 20 lines by default
420    #[pyo3(signature = (num=20))]
421    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
422        let df = self.df.as_ref().clone().limit(0, Some(num))?;
423        print_dataframe(py, df)
424    }
425
426    /// Filter out duplicate rows
427    fn distinct(&self) -> PyDataFusionResult<Self> {
428        let df = self.df.as_ref().clone().distinct()?;
429        Ok(Self::new(df))
430    }
431
432    fn join(
433        &self,
434        right: PyDataFrame,
435        how: &str,
436        left_on: Vec<PyBackedStr>,
437        right_on: Vec<PyBackedStr>,
438    ) -> PyDataFusionResult<Self> {
439        let join_type = match how {
440            "inner" => JoinType::Inner,
441            "left" => JoinType::Left,
442            "right" => JoinType::Right,
443            "full" => JoinType::Full,
444            "semi" => JoinType::LeftSemi,
445            "anti" => JoinType::LeftAnti,
446            how => {
447                return Err(PyDataFusionError::Common(format!(
448                    "The join type {how} does not exist or is not implemented"
449                )));
450            }
451        };
452
453        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
454        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
455
456        let df = self.df.as_ref().clone().join(
457            right.df.as_ref().clone(),
458            join_type,
459            &left_keys,
460            &right_keys,
461            None,
462        )?;
463        Ok(Self::new(df))
464    }
465
466    fn join_on(
467        &self,
468        right: PyDataFrame,
469        on_exprs: Vec<PyExpr>,
470        how: &str,
471    ) -> PyDataFusionResult<Self> {
472        let join_type = match how {
473            "inner" => JoinType::Inner,
474            "left" => JoinType::Left,
475            "right" => JoinType::Right,
476            "full" => JoinType::Full,
477            "semi" => JoinType::LeftSemi,
478            "anti" => JoinType::LeftAnti,
479            how => {
480                return Err(PyDataFusionError::Common(format!(
481                    "The join type {how} does not exist or is not implemented"
482                )));
483            }
484        };
485        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
486
487        let df = self
488            .df
489            .as_ref()
490            .clone()
491            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
492        Ok(Self::new(df))
493    }
494
495    /// Print the query plan
496    #[pyo3(signature = (verbose=false, analyze=false))]
497    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
498        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
499        print_dataframe(py, df)
500    }
501
502    /// Get the logical plan for this `DataFrame`
503    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
504        Ok(self.df.as_ref().clone().logical_plan().clone().into())
505    }
506
507    /// Get the optimized logical plan for this `DataFrame`
508    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
509        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
510    }
511
512    /// Get the execution plan for this `DataFrame`
513    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
514        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
515        Ok(plan.into())
516    }
517
518    /// Repartition a `DataFrame` based on a logical partitioning scheme.
519    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
520        let new_df = self
521            .df
522            .as_ref()
523            .clone()
524            .repartition(Partitioning::RoundRobinBatch(num))?;
525        Ok(Self::new(new_df))
526    }
527
528    /// Repartition a `DataFrame` based on a logical partitioning scheme.
529    #[pyo3(signature = (*args, num))]
530    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
531        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
532        let new_df = self
533            .df
534            .as_ref()
535            .clone()
536            .repartition(Partitioning::Hash(expr, num))?;
537        Ok(Self::new(new_df))
538    }
539
540    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
541    /// two `DataFrame`s must have exactly the same schema
542    #[pyo3(signature = (py_df, distinct=false))]
543    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
544        let new_df = if distinct {
545            self.df
546                .as_ref()
547                .clone()
548                .union_distinct(py_df.df.as_ref().clone())?
549        } else {
550            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
551        };
552
553        Ok(Self::new(new_df))
554    }
555
556    /// Calculate the distinct union of two `DataFrame`s.  The
557    /// two `DataFrame`s must have exactly the same schema
558    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
559        let new_df = self
560            .df
561            .as_ref()
562            .clone()
563            .union_distinct(py_df.df.as_ref().clone())?;
564        Ok(Self::new(new_df))
565    }
566
567    #[pyo3(signature = (column, preserve_nulls=true))]
568    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
569        // TODO: expose RecursionUnnestOptions
570        // REF: https://github.com/apache/datafusion/pull/11577
571        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
572        let df = self
573            .df
574            .as_ref()
575            .clone()
576            .unnest_columns_with_options(&[column], unnest_options)?;
577        Ok(Self::new(df))
578    }
579
580    #[pyo3(signature = (columns, preserve_nulls=true))]
581    fn unnest_columns(
582        &self,
583        columns: Vec<String>,
584        preserve_nulls: bool,
585    ) -> PyDataFusionResult<Self> {
586        // TODO: expose RecursionUnnestOptions
587        // REF: https://github.com/apache/datafusion/pull/11577
588        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
589        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
590        let df = self
591            .df
592            .as_ref()
593            .clone()
594            .unnest_columns_with_options(&cols, unnest_options)?;
595        Ok(Self::new(df))
596    }
597
598    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
599    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
600        let new_df = self
601            .df
602            .as_ref()
603            .clone()
604            .intersect(py_df.df.as_ref().clone())?;
605        Ok(Self::new(new_df))
606    }
607
608    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
609    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
610        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
611        Ok(Self::new(new_df))
612    }
613
614    /// Write a `DataFrame` to a CSV file.
615    fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
616        let csv_options = CsvOptions {
617            has_header: Some(with_header),
618            ..Default::default()
619        };
620        wait_for_future(
621            py,
622            self.df.as_ref().clone().write_csv(
623                path,
624                DataFrameWriteOptions::new(),
625                Some(csv_options),
626            ),
627        )?;
628        Ok(())
629    }
630
631    /// Write a `DataFrame` to a Parquet file.
632    #[pyo3(signature = (
633        path,
634        compression="zstd",
635        compression_level=None
636        ))]
637    fn write_parquet(
638        &self,
639        path: &str,
640        compression: &str,
641        compression_level: Option<u32>,
642        py: Python,
643    ) -> PyDataFusionResult<()> {
644        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
645            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
646        }
647
648        let _validated = match compression.to_lowercase().as_str() {
649            "snappy" => Compression::SNAPPY,
650            "gzip" => Compression::GZIP(
651                GzipLevel::try_new(compression_level.unwrap_or(6))
652                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
653            ),
654            "brotli" => Compression::BROTLI(
655                BrotliLevel::try_new(verify_compression_level(compression_level)?)
656                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
657            ),
658            "zstd" => Compression::ZSTD(
659                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
660                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
661            ),
662            "lzo" => Compression::LZO,
663            "lz4" => Compression::LZ4,
664            "lz4_raw" => Compression::LZ4_RAW,
665            "uncompressed" => Compression::UNCOMPRESSED,
666            _ => {
667                return Err(PyDataFusionError::Common(format!(
668                    "Unrecognized compression type {compression}"
669                )));
670            }
671        };
672
673        let mut compression_string = compression.to_string();
674        if let Some(level) = compression_level {
675            compression_string.push_str(&format!("({level})"));
676        }
677
678        let mut options = TableParquetOptions::default();
679        options.global.compression = Some(compression_string);
680
681        wait_for_future(
682            py,
683            self.df.as_ref().clone().write_parquet(
684                path,
685                DataFrameWriteOptions::new(),
686                Option::from(options),
687            ),
688        )?;
689        Ok(())
690    }
691
692    /// Executes a query and writes the results to a partitioned JSON file.
693    fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
694        wait_for_future(
695            py,
696            self.df
697                .as_ref()
698                .clone()
699                .write_json(path, DataFrameWriteOptions::new(), None),
700        )?;
701        Ok(())
702    }
703
704    /// Convert to Arrow Table
705    /// Collect the batches and pass to Arrow Table
706    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
707        let batches = self.collect(py)?.into_pyobject(py)?;
708        let schema = self.schema().into_pyobject(py)?;
709
710        // Instantiate pyarrow Table object and use its from_batches method
711        let table_class = py.import("pyarrow")?.getattr("Table")?;
712        let args = PyTuple::new(py, &[batches, schema])?;
713        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
714        Ok(table)
715    }
716
717    #[pyo3(signature = (requested_schema=None))]
718    fn __arrow_c_stream__<'py>(
719        &'py mut self,
720        py: Python<'py>,
721        requested_schema: Option<Bound<'py, PyCapsule>>,
722    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
723        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
724        let mut schema: Schema = self.df.schema().to_owned().into();
725
726        if let Some(schema_capsule) = requested_schema {
727            validate_pycapsule(&schema_capsule, "arrow_schema")?;
728
729            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
730            let desired_schema = Schema::try_from(schema_ptr)?;
731
732            schema = project_schema(schema, desired_schema)?;
733
734            batches = batches
735                .into_iter()
736                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
737                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
738        }
739
740        let batches_wrapped = batches.into_iter().map(Ok);
741
742        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
743        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
744
745        let ffi_stream = FFI_ArrowArrayStream::new(reader);
746        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
747        PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
748    }
749
750    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
751        // create a Tokio runtime to run the async code
752        let rt = &get_tokio_runtime().0;
753        let df = self.df.as_ref().clone();
754        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
755            rt.spawn(async move { df.execute_stream().await });
756        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
757        Ok(PyRecordBatchStream::new(stream?))
758    }
759
760    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
761        // create a Tokio runtime to run the async code
762        let rt = &get_tokio_runtime().0;
763        let df = self.df.as_ref().clone();
764        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
765            rt.spawn(async move { df.execute_stream_partitioned().await });
766        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
767
768        match stream {
769            Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
770            _ => Err(PyValueError::new_err(
771                "Unable to execute stream partitioned",
772            )),
773        }
774    }
775
776    /// Convert to pandas dataframe with pyarrow
777    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
778    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
779        let table = self.to_arrow_table(py)?;
780
781        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
782        let result = table.call_method0(py, "to_pandas")?;
783        Ok(result)
784    }
785
786    /// Convert to Python list using pyarrow
787    /// Each list item represents one row encoded as dictionary
788    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
789        let table = self.to_arrow_table(py)?;
790
791        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
792        let result = table.call_method0(py, "to_pylist")?;
793        Ok(result)
794    }
795
796    /// Convert to Python dictionary using pyarrow
797    /// Each dictionary key is a column and the dictionary value represents the column values
798    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
799        let table = self.to_arrow_table(py)?;
800
801        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
802        let result = table.call_method0(py, "to_pydict")?;
803        Ok(result)
804    }
805
806    /// Convert to polars dataframe with pyarrow
807    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
808    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
809        let table = self.to_arrow_table(py)?;
810        let dataframe = py.import("polars")?.getattr("DataFrame")?;
811        let args = PyTuple::new(py, &[table])?;
812        let result: PyObject = dataframe.call1(args)?.into();
813        Ok(result)
814    }
815
816    // Executes this DataFrame to get the total number of rows.
817    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
818        Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
819    }
820
821    /// Fill null values with a specified value for specific columns
822    #[pyo3(signature = (value, columns=None))]
823    fn fill_null(
824        &self,
825        value: PyObject,
826        columns: Option<Vec<PyBackedStr>>,
827        py: Python,
828    ) -> PyDataFusionResult<Self> {
829        let scalar_value = py_obj_to_scalar_value(py, value)?;
830
831        let cols = match columns {
832            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
833            None => Vec::new(), // Empty vector means fill null for all columns
834        };
835
836        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
837        Ok(Self::new(df))
838    }
839}
840
841/// Print DataFrame
842fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
843    // Get string representation of record batches
844    let batches = wait_for_future(py, df.collect())?;
845    let batches_as_string = pretty::pretty_format_batches(&batches);
846    let result = match batches_as_string {
847        Ok(batch) => format!("DataFrame()\n{batch}"),
848        Err(err) => format!("Error: {:?}", err.to_string()),
849    };
850
851    // Import the Python 'builtins' module to access the print function
852    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
853    let print = py.import("builtins")?.getattr("print")?;
854    print.call1((result,))?;
855    Ok(())
856}
857
858fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
859    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
860
861    let project_indices: Vec<usize> = to_schema
862        .fields
863        .iter()
864        .map(|field| field.name())
865        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
866        .collect();
867
868    merged_schema.project(&project_indices)
869}
870
871fn record_batch_into_schema(
872    record_batch: RecordBatch,
873    schema: &Schema,
874) -> Result<RecordBatch, ArrowError> {
875    let schema = Arc::new(schema.clone());
876    let base_schema = record_batch.schema();
877    if base_schema.fields().is_empty() {
878        // Nothing to project
879        return Ok(RecordBatch::new_empty(schema));
880    }
881
882    let array_size = record_batch.column(0).len();
883    let mut data_arrays = Vec::with_capacity(schema.fields().len());
884
885    for field in schema.fields() {
886        let desired_data_type = field.data_type();
887        if let Some(original_data) = record_batch.column_by_name(field.name()) {
888            let original_data_type = original_data.data_type();
889
890            if can_cast_types(original_data_type, desired_data_type) {
891                data_arrays.push(arrow::compute::kernels::cast(
892                    original_data,
893                    desired_data_type,
894                )?);
895            } else if field.is_nullable() {
896                data_arrays.push(new_null_array(desired_data_type, array_size));
897            } else {
898                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
899            }
900        } else {
901            if !field.is_nullable() {
902                return Err(ArrowError::CastError(format!(
903                    "Attempting to set null to non-nullable field {} during schema projection.",
904                    field.name()
905                )));
906            }
907            data_arrays.push(new_null_array(desired_data_type, array_size));
908        }
909    }
910
911    RecordBatch::try_new(schema, data_arrays)
912}
913
914/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
915/// It additionally returns a bool, which indicates if there are more record batches available.
916/// We do this so we can determine if we should indicate to the user that the data has been
917/// truncated. This collects until we have achived both of these two conditions
918///
919/// - We have collected our minimum number of rows
920/// - We have reached our limit, either data size or maximum number of rows
921///
922/// Otherwise it will return when the stream has exhausted. If you want a specific number of
923/// rows, set min_rows == max_rows.
924async fn collect_record_batches_to_display(
925    df: DataFrame,
926    config: FormatterConfig,
927) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
928    let FormatterConfig {
929        max_bytes,
930        min_rows,
931        repr_rows,
932    } = config;
933
934    let partitioned_stream = df.execute_stream_partitioned().await?;
935    let mut stream = futures::stream::iter(partitioned_stream).flatten();
936    let mut size_estimate_so_far = 0;
937    let mut rows_so_far = 0;
938    let mut record_batches = Vec::default();
939    let mut has_more = false;
940
941    // ensure minimum rows even if memory/row limits are hit
942    while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
943        let mut rb = match stream.next().await {
944            None => {
945                break;
946            }
947            Some(Ok(r)) => r,
948            Some(Err(e)) => return Err(e),
949        };
950
951        let mut rows_in_rb = rb.num_rows();
952        if rows_in_rb > 0 {
953            size_estimate_so_far += rb.get_array_memory_size();
954
955            if size_estimate_so_far > max_bytes {
956                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
957                let total_rows = rows_in_rb + rows_so_far;
958
959                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
960                if reduced_row_num < min_rows {
961                    reduced_row_num = min_rows.min(total_rows);
962                }
963
964                let limited_rows_this_rb = reduced_row_num - rows_so_far;
965                if limited_rows_this_rb < rows_in_rb {
966                    rows_in_rb = limited_rows_this_rb;
967                    rb = rb.slice(0, limited_rows_this_rb);
968                    has_more = true;
969                }
970            }
971
972            if rows_in_rb + rows_so_far > repr_rows {
973                rb = rb.slice(0, repr_rows - rows_so_far);
974                has_more = true;
975            }
976
977            rows_so_far += rb.num_rows();
978            record_batches.push(rb);
979        }
980    }
981
982    if record_batches.is_empty() {
983        return Ok((Vec::default(), false));
984    }
985
986    if !has_more {
987        // Data was not already truncated, so check to see if more record batches remain
988        has_more = match stream.try_next().await {
989            Ok(None) => false, // reached end
990            Ok(Some(_)) => true,
991            Err(_) => false, // Stream disconnected
992        };
993    }
994
995    Ok((record_batches, has_more))
996}