datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::CString;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22use arrow::compute::can_cast_types;
23use arrow::error::ArrowError;
24use arrow::ffi::FFI_ArrowSchema;
25use arrow::ffi_stream::FFI_ArrowArrayStream;
26use arrow::util::display::{ArrayFormatter, FormatOptions};
27use datafusion::arrow::datatypes::Schema;
28use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
29use datafusion::arrow::util::pretty;
30use datafusion::common::UnnestOptions;
31use datafusion::config::{CsvOptions, TableParquetOptions};
32use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33use datafusion::datasource::TableProvider;
34use datafusion::error::DataFusionError;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
37use datafusion::prelude::*;
38use futures::{StreamExt, TryStreamExt};
39use pyo3::exceptions::PyValueError;
40use pyo3::prelude::*;
41use pyo3::pybacked::PyBackedStr;
42use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
43use tokio::task::JoinHandle;
44
45use crate::catalog::PyTable;
46use crate::errors::{py_datafusion_err, PyDataFusionError};
47use crate::expr::sort_expr::to_sort_expressions;
48use crate::physical_plan::PyExecutionPlan;
49use crate::record_batch::PyRecordBatchStream;
50use crate::sql::logical::PyLogicalPlan;
51use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
52use crate::{
53    errors::PyDataFusionResult,
54    expr::{sort_expr::PySortExpr, PyExpr},
55};
56
57// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
58// - we have not decided on the table_provider approach yet
59// this is an interim implementation
60#[pyclass(name = "TableProvider", module = "datafusion")]
61pub struct PyTableProvider {
62    provider: Arc<dyn TableProvider>,
63}
64
65impl PyTableProvider {
66    pub fn new(provider: Arc<dyn TableProvider>) -> Self {
67        Self { provider }
68    }
69
70    pub fn as_table(&self) -> PyTable {
71        let table_provider: Arc<dyn TableProvider> = self.provider.clone();
72        PyTable::new(table_provider)
73    }
74}
75const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
76const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
77const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
78
79/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
80/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
81/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
82#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
83#[derive(Clone)]
84pub struct PyDataFrame {
85    df: Arc<DataFrame>,
86}
87
88impl PyDataFrame {
89    /// creates a new PyDataFrame
90    pub fn new(df: DataFrame) -> Self {
91        Self { df: Arc::new(df) }
92    }
93}
94
95#[pymethods]
96impl PyDataFrame {
97    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
98    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
99        if let Ok(key) = key.extract::<PyBackedStr>() {
100            // df[col]
101            self.select_columns(vec![key])
102        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
103            // df[col1, col2, col3]
104            let keys = tuple
105                .iter()
106                .map(|item| item.extract::<PyBackedStr>())
107                .collect::<PyResult<Vec<PyBackedStr>>>()?;
108            self.select_columns(keys)
109        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
110            // df[[col1, col2, col3]]
111            self.select_columns(keys)
112        } else {
113            let message = "DataFrame can only be indexed by string index or indices".to_string();
114            Err(PyDataFusionError::Common(message))
115        }
116    }
117
118    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
119        let (batches, has_more) = wait_for_future(
120            py,
121            collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
122        )?;
123        if batches.is_empty() {
124            // This should not be reached, but do it for safety since we index into the vector below
125            return Ok("No data to display".to_string());
126        }
127
128        let batches_as_displ =
129            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
130
131        let additional_str = match has_more {
132            true => "\nData truncated.",
133            false => "",
134        };
135
136        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
137    }
138
139    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
140        let (batches, has_more) = wait_for_future(
141            py,
142            collect_record_batches_to_display(
143                self.df.as_ref().clone(),
144                MIN_TABLE_ROWS_TO_DISPLAY,
145                usize::MAX,
146            ),
147        )?;
148        if batches.is_empty() {
149            // This should not be reached, but do it for safety since we index into the vector below
150            return Ok("No data to display".to_string());
151        }
152
153        let table_uuid = uuid::Uuid::new_v4().to_string();
154
155        let mut html_str = "
156        <style>
157            .expandable-container {
158                display: inline-block;
159                max-width: 200px;
160            }
161            .expandable {
162                white-space: nowrap;
163                overflow: hidden;
164                text-overflow: ellipsis;
165                display: block;
166            }
167            .full-text {
168                display: none;
169                white-space: normal;
170            }
171            .expand-btn {
172                cursor: pointer;
173                color: blue;
174                text-decoration: underline;
175                border: none;
176                background: none;
177                font-size: inherit;
178                display: block;
179                margin-top: 5px;
180            }
181        </style>
182
183        <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
184            <table style=\"border-collapse: collapse; min-width: 100%\">
185                <thead>\n".to_string();
186
187        let schema = batches[0].schema();
188
189        let mut header = Vec::new();
190        for field in schema.fields() {
191            header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
192        }
193        let header_str = header.join("");
194        html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
195
196        let batch_formatters = batches
197            .iter()
198            .map(|batch| {
199                batch
200                    .columns()
201                    .iter()
202                    .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
203                    .map(|c| {
204                        c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
205                    })
206                    .collect::<Result<Vec<_>, _>>()
207            })
208            .collect::<Result<Vec<_>, _>>()?;
209
210        let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
211
212        // We need to build up row by row for html
213        let mut table_row = 0;
214        for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
215            for batch_row in 0..num_rows_in_batch {
216                table_row += 1;
217                let mut cells = Vec::new();
218                for (col, formatter) in batch_formatter.iter().enumerate() {
219                    let cell_data = formatter.value(batch_row).to_string();
220                    // From testing, primitive data types do not typically get larger than 21 characters
221                    if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
222                        let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
223                        cells.push(format!("
224                            <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
225                                <div class=\"expandable-container\">
226                                    <span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span>
227                                    <span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span>
228                                    <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button>
229                                </div>
230                            </td>"));
231                    } else {
232                        cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row)));
233                    }
234                }
235                let row_str = cells.join("");
236                html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
237            }
238        }
239        html_str.push_str("</tbody></table></div>\n");
240
241        html_str.push_str("
242            <script>
243            function toggleDataFrameCellText(table_uuid, row, col) {
244                var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col);
245                var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col);
246                var button = event.target;
247
248                if (fullText.style.display === \"none\") {
249                    shortText.style.display = \"none\";
250                    fullText.style.display = \"inline\";
251                    button.textContent = \"(less)\";
252                } else {
253                    shortText.style.display = \"inline\";
254                    fullText.style.display = \"none\";
255                    button.textContent = \"...\";
256                }
257            }
258            </script>
259        ");
260
261        if has_more {
262            html_str.push_str("Data truncated due to size.");
263        }
264
265        Ok(html_str)
266    }
267
268    /// Calculate summary statistics for a DataFrame
269    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
270        let df = self.df.as_ref().clone();
271        let stat_df = wait_for_future(py, df.describe())?;
272        Ok(Self::new(stat_df))
273    }
274
275    /// Returns the schema from the logical plan
276    fn schema(&self) -> PyArrowType<Schema> {
277        PyArrowType(self.df.schema().into())
278    }
279
280    /// Convert this DataFrame into a Table that can be used in register_table
281    /// By convention, into_... methods consume self and return the new object.
282    /// Disabling the clippy lint, so we can use &self
283    /// because we're working with Python bindings
284    /// where objects are shared
285    /// https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
286    /// - we have not decided on the table_provider approach yet
287    #[allow(clippy::wrong_self_convention)]
288    fn into_view(&self) -> PyDataFusionResult<PyTable> {
289        // Call the underlying Rust DataFrame::into_view method.
290        // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
291        // so that we don’t invalidate this PyDataFrame.
292        let table_provider = self.df.as_ref().clone().into_view();
293        let table_provider = PyTableProvider::new(table_provider);
294
295        Ok(table_provider.as_table())
296    }
297
298    #[pyo3(signature = (*args))]
299    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
300        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
301        let df = self.df.as_ref().clone().select_columns(&args)?;
302        Ok(Self::new(df))
303    }
304
305    #[pyo3(signature = (*args))]
306    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
307        let expr = args.into_iter().map(|e| e.into()).collect();
308        let df = self.df.as_ref().clone().select(expr)?;
309        Ok(Self::new(df))
310    }
311
312    #[pyo3(signature = (*args))]
313    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
314        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
315        let df = self.df.as_ref().clone().drop_columns(&cols)?;
316        Ok(Self::new(df))
317    }
318
319    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
320        let df = self.df.as_ref().clone().filter(predicate.into())?;
321        Ok(Self::new(df))
322    }
323
324    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
325        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
326        Ok(Self::new(df))
327    }
328
329    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
330        let mut df = self.df.as_ref().clone();
331        for expr in exprs {
332            let expr: Expr = expr.into();
333            let name = format!("{}", expr.schema_name());
334            df = df.with_column(name.as_str(), expr)?
335        }
336        Ok(Self::new(df))
337    }
338
339    /// Rename one column by applying a new projection. This is a no-op if the column to be
340    /// renamed does not exist.
341    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
342        let df = self
343            .df
344            .as_ref()
345            .clone()
346            .with_column_renamed(old_name, new_name)?;
347        Ok(Self::new(df))
348    }
349
350    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
351        let group_by = group_by.into_iter().map(|e| e.into()).collect();
352        let aggs = aggs.into_iter().map(|e| e.into()).collect();
353        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
354        Ok(Self::new(df))
355    }
356
357    #[pyo3(signature = (*exprs))]
358    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
359        let exprs = to_sort_expressions(exprs);
360        let df = self.df.as_ref().clone().sort(exprs)?;
361        Ok(Self::new(df))
362    }
363
364    #[pyo3(signature = (count, offset=0))]
365    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
366        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
367        Ok(Self::new(df))
368    }
369
370    /// Executes the plan, returning a list of `RecordBatch`es.
371    /// Unless some order is specified in the plan, there is no
372    /// guarantee of the order of the result.
373    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
374        let batches = wait_for_future(py, self.df.as_ref().clone().collect())
375            .map_err(PyDataFusionError::from)?;
376        // cannot use PyResult<Vec<RecordBatch>> return type due to
377        // https://github.com/PyO3/pyo3/issues/1813
378        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
379    }
380
381    /// Cache DataFrame.
382    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
383        let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
384        Ok(Self::new(df))
385    }
386
387    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
388    /// maintaining the input partitioning.
389    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
390        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
391            .map_err(PyDataFusionError::from)?;
392
393        batches
394            .into_iter()
395            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
396            .collect()
397    }
398
399    /// Print the result, 20 lines by default
400    #[pyo3(signature = (num=20))]
401    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
402        let df = self.df.as_ref().clone().limit(0, Some(num))?;
403        print_dataframe(py, df)
404    }
405
406    /// Filter out duplicate rows
407    fn distinct(&self) -> PyDataFusionResult<Self> {
408        let df = self.df.as_ref().clone().distinct()?;
409        Ok(Self::new(df))
410    }
411
412    fn join(
413        &self,
414        right: PyDataFrame,
415        how: &str,
416        left_on: Vec<PyBackedStr>,
417        right_on: Vec<PyBackedStr>,
418    ) -> PyDataFusionResult<Self> {
419        let join_type = match how {
420            "inner" => JoinType::Inner,
421            "left" => JoinType::Left,
422            "right" => JoinType::Right,
423            "full" => JoinType::Full,
424            "semi" => JoinType::LeftSemi,
425            "anti" => JoinType::LeftAnti,
426            how => {
427                return Err(PyDataFusionError::Common(format!(
428                    "The join type {how} does not exist or is not implemented"
429                )));
430            }
431        };
432
433        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
434        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
435
436        let df = self.df.as_ref().clone().join(
437            right.df.as_ref().clone(),
438            join_type,
439            &left_keys,
440            &right_keys,
441            None,
442        )?;
443        Ok(Self::new(df))
444    }
445
446    fn join_on(
447        &self,
448        right: PyDataFrame,
449        on_exprs: Vec<PyExpr>,
450        how: &str,
451    ) -> PyDataFusionResult<Self> {
452        let join_type = match how {
453            "inner" => JoinType::Inner,
454            "left" => JoinType::Left,
455            "right" => JoinType::Right,
456            "full" => JoinType::Full,
457            "semi" => JoinType::LeftSemi,
458            "anti" => JoinType::LeftAnti,
459            how => {
460                return Err(PyDataFusionError::Common(format!(
461                    "The join type {how} does not exist or is not implemented"
462                )));
463            }
464        };
465        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
466
467        let df = self
468            .df
469            .as_ref()
470            .clone()
471            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
472        Ok(Self::new(df))
473    }
474
475    /// Print the query plan
476    #[pyo3(signature = (verbose=false, analyze=false))]
477    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
478        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
479        print_dataframe(py, df)
480    }
481
482    /// Get the logical plan for this `DataFrame`
483    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
484        Ok(self.df.as_ref().clone().logical_plan().clone().into())
485    }
486
487    /// Get the optimized logical plan for this `DataFrame`
488    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
489        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
490    }
491
492    /// Get the execution plan for this `DataFrame`
493    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
494        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
495        Ok(plan.into())
496    }
497
498    /// Repartition a `DataFrame` based on a logical partitioning scheme.
499    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
500        let new_df = self
501            .df
502            .as_ref()
503            .clone()
504            .repartition(Partitioning::RoundRobinBatch(num))?;
505        Ok(Self::new(new_df))
506    }
507
508    /// Repartition a `DataFrame` based on a logical partitioning scheme.
509    #[pyo3(signature = (*args, num))]
510    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
511        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
512        let new_df = self
513            .df
514            .as_ref()
515            .clone()
516            .repartition(Partitioning::Hash(expr, num))?;
517        Ok(Self::new(new_df))
518    }
519
520    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
521    /// two `DataFrame`s must have exactly the same schema
522    #[pyo3(signature = (py_df, distinct=false))]
523    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
524        let new_df = if distinct {
525            self.df
526                .as_ref()
527                .clone()
528                .union_distinct(py_df.df.as_ref().clone())?
529        } else {
530            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
531        };
532
533        Ok(Self::new(new_df))
534    }
535
536    /// Calculate the distinct union of two `DataFrame`s.  The
537    /// two `DataFrame`s must have exactly the same schema
538    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
539        let new_df = self
540            .df
541            .as_ref()
542            .clone()
543            .union_distinct(py_df.df.as_ref().clone())?;
544        Ok(Self::new(new_df))
545    }
546
547    #[pyo3(signature = (column, preserve_nulls=true))]
548    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
549        // TODO: expose RecursionUnnestOptions
550        // REF: https://github.com/apache/datafusion/pull/11577
551        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
552        let df = self
553            .df
554            .as_ref()
555            .clone()
556            .unnest_columns_with_options(&[column], unnest_options)?;
557        Ok(Self::new(df))
558    }
559
560    #[pyo3(signature = (columns, preserve_nulls=true))]
561    fn unnest_columns(
562        &self,
563        columns: Vec<String>,
564        preserve_nulls: bool,
565    ) -> PyDataFusionResult<Self> {
566        // TODO: expose RecursionUnnestOptions
567        // REF: https://github.com/apache/datafusion/pull/11577
568        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
569        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
570        let df = self
571            .df
572            .as_ref()
573            .clone()
574            .unnest_columns_with_options(&cols, unnest_options)?;
575        Ok(Self::new(df))
576    }
577
578    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
579    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
580        let new_df = self
581            .df
582            .as_ref()
583            .clone()
584            .intersect(py_df.df.as_ref().clone())?;
585        Ok(Self::new(new_df))
586    }
587
588    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
589    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
590        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
591        Ok(Self::new(new_df))
592    }
593
594    /// Write a `DataFrame` to a CSV file.
595    fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
596        let csv_options = CsvOptions {
597            has_header: Some(with_header),
598            ..Default::default()
599        };
600        wait_for_future(
601            py,
602            self.df.as_ref().clone().write_csv(
603                path,
604                DataFrameWriteOptions::new(),
605                Some(csv_options),
606            ),
607        )?;
608        Ok(())
609    }
610
611    /// Write a `DataFrame` to a Parquet file.
612    #[pyo3(signature = (
613        path,
614        compression="zstd",
615        compression_level=None
616        ))]
617    fn write_parquet(
618        &self,
619        path: &str,
620        compression: &str,
621        compression_level: Option<u32>,
622        py: Python,
623    ) -> PyDataFusionResult<()> {
624        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
625            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
626        }
627
628        let _validated = match compression.to_lowercase().as_str() {
629            "snappy" => Compression::SNAPPY,
630            "gzip" => Compression::GZIP(
631                GzipLevel::try_new(compression_level.unwrap_or(6))
632                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
633            ),
634            "brotli" => Compression::BROTLI(
635                BrotliLevel::try_new(verify_compression_level(compression_level)?)
636                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
637            ),
638            "zstd" => Compression::ZSTD(
639                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
640                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
641            ),
642            "lzo" => Compression::LZO,
643            "lz4" => Compression::LZ4,
644            "lz4_raw" => Compression::LZ4_RAW,
645            "uncompressed" => Compression::UNCOMPRESSED,
646            _ => {
647                return Err(PyDataFusionError::Common(format!(
648                    "Unrecognized compression type {compression}"
649                )));
650            }
651        };
652
653        let mut compression_string = compression.to_string();
654        if let Some(level) = compression_level {
655            compression_string.push_str(&format!("({level})"));
656        }
657
658        let mut options = TableParquetOptions::default();
659        options.global.compression = Some(compression_string);
660
661        wait_for_future(
662            py,
663            self.df.as_ref().clone().write_parquet(
664                path,
665                DataFrameWriteOptions::new(),
666                Option::from(options),
667            ),
668        )?;
669        Ok(())
670    }
671
672    /// Executes a query and writes the results to a partitioned JSON file.
673    fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
674        wait_for_future(
675            py,
676            self.df
677                .as_ref()
678                .clone()
679                .write_json(path, DataFrameWriteOptions::new(), None),
680        )?;
681        Ok(())
682    }
683
684    /// Convert to Arrow Table
685    /// Collect the batches and pass to Arrow Table
686    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
687        let batches = self.collect(py)?.into_pyobject(py)?;
688        let schema = self.schema().into_pyobject(py)?;
689
690        // Instantiate pyarrow Table object and use its from_batches method
691        let table_class = py.import("pyarrow")?.getattr("Table")?;
692        let args = PyTuple::new(py, &[batches, schema])?;
693        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
694        Ok(table)
695    }
696
697    #[pyo3(signature = (requested_schema=None))]
698    fn __arrow_c_stream__<'py>(
699        &'py mut self,
700        py: Python<'py>,
701        requested_schema: Option<Bound<'py, PyCapsule>>,
702    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
703        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
704        let mut schema: Schema = self.df.schema().to_owned().into();
705
706        if let Some(schema_capsule) = requested_schema {
707            validate_pycapsule(&schema_capsule, "arrow_schema")?;
708
709            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
710            let desired_schema = Schema::try_from(schema_ptr)?;
711
712            schema = project_schema(schema, desired_schema)?;
713
714            batches = batches
715                .into_iter()
716                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
717                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
718        }
719
720        let batches_wrapped = batches.into_iter().map(Ok);
721
722        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
723        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
724
725        let ffi_stream = FFI_ArrowArrayStream::new(reader);
726        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
727        PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
728    }
729
730    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
731        // create a Tokio runtime to run the async code
732        let rt = &get_tokio_runtime().0;
733        let df = self.df.as_ref().clone();
734        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
735            rt.spawn(async move { df.execute_stream().await });
736        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
737        Ok(PyRecordBatchStream::new(stream?))
738    }
739
740    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
741        // create a Tokio runtime to run the async code
742        let rt = &get_tokio_runtime().0;
743        let df = self.df.as_ref().clone();
744        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
745            rt.spawn(async move { df.execute_stream_partitioned().await });
746        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
747
748        match stream {
749            Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
750            _ => Err(PyValueError::new_err(
751                "Unable to execute stream partitioned",
752            )),
753        }
754    }
755
756    /// Convert to pandas dataframe with pyarrow
757    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
758    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
759        let table = self.to_arrow_table(py)?;
760
761        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
762        let result = table.call_method0(py, "to_pandas")?;
763        Ok(result)
764    }
765
766    /// Convert to Python list using pyarrow
767    /// Each list item represents one row encoded as dictionary
768    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
769        let table = self.to_arrow_table(py)?;
770
771        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
772        let result = table.call_method0(py, "to_pylist")?;
773        Ok(result)
774    }
775
776    /// Convert to Python dictionary using pyarrow
777    /// Each dictionary key is a column and the dictionary value represents the column values
778    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
779        let table = self.to_arrow_table(py)?;
780
781        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
782        let result = table.call_method0(py, "to_pydict")?;
783        Ok(result)
784    }
785
786    /// Convert to polars dataframe with pyarrow
787    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
788    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
789        let table = self.to_arrow_table(py)?;
790        let dataframe = py.import("polars")?.getattr("DataFrame")?;
791        let args = PyTuple::new(py, &[table])?;
792        let result: PyObject = dataframe.call1(args)?.into();
793        Ok(result)
794    }
795
796    // Executes this DataFrame to get the total number of rows.
797    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
798        Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
799    }
800}
801
802/// Print DataFrame
803fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
804    // Get string representation of record batches
805    let batches = wait_for_future(py, df.collect())?;
806    let batches_as_string = pretty::pretty_format_batches(&batches);
807    let result = match batches_as_string {
808        Ok(batch) => format!("DataFrame()\n{batch}"),
809        Err(err) => format!("Error: {:?}", err.to_string()),
810    };
811
812    // Import the Python 'builtins' module to access the print function
813    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
814    let print = py.import("builtins")?.getattr("print")?;
815    print.call1((result,))?;
816    Ok(())
817}
818
819fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
820    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
821
822    let project_indices: Vec<usize> = to_schema
823        .fields
824        .iter()
825        .map(|field| field.name())
826        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
827        .collect();
828
829    merged_schema.project(&project_indices)
830}
831
832fn record_batch_into_schema(
833    record_batch: RecordBatch,
834    schema: &Schema,
835) -> Result<RecordBatch, ArrowError> {
836    let schema = Arc::new(schema.clone());
837    let base_schema = record_batch.schema();
838    if base_schema.fields().len() == 0 {
839        // Nothing to project
840        return Ok(RecordBatch::new_empty(schema));
841    }
842
843    let array_size = record_batch.column(0).len();
844    let mut data_arrays = Vec::with_capacity(schema.fields().len());
845
846    for field in schema.fields() {
847        let desired_data_type = field.data_type();
848        if let Some(original_data) = record_batch.column_by_name(field.name()) {
849            let original_data_type = original_data.data_type();
850
851            if can_cast_types(original_data_type, desired_data_type) {
852                data_arrays.push(arrow::compute::kernels::cast(
853                    original_data,
854                    desired_data_type,
855                )?);
856            } else if field.is_nullable() {
857                data_arrays.push(new_null_array(desired_data_type, array_size));
858            } else {
859                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
860            }
861        } else {
862            if !field.is_nullable() {
863                return Err(ArrowError::CastError(format!(
864                    "Attempting to set null to non-nullable field {} during schema projection.",
865                    field.name()
866                )));
867            }
868            data_arrays.push(new_null_array(desired_data_type, array_size));
869        }
870    }
871
872    RecordBatch::try_new(schema, data_arrays)
873}
874
875/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
876/// It additionally returns a bool, which indicates if there are more record batches available.
877/// We do this so we can determine if we should indicate to the user that the data has been
878/// truncated. This collects until we have achived both of these two conditions
879///
880/// - We have collected our minimum number of rows
881/// - We have reached our limit, either data size or maximum number of rows
882///
883/// Otherwise it will return when the stream has exhausted. If you want a specific number of
884/// rows, set min_rows == max_rows.
885async fn collect_record_batches_to_display(
886    df: DataFrame,
887    min_rows: usize,
888    max_rows: usize,
889) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
890    let partitioned_stream = df.execute_stream_partitioned().await?;
891    let mut stream = futures::stream::iter(partitioned_stream).flatten();
892    let mut size_estimate_so_far = 0;
893    let mut rows_so_far = 0;
894    let mut record_batches = Vec::default();
895    let mut has_more = false;
896
897    while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
898        || rows_so_far < min_rows
899    {
900        let mut rb = match stream.next().await {
901            None => {
902                break;
903            }
904            Some(Ok(r)) => r,
905            Some(Err(e)) => return Err(e),
906        };
907
908        let mut rows_in_rb = rb.num_rows();
909        if rows_in_rb > 0 {
910            size_estimate_so_far += rb.get_array_memory_size();
911
912            if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
913                let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
914                let total_rows = rows_in_rb + rows_so_far;
915
916                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
917                if reduced_row_num < min_rows {
918                    reduced_row_num = min_rows.min(total_rows);
919                }
920
921                let limited_rows_this_rb = reduced_row_num - rows_so_far;
922                if limited_rows_this_rb < rows_in_rb {
923                    rows_in_rb = limited_rows_this_rb;
924                    rb = rb.slice(0, limited_rows_this_rb);
925                    has_more = true;
926                }
927            }
928
929            if rows_in_rb + rows_so_far > max_rows {
930                rb = rb.slice(0, max_rows - rows_so_far);
931                has_more = true;
932            }
933
934            rows_so_far += rb.num_rows();
935            record_batches.push(rb);
936        }
937    }
938
939    if record_batches.is_empty() {
940        return Ok((Vec::default(), false));
941    }
942
943    if !has_more {
944        // Data was not already truncated, so check to see if more record batches remain
945        has_more = match stream.try_next().await {
946            Ok(None) => false, // reached end
947            Ok(Some(_)) => true,
948            Err(_) => false, // Stream disconnected
949        };
950    }
951
952    Ok((record_batches, has_more))
953}