datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::ffi::CString;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22use arrow::compute::can_cast_types;
23use arrow::error::ArrowError;
24use arrow::ffi::FFI_ArrowSchema;
25use arrow::ffi_stream::FFI_ArrowArrayStream;
26use arrow::util::display::{ArrayFormatter, FormatOptions};
27use datafusion::arrow::datatypes::Schema;
28use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
29use datafusion::arrow::util::pretty;
30use datafusion::common::UnnestOptions;
31use datafusion::config::{CsvOptions, TableParquetOptions};
32use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33use datafusion::execution::SendableRecordBatchStream;
34use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
35use datafusion::prelude::*;
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38use pyo3::pybacked::PyBackedStr;
39use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
40use tokio::task::JoinHandle;
41
42use crate::errors::{py_datafusion_err, PyDataFusionError};
43use crate::expr::sort_expr::to_sort_expressions;
44use crate::physical_plan::PyExecutionPlan;
45use crate::record_batch::PyRecordBatchStream;
46use crate::sql::logical::PyLogicalPlan;
47use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
48use crate::{
49    errors::PyDataFusionResult,
50    expr::{sort_expr::PySortExpr, PyExpr},
51};
52
53/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
54/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
55/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
56#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
57#[derive(Clone)]
58pub struct PyDataFrame {
59    df: Arc<DataFrame>,
60}
61
62impl PyDataFrame {
63    /// creates a new PyDataFrame
64    pub fn new(df: DataFrame) -> Self {
65        Self { df: Arc::new(df) }
66    }
67}
68
69#[pymethods]
70impl PyDataFrame {
71    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
72    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
73        if let Ok(key) = key.extract::<PyBackedStr>() {
74            // df[col]
75            self.select_columns(vec![key])
76        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
77            // df[col1, col2, col3]
78            let keys = tuple
79                .iter()
80                .map(|item| item.extract::<PyBackedStr>())
81                .collect::<PyResult<Vec<PyBackedStr>>>()?;
82            self.select_columns(keys)
83        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
84            // df[[col1, col2, col3]]
85            self.select_columns(keys)
86        } else {
87            let message = "DataFrame can only be indexed by string index or indices".to_string();
88            Err(PyDataFusionError::Common(message))
89        }
90    }
91
92    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
93        let df = self.df.as_ref().clone().limit(0, Some(10))?;
94        let batches = wait_for_future(py, df.collect())?;
95        let batches_as_string = pretty::pretty_format_batches(&batches);
96        match batches_as_string {
97            Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
98            Err(err) => Ok(format!("Error: {:?}", err.to_string())),
99        }
100    }
101
102    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
103        let mut html_str = "<table border='1'>\n".to_string();
104
105        let df = self.df.as_ref().clone().limit(0, Some(10))?;
106        let batches = wait_for_future(py, df.collect())?;
107
108        if batches.is_empty() {
109            html_str.push_str("</table>\n");
110            return Ok(html_str);
111        }
112
113        let schema = batches[0].schema();
114
115        let mut header = Vec::new();
116        for field in schema.fields() {
117            header.push(format!("<th>{}</td>", field.name()));
118        }
119        let header_str = header.join("");
120        html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
121
122        for batch in batches {
123            let formatters = batch
124                .columns()
125                .iter()
126                .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
127                .map(|c| {
128                    c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
129                })
130                .collect::<Result<Vec<_>, _>>()?;
131
132            for row in 0..batch.num_rows() {
133                let mut cells = Vec::new();
134                for formatter in &formatters {
135                    cells.push(format!("<td>{}</td>", formatter.value(row)));
136                }
137                let row_str = cells.join("");
138                html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
139            }
140        }
141
142        html_str.push_str("</table>\n");
143
144        Ok(html_str)
145    }
146
147    /// Calculate summary statistics for a DataFrame
148    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
149        let df = self.df.as_ref().clone();
150        let stat_df = wait_for_future(py, df.describe())?;
151        Ok(Self::new(stat_df))
152    }
153
154    /// Returns the schema from the logical plan
155    fn schema(&self) -> PyArrowType<Schema> {
156        PyArrowType(self.df.schema().into())
157    }
158
159    #[pyo3(signature = (*args))]
160    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
161        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
162        let df = self.df.as_ref().clone().select_columns(&args)?;
163        Ok(Self::new(df))
164    }
165
166    #[pyo3(signature = (*args))]
167    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
168        let expr = args.into_iter().map(|e| e.into()).collect();
169        let df = self.df.as_ref().clone().select(expr)?;
170        Ok(Self::new(df))
171    }
172
173    #[pyo3(signature = (*args))]
174    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
175        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
176        let df = self.df.as_ref().clone().drop_columns(&cols)?;
177        Ok(Self::new(df))
178    }
179
180    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
181        let df = self.df.as_ref().clone().filter(predicate.into())?;
182        Ok(Self::new(df))
183    }
184
185    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
186        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
187        Ok(Self::new(df))
188    }
189
190    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
191        let mut df = self.df.as_ref().clone();
192        for expr in exprs {
193            let expr: Expr = expr.into();
194            let name = format!("{}", expr.schema_name());
195            df = df.with_column(name.as_str(), expr)?
196        }
197        Ok(Self::new(df))
198    }
199
200    /// Rename one column by applying a new projection. This is a no-op if the column to be
201    /// renamed does not exist.
202    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
203        let df = self
204            .df
205            .as_ref()
206            .clone()
207            .with_column_renamed(old_name, new_name)?;
208        Ok(Self::new(df))
209    }
210
211    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
212        let group_by = group_by.into_iter().map(|e| e.into()).collect();
213        let aggs = aggs.into_iter().map(|e| e.into()).collect();
214        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
215        Ok(Self::new(df))
216    }
217
218    #[pyo3(signature = (*exprs))]
219    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
220        let exprs = to_sort_expressions(exprs);
221        let df = self.df.as_ref().clone().sort(exprs)?;
222        Ok(Self::new(df))
223    }
224
225    #[pyo3(signature = (count, offset=0))]
226    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
227        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
228        Ok(Self::new(df))
229    }
230
231    /// Executes the plan, returning a list of `RecordBatch`es.
232    /// Unless some order is specified in the plan, there is no
233    /// guarantee of the order of the result.
234    fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
235        let batches = wait_for_future(py, self.df.as_ref().clone().collect())
236            .map_err(PyDataFusionError::from)?;
237        // cannot use PyResult<Vec<RecordBatch>> return type due to
238        // https://github.com/PyO3/pyo3/issues/1813
239        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
240    }
241
242    /// Cache DataFrame.
243    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
244        let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
245        Ok(Self::new(df))
246    }
247
248    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
249    /// maintaining the input partitioning.
250    fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
251        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
252            .map_err(PyDataFusionError::from)?;
253
254        batches
255            .into_iter()
256            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
257            .collect()
258    }
259
260    /// Print the result, 20 lines by default
261    #[pyo3(signature = (num=20))]
262    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
263        let df = self.df.as_ref().clone().limit(0, Some(num))?;
264        print_dataframe(py, df)
265    }
266
267    /// Filter out duplicate rows
268    fn distinct(&self) -> PyDataFusionResult<Self> {
269        let df = self.df.as_ref().clone().distinct()?;
270        Ok(Self::new(df))
271    }
272
273    fn join(
274        &self,
275        right: PyDataFrame,
276        how: &str,
277        left_on: Vec<PyBackedStr>,
278        right_on: Vec<PyBackedStr>,
279    ) -> PyDataFusionResult<Self> {
280        let join_type = match how {
281            "inner" => JoinType::Inner,
282            "left" => JoinType::Left,
283            "right" => JoinType::Right,
284            "full" => JoinType::Full,
285            "semi" => JoinType::LeftSemi,
286            "anti" => JoinType::LeftAnti,
287            how => {
288                return Err(PyDataFusionError::Common(format!(
289                    "The join type {how} does not exist or is not implemented"
290                )));
291            }
292        };
293
294        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
295        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
296
297        let df = self.df.as_ref().clone().join(
298            right.df.as_ref().clone(),
299            join_type,
300            &left_keys,
301            &right_keys,
302            None,
303        )?;
304        Ok(Self::new(df))
305    }
306
307    fn join_on(
308        &self,
309        right: PyDataFrame,
310        on_exprs: Vec<PyExpr>,
311        how: &str,
312    ) -> PyDataFusionResult<Self> {
313        let join_type = match how {
314            "inner" => JoinType::Inner,
315            "left" => JoinType::Left,
316            "right" => JoinType::Right,
317            "full" => JoinType::Full,
318            "semi" => JoinType::LeftSemi,
319            "anti" => JoinType::LeftAnti,
320            how => {
321                return Err(PyDataFusionError::Common(format!(
322                    "The join type {how} does not exist or is not implemented"
323                )));
324            }
325        };
326        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
327
328        let df = self
329            .df
330            .as_ref()
331            .clone()
332            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
333        Ok(Self::new(df))
334    }
335
336    /// Print the query plan
337    #[pyo3(signature = (verbose=false, analyze=false))]
338    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
339        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
340        print_dataframe(py, df)
341    }
342
343    /// Get the logical plan for this `DataFrame`
344    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
345        Ok(self.df.as_ref().clone().logical_plan().clone().into())
346    }
347
348    /// Get the optimized logical plan for this `DataFrame`
349    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
350        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
351    }
352
353    /// Get the execution plan for this `DataFrame`
354    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
355        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
356        Ok(plan.into())
357    }
358
359    /// Repartition a `DataFrame` based on a logical partitioning scheme.
360    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
361        let new_df = self
362            .df
363            .as_ref()
364            .clone()
365            .repartition(Partitioning::RoundRobinBatch(num))?;
366        Ok(Self::new(new_df))
367    }
368
369    /// Repartition a `DataFrame` based on a logical partitioning scheme.
370    #[pyo3(signature = (*args, num))]
371    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
372        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
373        let new_df = self
374            .df
375            .as_ref()
376            .clone()
377            .repartition(Partitioning::Hash(expr, num))?;
378        Ok(Self::new(new_df))
379    }
380
381    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
382    /// two `DataFrame`s must have exactly the same schema
383    #[pyo3(signature = (py_df, distinct=false))]
384    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
385        let new_df = if distinct {
386            self.df
387                .as_ref()
388                .clone()
389                .union_distinct(py_df.df.as_ref().clone())?
390        } else {
391            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
392        };
393
394        Ok(Self::new(new_df))
395    }
396
397    /// Calculate the distinct union of two `DataFrame`s.  The
398    /// two `DataFrame`s must have exactly the same schema
399    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
400        let new_df = self
401            .df
402            .as_ref()
403            .clone()
404            .union_distinct(py_df.df.as_ref().clone())?;
405        Ok(Self::new(new_df))
406    }
407
408    #[pyo3(signature = (column, preserve_nulls=true))]
409    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
410        // TODO: expose RecursionUnnestOptions
411        // REF: https://github.com/apache/datafusion/pull/11577
412        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
413        let df = self
414            .df
415            .as_ref()
416            .clone()
417            .unnest_columns_with_options(&[column], unnest_options)?;
418        Ok(Self::new(df))
419    }
420
421    #[pyo3(signature = (columns, preserve_nulls=true))]
422    fn unnest_columns(
423        &self,
424        columns: Vec<String>,
425        preserve_nulls: bool,
426    ) -> PyDataFusionResult<Self> {
427        // TODO: expose RecursionUnnestOptions
428        // REF: https://github.com/apache/datafusion/pull/11577
429        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
430        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
431        let df = self
432            .df
433            .as_ref()
434            .clone()
435            .unnest_columns_with_options(&cols, unnest_options)?;
436        Ok(Self::new(df))
437    }
438
439    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
440    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
441        let new_df = self
442            .df
443            .as_ref()
444            .clone()
445            .intersect(py_df.df.as_ref().clone())?;
446        Ok(Self::new(new_df))
447    }
448
449    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
450    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
451        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
452        Ok(Self::new(new_df))
453    }
454
455    /// Write a `DataFrame` to a CSV file.
456    fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
457        let csv_options = CsvOptions {
458            has_header: Some(with_header),
459            ..Default::default()
460        };
461        wait_for_future(
462            py,
463            self.df.as_ref().clone().write_csv(
464                path,
465                DataFrameWriteOptions::new(),
466                Some(csv_options),
467            ),
468        )?;
469        Ok(())
470    }
471
472    /// Write a `DataFrame` to a Parquet file.
473    #[pyo3(signature = (
474        path,
475        compression="zstd",
476        compression_level=None
477        ))]
478    fn write_parquet(
479        &self,
480        path: &str,
481        compression: &str,
482        compression_level: Option<u32>,
483        py: Python,
484    ) -> PyDataFusionResult<()> {
485        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
486            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
487        }
488
489        let _validated = match compression.to_lowercase().as_str() {
490            "snappy" => Compression::SNAPPY,
491            "gzip" => Compression::GZIP(
492                GzipLevel::try_new(compression_level.unwrap_or(6))
493                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
494            ),
495            "brotli" => Compression::BROTLI(
496                BrotliLevel::try_new(verify_compression_level(compression_level)?)
497                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
498            ),
499            "zstd" => Compression::ZSTD(
500                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
501                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
502            ),
503            "lzo" => Compression::LZO,
504            "lz4" => Compression::LZ4,
505            "lz4_raw" => Compression::LZ4_RAW,
506            "uncompressed" => Compression::UNCOMPRESSED,
507            _ => {
508                return Err(PyDataFusionError::Common(format!(
509                    "Unrecognized compression type {compression}"
510                )));
511            }
512        };
513
514        let mut compression_string = compression.to_string();
515        if let Some(level) = compression_level {
516            compression_string.push_str(&format!("({level})"));
517        }
518
519        let mut options = TableParquetOptions::default();
520        options.global.compression = Some(compression_string);
521
522        wait_for_future(
523            py,
524            self.df.as_ref().clone().write_parquet(
525                path,
526                DataFrameWriteOptions::new(),
527                Option::from(options),
528            ),
529        )?;
530        Ok(())
531    }
532
533    /// Executes a query and writes the results to a partitioned JSON file.
534    fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
535        wait_for_future(
536            py,
537            self.df
538                .as_ref()
539                .clone()
540                .write_json(path, DataFrameWriteOptions::new(), None),
541        )?;
542        Ok(())
543    }
544
545    /// Convert to Arrow Table
546    /// Collect the batches and pass to Arrow Table
547    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
548        let batches = self.collect(py)?.to_object(py);
549        let schema: PyObject = self.schema().into_pyobject(py)?.to_object(py);
550
551        // Instantiate pyarrow Table object and use its from_batches method
552        let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
553        let args = PyTuple::new_bound(py, &[batches, schema]);
554        let table: PyObject = table_class.call_method1("from_batches", args)?.into();
555        Ok(table)
556    }
557
558    #[pyo3(signature = (requested_schema=None))]
559    fn __arrow_c_stream__<'py>(
560        &'py mut self,
561        py: Python<'py>,
562        requested_schema: Option<Bound<'py, PyCapsule>>,
563    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
564        let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
565        let mut schema: Schema = self.df.schema().to_owned().into();
566
567        if let Some(schema_capsule) = requested_schema {
568            validate_pycapsule(&schema_capsule, "arrow_schema")?;
569
570            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
571            let desired_schema = Schema::try_from(schema_ptr)?;
572
573            schema = project_schema(schema, desired_schema)?;
574
575            batches = batches
576                .into_iter()
577                .map(|record_batch| record_batch_into_schema(record_batch, &schema))
578                .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
579        }
580
581        let batches_wrapped = batches.into_iter().map(Ok);
582
583        let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
584        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
585
586        let ffi_stream = FFI_ArrowArrayStream::new(reader);
587        let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
588        PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
589            .map_err(PyDataFusionError::from)
590    }
591
592    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
593        // create a Tokio runtime to run the async code
594        let rt = &get_tokio_runtime().0;
595        let df = self.df.as_ref().clone();
596        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
597            rt.spawn(async move { df.execute_stream().await });
598        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
599        Ok(PyRecordBatchStream::new(stream?))
600    }
601
602    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
603        // create a Tokio runtime to run the async code
604        let rt = &get_tokio_runtime().0;
605        let df = self.df.as_ref().clone();
606        let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
607            rt.spawn(async move { df.execute_stream_partitioned().await });
608        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
609
610        match stream {
611            Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
612            _ => Err(PyValueError::new_err(
613                "Unable to execute stream partitioned",
614            )),
615        }
616    }
617
618    /// Convert to pandas dataframe with pyarrow
619    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
620    fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
621        let table = self.to_arrow_table(py)?;
622
623        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
624        let result = table.call_method0(py, "to_pandas")?;
625        Ok(result)
626    }
627
628    /// Convert to Python list using pyarrow
629    /// Each list item represents one row encoded as dictionary
630    fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
631        let table = self.to_arrow_table(py)?;
632
633        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
634        let result = table.call_method0(py, "to_pylist")?;
635        Ok(result)
636    }
637
638    /// Convert to Python dictionary using pyarrow
639    /// Each dictionary key is a column and the dictionary value represents the column values
640    fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
641        let table = self.to_arrow_table(py)?;
642
643        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
644        let result = table.call_method0(py, "to_pydict")?;
645        Ok(result)
646    }
647
648    /// Convert to polars dataframe with pyarrow
649    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
650    fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
651        let table = self.to_arrow_table(py)?;
652        let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
653        let args = PyTuple::new_bound(py, &[table]);
654        let result: PyObject = dataframe.call1(args)?.into();
655        Ok(result)
656    }
657
658    // Executes this DataFrame to get the total number of rows.
659    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
660        Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
661    }
662}
663
664/// Print DataFrame
665fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
666    // Get string representation of record batches
667    let batches = wait_for_future(py, df.collect())?;
668    let batches_as_string = pretty::pretty_format_batches(&batches);
669    let result = match batches_as_string {
670        Ok(batch) => format!("DataFrame()\n{batch}"),
671        Err(err) => format!("Error: {:?}", err.to_string()),
672    };
673
674    // Import the Python 'builtins' module to access the print function
675    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
676    let print = py.import_bound("builtins")?.getattr("print")?;
677    print.call1((result,))?;
678    Ok(())
679}
680
681fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
682    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
683
684    let project_indices: Vec<usize> = to_schema
685        .fields
686        .iter()
687        .map(|field| field.name())
688        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
689        .collect();
690
691    merged_schema.project(&project_indices)
692}
693
694fn record_batch_into_schema(
695    record_batch: RecordBatch,
696    schema: &Schema,
697) -> Result<RecordBatch, ArrowError> {
698    let schema = Arc::new(schema.clone());
699    let base_schema = record_batch.schema();
700    if base_schema.fields().len() == 0 {
701        // Nothing to project
702        return Ok(RecordBatch::new_empty(schema));
703    }
704
705    let array_size = record_batch.column(0).len();
706    let mut data_arrays = Vec::with_capacity(schema.fields().len());
707
708    for field in schema.fields() {
709        let desired_data_type = field.data_type();
710        if let Some(original_data) = record_batch.column_by_name(field.name()) {
711            let original_data_type = original_data.data_type();
712
713            if can_cast_types(original_data_type, desired_data_type) {
714                data_arrays.push(arrow::compute::kernels::cast(
715                    original_data,
716                    desired_data_type,
717                )?);
718            } else if field.is_nullable() {
719                data_arrays.push(new_null_array(desired_data_type, array_size));
720            } else {
721                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
722            }
723        } else {
724            if !field.is_nullable() {
725                return Err(ArrowError::CastError(format!(
726                    "Attempting to set null to non-nullable field {} during schema projection.",
727                    field.name()
728                )));
729            }
730            data_arrays.push(new_null_array(desired_data_type, array_size));
731        }
732    }
733
734    RecordBatch::try_new(schema, data_arrays)
735}