datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::{CStr, CString};
20use std::sync::Arc;
21
22use arrow::array::{new_null_array, Array, ArrayRef, RecordBatch, RecordBatchReader};
23use arrow::compute::can_cast_types;
24use arrow::error::ArrowError;
25use arrow::ffi::FFI_ArrowSchema;
26use arrow::ffi_stream::FFI_ArrowArrayStream;
27use arrow::pyarrow::FromPyArrow;
28use cstr::cstr;
29use datafusion::arrow::datatypes::{Schema, SchemaRef};
30use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
31use datafusion::arrow::util::pretty;
32use datafusion::catalog::TableProvider;
33use datafusion::common::UnnestOptions;
34use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
35use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
36use datafusion::error::DataFusionError;
37use datafusion::execution::SendableRecordBatchStream;
38use datafusion::logical_expr::dml::InsertOp;
39use datafusion::logical_expr::SortExpr;
40use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
41use datafusion::prelude::*;
42use futures::{StreamExt, TryStreamExt};
43use parking_lot::Mutex;
44use pyo3::exceptions::PyValueError;
45use pyo3::prelude::*;
46use pyo3::pybacked::PyBackedStr;
47use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
48use pyo3::PyErr;
49
50use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
51use crate::expr::sort_expr::{to_sort_expressions, PySortExpr};
52use crate::expr::PyExpr;
53use crate::physical_plan::PyExecutionPlan;
54use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
55use crate::sql::logical::PyLogicalPlan;
56use crate::table::{PyTable, TempViewTable};
57use crate::utils::{
58    is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future,
59};
60
61/// File-level static CStr for the Arrow array stream capsule name.
62static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
63
64// Type aliases to simplify very complex types used in this file and
65// avoid compiler complaints about deeply nested types in struct fields.
66type CachedBatches = Option<(Vec<RecordBatch>, bool)>;
67type SharedCachedBatches = Arc<Mutex<CachedBatches>>;
68
69/// Configuration for DataFrame display formatting
70#[derive(Debug, Clone)]
71pub struct FormatterConfig {
72    /// Maximum memory in bytes to use for display (default: 2MB)
73    pub max_bytes: usize,
74    /// Minimum number of rows to display (default: 20)
75    pub min_rows: usize,
76    /// Number of rows to include in __repr__ output (default: 10)
77    pub repr_rows: usize,
78}
79
80impl Default for FormatterConfig {
81    fn default() -> Self {
82        Self {
83            max_bytes: 2 * 1024 * 1024, // 2MB
84            min_rows: 20,
85            repr_rows: 10,
86        }
87    }
88}
89
90impl FormatterConfig {
91    /// Validates that all configuration values are positive integers.
92    ///
93    /// # Returns
94    ///
95    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
96    pub fn validate(&self) -> Result<(), String> {
97        if self.max_bytes == 0 {
98            return Err("max_bytes must be a positive integer".to_string());
99        }
100
101        if self.min_rows == 0 {
102            return Err("min_rows must be a positive integer".to_string());
103        }
104
105        if self.repr_rows == 0 {
106            return Err("repr_rows must be a positive integer".to_string());
107        }
108
109        Ok(())
110    }
111}
112
113/// Holds the Python formatter and its configuration
114struct PythonFormatter<'py> {
115    /// The Python formatter object
116    formatter: Bound<'py, PyAny>,
117    /// The formatter configuration
118    config: FormatterConfig,
119}
120
121/// Get the Python formatter and its configuration
122fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
123    let formatter = import_python_formatter(py)?;
124    let config = build_formatter_config_from_python(&formatter)?;
125    Ok(PythonFormatter { formatter, config })
126}
127
128/// Get the Python formatter from the datafusion.dataframe_formatter module
129fn import_python_formatter(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
130    let formatter_module = py.import("datafusion.dataframe_formatter")?;
131    let get_formatter = formatter_module.getattr("get_formatter")?;
132    get_formatter.call0()
133}
134
135// Helper function to extract attributes with fallback to default
136fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
137where
138    T: for<'py> pyo3::FromPyObject<'py> + Clone,
139{
140    py_object
141        .getattr(attr_name)
142        .and_then(|v| v.extract::<T>())
143        .unwrap_or_else(|_| default_value.clone())
144}
145
146/// Helper function to create a FormatterConfig from a Python formatter object
147fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
148    let default_config = FormatterConfig::default();
149    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
150    let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
151    let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
152
153    let config = FormatterConfig {
154        max_bytes,
155        min_rows,
156        repr_rows,
157    };
158
159    // Return the validated config, converting String error to PyErr
160    config.validate().map_err(PyValueError::new_err)?;
161    Ok(config)
162}
163
164/// Python mapping of `ParquetOptions` (includes just the writer-related options).
165#[pyclass(frozen, name = "ParquetWriterOptions", module = "datafusion", subclass)]
166#[derive(Clone, Default)]
167pub struct PyParquetWriterOptions {
168    options: ParquetOptions,
169}
170
171#[pymethods]
172impl PyParquetWriterOptions {
173    #[new]
174    #[allow(clippy::too_many_arguments)]
175    pub fn new(
176        data_pagesize_limit: usize,
177        write_batch_size: usize,
178        writer_version: String,
179        skip_arrow_metadata: bool,
180        compression: Option<String>,
181        dictionary_enabled: Option<bool>,
182        dictionary_page_size_limit: usize,
183        statistics_enabled: Option<String>,
184        max_row_group_size: usize,
185        created_by: String,
186        column_index_truncate_length: Option<usize>,
187        statistics_truncate_length: Option<usize>,
188        data_page_row_count_limit: usize,
189        encoding: Option<String>,
190        bloom_filter_on_write: bool,
191        bloom_filter_fpp: Option<f64>,
192        bloom_filter_ndv: Option<u64>,
193        allow_single_file_parallelism: bool,
194        maximum_parallel_row_group_writers: usize,
195        maximum_buffered_record_batches_per_stream: usize,
196    ) -> Self {
197        Self {
198            options: ParquetOptions {
199                data_pagesize_limit,
200                write_batch_size,
201                writer_version,
202                skip_arrow_metadata,
203                compression,
204                dictionary_enabled,
205                dictionary_page_size_limit,
206                statistics_enabled,
207                max_row_group_size,
208                created_by,
209                column_index_truncate_length,
210                statistics_truncate_length,
211                data_page_row_count_limit,
212                encoding,
213                bloom_filter_on_write,
214                bloom_filter_fpp,
215                bloom_filter_ndv,
216                allow_single_file_parallelism,
217                maximum_parallel_row_group_writers,
218                maximum_buffered_record_batches_per_stream,
219                ..Default::default()
220            },
221        }
222    }
223}
224
225/// Python mapping of `ParquetColumnOptions`.
226#[pyclass(frozen, name = "ParquetColumnOptions", module = "datafusion", subclass)]
227#[derive(Clone, Default)]
228pub struct PyParquetColumnOptions {
229    options: ParquetColumnOptions,
230}
231
232#[pymethods]
233impl PyParquetColumnOptions {
234    #[new]
235    pub fn new(
236        bloom_filter_enabled: Option<bool>,
237        encoding: Option<String>,
238        dictionary_enabled: Option<bool>,
239        compression: Option<String>,
240        statistics_enabled: Option<String>,
241        bloom_filter_fpp: Option<f64>,
242        bloom_filter_ndv: Option<u64>,
243    ) -> Self {
244        Self {
245            options: ParquetColumnOptions {
246                bloom_filter_enabled,
247                encoding,
248                dictionary_enabled,
249                compression,
250                statistics_enabled,
251                bloom_filter_fpp,
252                bloom_filter_ndv,
253            },
254        }
255    }
256}
257
258/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
259/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
260/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
261#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)]
262#[derive(Clone)]
263pub struct PyDataFrame {
264    df: Arc<DataFrame>,
265
266    // In IPython environment cache batches between __repr__ and _repr_html_ calls.
267    batches: SharedCachedBatches,
268}
269
270impl PyDataFrame {
271    /// creates a new PyDataFrame
272    pub fn new(df: DataFrame) -> Self {
273        Self {
274            df: Arc::new(df),
275            batches: Arc::new(Mutex::new(None)),
276        }
277    }
278
279    /// Return a clone of the inner Arc<DataFrame> for crate-local callers.
280    pub(crate) fn inner_df(&self) -> Arc<DataFrame> {
281        Arc::clone(&self.df)
282    }
283
284    fn prepare_repr_string<'py>(
285        &self,
286        py: Python<'py>,
287        as_html: bool,
288    ) -> PyDataFusionResult<String> {
289        // Get the Python formatter and config
290        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
291
292        let is_ipython = *is_ipython_env(py);
293
294        let (cached_batches, should_cache) = {
295            let mut cache = self.batches.lock();
296            let should_cache = is_ipython && cache.is_none();
297            let batches = cache.take();
298            (batches, should_cache)
299        };
300
301        let (batches, has_more) = match cached_batches {
302            Some(b) => b,
303            None => wait_for_future(
304                py,
305                collect_record_batches_to_display(self.df.as_ref().clone(), config),
306            )??,
307        };
308
309        if batches.is_empty() {
310            // This should not be reached, but do it for safety since we index into the vector below
311            return Ok("No data to display".to_string());
312        }
313
314        let table_uuid = uuid::Uuid::new_v4().to_string();
315
316        // Convert record batches to Py<PyAny> list
317        let py_batches = batches
318            .iter()
319            .map(|rb| rb.to_pyarrow(py))
320            .collect::<PyResult<Vec<Bound<'py, PyAny>>>>()?;
321
322        let py_schema = self.schema().into_pyobject(py)?;
323
324        let kwargs = pyo3::types::PyDict::new(py);
325        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
326        kwargs.set_item("batches", py_batches_list)?;
327        kwargs.set_item("schema", py_schema)?;
328        kwargs.set_item("has_more", has_more)?;
329        kwargs.set_item("table_uuid", table_uuid)?;
330
331        let method_name = match as_html {
332            true => "format_html",
333            false => "format_str",
334        };
335
336        let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
337        let html_str: String = html_result.extract()?;
338
339        if should_cache {
340            let mut cache = self.batches.lock();
341            *cache = Some((batches.clone(), has_more));
342        }
343
344        Ok(html_str)
345    }
346
347    async fn collect_column_inner(&self, column: &str) -> Result<ArrayRef, DataFusionError> {
348        let batches = self
349            .df
350            .as_ref()
351            .clone()
352            .select_columns(&[column])?
353            .collect()
354            .await?;
355
356        let arrays = batches
357            .iter()
358            .map(|b| b.column(0).as_ref())
359            .collect::<Vec<_>>();
360
361        arrow_select::concat::concat(&arrays).map_err(Into::into)
362    }
363}
364
365/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
366/// for the `__arrow_c_stream__` implementation.
367///
368/// It drains each partition's stream sequentially, yielding record batches in
369/// their original partition order. When a `projection` is set, each batch is
370/// converted via `record_batch_into_schema` to apply schema changes per batch.
371struct PartitionedDataFrameStreamReader {
372    streams: Vec<SendableRecordBatchStream>,
373    schema: SchemaRef,
374    projection: Option<SchemaRef>,
375    current: usize,
376}
377
378impl Iterator for PartitionedDataFrameStreamReader {
379    type Item = Result<RecordBatch, ArrowError>;
380
381    fn next(&mut self) -> Option<Self::Item> {
382        while self.current < self.streams.len() {
383            let stream = &mut self.streams[self.current];
384            let fut = poll_next_batch(stream);
385            let result = Python::attach(|py| wait_for_future(py, fut));
386
387            match result {
388                Ok(Ok(Some(batch))) => {
389                    let batch = if let Some(ref schema) = self.projection {
390                        match record_batch_into_schema(batch, schema.as_ref()) {
391                            Ok(b) => b,
392                            Err(e) => return Some(Err(e)),
393                        }
394                    } else {
395                        batch
396                    };
397                    return Some(Ok(batch));
398                }
399                Ok(Ok(None)) => {
400                    self.current += 1;
401                    continue;
402                }
403                Ok(Err(e)) => {
404                    return Some(Err(ArrowError::ExternalError(Box::new(e))));
405                }
406                Err(e) => {
407                    return Some(Err(ArrowError::ExternalError(Box::new(e))));
408                }
409            }
410        }
411
412        None
413    }
414}
415
416impl RecordBatchReader for PartitionedDataFrameStreamReader {
417    fn schema(&self) -> SchemaRef {
418        self.schema.clone()
419    }
420}
421
422#[pymethods]
423impl PyDataFrame {
424    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
425    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
426        if let Ok(key) = key.extract::<PyBackedStr>() {
427            // df[col]
428            self.select_columns(vec![key])
429        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
430            // df[col1, col2, col3]
431            let keys = tuple
432                .iter()
433                .map(|item| item.extract::<PyBackedStr>())
434                .collect::<PyResult<Vec<PyBackedStr>>>()?;
435            self.select_columns(keys)
436        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
437            // df[[col1, col2, col3]]
438            self.select_columns(keys)
439        } else {
440            let message = "DataFrame can only be indexed by string index or indices".to_string();
441            Err(PyDataFusionError::Common(message))
442        }
443    }
444
445    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
446        self.prepare_repr_string(py, false)
447    }
448
449    #[staticmethod]
450    #[expect(unused_variables)]
451    fn default_str_repr<'py>(
452        batches: Vec<Bound<'py, PyAny>>,
453        schema: &Bound<'py, PyAny>,
454        has_more: bool,
455        table_uuid: &str,
456    ) -> PyResult<String> {
457        let batches = batches
458            .into_iter()
459            .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
460            .collect::<PyResult<Vec<RecordBatch>>>()?
461            .into_iter()
462            .filter(|batch| batch.num_rows() > 0)
463            .collect::<Vec<_>>();
464
465        if batches.is_empty() {
466            return Ok("No data to display".to_owned());
467        }
468
469        let batches_as_displ =
470            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
471
472        let additional_str = match has_more {
473            true => "\nData truncated.",
474            false => "",
475        };
476
477        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
478    }
479
480    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
481        self.prepare_repr_string(py, true)
482    }
483
484    /// Calculate summary statistics for a DataFrame
485    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
486        let df = self.df.as_ref().clone();
487        let stat_df = wait_for_future(py, df.describe())??;
488        Ok(Self::new(stat_df))
489    }
490
491    /// Returns the schema from the logical plan
492    fn schema(&self) -> PyArrowType<Schema> {
493        PyArrowType(self.df.schema().as_arrow().clone())
494    }
495
496    /// Convert this DataFrame into a Table Provider that can be used in register_table
497    /// By convention, into_... methods consume self and return the new object.
498    /// Disabling the clippy lint, so we can use &self
499    /// because we're working with Python bindings
500    /// where objects are shared
501    #[allow(clippy::wrong_self_convention)]
502    pub fn into_view(&self, temporary: bool) -> PyDataFusionResult<PyTable> {
503        let table_provider = if temporary {
504            Arc::new(TempViewTable::new(Arc::clone(&self.df))) as Arc<dyn TableProvider>
505        } else {
506            // Call the underlying Rust DataFrame::into_view method.
507            // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
508            // so that we don't invalidate this PyDataFrame.
509            self.df.as_ref().clone().into_view()
510        };
511        Ok(PyTable::from(table_provider))
512    }
513
514    #[pyo3(signature = (*args))]
515    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
516        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
517        let df = self.df.as_ref().clone().select_columns(&args)?;
518        Ok(Self::new(df))
519    }
520
521    #[pyo3(signature = (*args))]
522    fn select_exprs(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
523        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
524        let df = self.df.as_ref().clone().select_exprs(&args)?;
525        Ok(Self::new(df))
526    }
527
528    #[pyo3(signature = (*args))]
529    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
530        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
531        let df = self.df.as_ref().clone().select(expr)?;
532        Ok(Self::new(df))
533    }
534
535    #[pyo3(signature = (*args))]
536    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
537        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
538        let df = self.df.as_ref().clone().drop_columns(&cols)?;
539        Ok(Self::new(df))
540    }
541
542    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
543        let df = self.df.as_ref().clone().filter(predicate.into())?;
544        Ok(Self::new(df))
545    }
546
547    fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {
548        self.df
549            .as_ref()
550            .parse_sql_expr(&expr)
551            .map(PyExpr::from)
552            .map_err(PyDataFusionError::from)
553    }
554
555    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
556        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
557        Ok(Self::new(df))
558    }
559
560    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
561        let mut df = self.df.as_ref().clone();
562        for expr in exprs {
563            let expr: Expr = expr.into();
564            let name = format!("{}", expr.schema_name());
565            df = df.with_column(name.as_str(), expr)?
566        }
567        Ok(Self::new(df))
568    }
569
570    /// Rename one column by applying a new projection. This is a no-op if the column to be
571    /// renamed does not exist.
572    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
573        let df = self
574            .df
575            .as_ref()
576            .clone()
577            .with_column_renamed(old_name, new_name)?;
578        Ok(Self::new(df))
579    }
580
581    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
582        let group_by = group_by.into_iter().map(|e| e.into()).collect();
583        let aggs = aggs.into_iter().map(|e| e.into()).collect();
584        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
585        Ok(Self::new(df))
586    }
587
588    #[pyo3(signature = (*exprs))]
589    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
590        let exprs = to_sort_expressions(exprs);
591        let df = self.df.as_ref().clone().sort(exprs)?;
592        Ok(Self::new(df))
593    }
594
595    #[pyo3(signature = (count, offset=0))]
596    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
597        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
598        Ok(Self::new(df))
599    }
600
601    /// Executes the plan, returning a list of `RecordBatch`es.
602    /// Unless some order is specified in the plan, there is no
603    /// guarantee of the order of the result.
604    fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
605        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
606            .map_err(PyDataFusionError::from)?;
607        // cannot use PyResult<Vec<RecordBatch>> return type due to
608        // https://github.com/PyO3/pyo3/issues/1813
609        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
610    }
611
612    /// Cache DataFrame.
613    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
614        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
615        Ok(Self::new(df))
616    }
617
618    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
619    /// maintaining the input partitioning.
620    fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
621        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
622            .map_err(PyDataFusionError::from)?;
623
624        batches
625            .into_iter()
626            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
627            .collect()
628    }
629
630    fn collect_column<'py>(&self, py: Python<'py>, column: &str) -> PyResult<Bound<'py, PyAny>> {
631        wait_for_future(py, self.collect_column_inner(column))?
632            .map_err(PyDataFusionError::from)?
633            .to_data()
634            .to_pyarrow(py)
635    }
636
637    /// Print the result, 20 lines by default
638    #[pyo3(signature = (num=20))]
639    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
640        let df = self.df.as_ref().clone().limit(0, Some(num))?;
641        print_dataframe(py, df)
642    }
643
644    /// Filter out duplicate rows
645    fn distinct(&self) -> PyDataFusionResult<Self> {
646        let df = self.df.as_ref().clone().distinct()?;
647        Ok(Self::new(df))
648    }
649
650    fn join(
651        &self,
652        right: PyDataFrame,
653        how: &str,
654        left_on: Vec<PyBackedStr>,
655        right_on: Vec<PyBackedStr>,
656        coalesce_keys: bool,
657    ) -> PyDataFusionResult<Self> {
658        let join_type = match how {
659            "inner" => JoinType::Inner,
660            "left" => JoinType::Left,
661            "right" => JoinType::Right,
662            "full" => JoinType::Full,
663            "semi" => JoinType::LeftSemi,
664            "anti" => JoinType::LeftAnti,
665            how => {
666                return Err(PyDataFusionError::Common(format!(
667                    "The join type {how} does not exist or is not implemented"
668                )));
669            }
670        };
671
672        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
673        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
674
675        let mut df = self.df.as_ref().clone().join(
676            right.df.as_ref().clone(),
677            join_type,
678            &left_keys,
679            &right_keys,
680            None,
681        )?;
682
683        if coalesce_keys {
684            let mutual_keys = left_keys
685                .iter()
686                .zip(right_keys.iter())
687                .filter(|(l, r)| l == r)
688                .map(|(key, _)| *key)
689                .collect::<Vec<_>>();
690
691            let fields_to_coalesce = mutual_keys
692                .iter()
693                .map(|name| {
694                    let qualified_fields = df
695                        .logical_plan()
696                        .schema()
697                        .qualified_fields_with_unqualified_name(name);
698                    (*name, qualified_fields)
699                })
700                .filter(|(_, fields)| fields.len() == 2)
701                .collect::<Vec<_>>();
702
703            let expr: Vec<Expr> = df
704                .logical_plan()
705                .schema()
706                .fields()
707                .into_iter()
708                .enumerate()
709                .map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
710                .filter_map(|(qualifier, field)| {
711                    if let Some((key_name, qualified_fields)) = fields_to_coalesce
712                        .iter()
713                        .find(|(_, qf)| qf.contains(&(qualifier, field)))
714                    {
715                        // Only add the coalesce expression once (when we encounter the first field)
716                        // Skip the second field (it's already included in to coalesce)
717                        if (qualifier, field) == qualified_fields[0] {
718                            let left_col = Expr::Column(Column::from(qualified_fields[0]));
719                            let right_col = Expr::Column(Column::from(qualified_fields[1]));
720                            return Some(coalesce(vec![left_col, right_col]).alias(*key_name));
721                        }
722                        None
723                    } else {
724                        Some(Expr::Column(Column::from((qualifier, field))))
725                    }
726                })
727                .collect();
728            df = df.select(expr)?;
729        }
730
731        Ok(Self::new(df))
732    }
733
734    fn join_on(
735        &self,
736        right: PyDataFrame,
737        on_exprs: Vec<PyExpr>,
738        how: &str,
739    ) -> PyDataFusionResult<Self> {
740        let join_type = match how {
741            "inner" => JoinType::Inner,
742            "left" => JoinType::Left,
743            "right" => JoinType::Right,
744            "full" => JoinType::Full,
745            "semi" => JoinType::LeftSemi,
746            "anti" => JoinType::LeftAnti,
747            how => {
748                return Err(PyDataFusionError::Common(format!(
749                    "The join type {how} does not exist or is not implemented"
750                )));
751            }
752        };
753        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
754
755        let df = self
756            .df
757            .as_ref()
758            .clone()
759            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
760        Ok(Self::new(df))
761    }
762
763    /// Print the query plan
764    #[pyo3(signature = (verbose=false, analyze=false))]
765    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
766        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
767        print_dataframe(py, df)
768    }
769
770    /// Get the logical plan for this `DataFrame`
771    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
772        Ok(self.df.as_ref().clone().logical_plan().clone().into())
773    }
774
775    /// Get the optimized logical plan for this `DataFrame`
776    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
777        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
778    }
779
780    /// Get the execution plan for this `DataFrame`
781    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
782        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
783        Ok(plan.into())
784    }
785
786    /// Repartition a `DataFrame` based on a logical partitioning scheme.
787    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
788        let new_df = self
789            .df
790            .as_ref()
791            .clone()
792            .repartition(Partitioning::RoundRobinBatch(num))?;
793        Ok(Self::new(new_df))
794    }
795
796    /// Repartition a `DataFrame` based on a logical partitioning scheme.
797    #[pyo3(signature = (*args, num))]
798    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
799        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
800        let new_df = self
801            .df
802            .as_ref()
803            .clone()
804            .repartition(Partitioning::Hash(expr, num))?;
805        Ok(Self::new(new_df))
806    }
807
808    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
809    /// two `DataFrame`s must have exactly the same schema
810    #[pyo3(signature = (py_df, distinct=false))]
811    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
812        let new_df = if distinct {
813            self.df
814                .as_ref()
815                .clone()
816                .union_distinct(py_df.df.as_ref().clone())?
817        } else {
818            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
819        };
820
821        Ok(Self::new(new_df))
822    }
823
824    /// Calculate the distinct union of two `DataFrame`s.  The
825    /// two `DataFrame`s must have exactly the same schema
826    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
827        let new_df = self
828            .df
829            .as_ref()
830            .clone()
831            .union_distinct(py_df.df.as_ref().clone())?;
832        Ok(Self::new(new_df))
833    }
834
835    #[pyo3(signature = (column, preserve_nulls=true))]
836    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
837        // TODO: expose RecursionUnnestOptions
838        // REF: https://github.com/apache/datafusion/pull/11577
839        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
840        let df = self
841            .df
842            .as_ref()
843            .clone()
844            .unnest_columns_with_options(&[column], unnest_options)?;
845        Ok(Self::new(df))
846    }
847
848    #[pyo3(signature = (columns, preserve_nulls=true))]
849    fn unnest_columns(
850        &self,
851        columns: Vec<String>,
852        preserve_nulls: bool,
853    ) -> PyDataFusionResult<Self> {
854        // TODO: expose RecursionUnnestOptions
855        // REF: https://github.com/apache/datafusion/pull/11577
856        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
857        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
858        let df = self
859            .df
860            .as_ref()
861            .clone()
862            .unnest_columns_with_options(&cols, unnest_options)?;
863        Ok(Self::new(df))
864    }
865
866    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
867    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
868        let new_df = self
869            .df
870            .as_ref()
871            .clone()
872            .intersect(py_df.df.as_ref().clone())?;
873        Ok(Self::new(new_df))
874    }
875
876    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
877    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
878        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
879        Ok(Self::new(new_df))
880    }
881
882    /// Write a `DataFrame` to a CSV file.
883    fn write_csv(
884        &self,
885        py: Python,
886        path: &str,
887        with_header: bool,
888        write_options: Option<PyDataFrameWriteOptions>,
889    ) -> PyDataFusionResult<()> {
890        let csv_options = CsvOptions {
891            has_header: Some(with_header),
892            ..Default::default()
893        };
894        let write_options = write_options
895            .map(DataFrameWriteOptions::from)
896            .unwrap_or_default();
897
898        wait_for_future(
899            py,
900            self.df
901                .as_ref()
902                .clone()
903                .write_csv(path, write_options, Some(csv_options)),
904        )??;
905        Ok(())
906    }
907
908    /// Write a `DataFrame` to a Parquet file.
909    #[pyo3(signature = (
910        path,
911        compression="zstd",
912        compression_level=None,
913        write_options=None,
914        ))]
915    fn write_parquet(
916        &self,
917        path: &str,
918        compression: &str,
919        compression_level: Option<u32>,
920        write_options: Option<PyDataFrameWriteOptions>,
921        py: Python,
922    ) -> PyDataFusionResult<()> {
923        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
924            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
925        }
926
927        let _validated = match compression.to_lowercase().as_str() {
928            "snappy" => Compression::SNAPPY,
929            "gzip" => Compression::GZIP(
930                GzipLevel::try_new(compression_level.unwrap_or(6))
931                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
932            ),
933            "brotli" => Compression::BROTLI(
934                BrotliLevel::try_new(verify_compression_level(compression_level)?)
935                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
936            ),
937            "zstd" => Compression::ZSTD(
938                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
939                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
940            ),
941            "lzo" => Compression::LZO,
942            "lz4" => Compression::LZ4,
943            "lz4_raw" => Compression::LZ4_RAW,
944            "uncompressed" => Compression::UNCOMPRESSED,
945            _ => {
946                return Err(PyDataFusionError::Common(format!(
947                    "Unrecognized compression type {compression}"
948                )));
949            }
950        };
951
952        let mut compression_string = compression.to_string();
953        if let Some(level) = compression_level {
954            compression_string.push_str(&format!("({level})"));
955        }
956
957        let mut options = TableParquetOptions::default();
958        options.global.compression = Some(compression_string);
959        let write_options = write_options
960            .map(DataFrameWriteOptions::from)
961            .unwrap_or_default();
962
963        wait_for_future(
964            py,
965            self.df
966                .as_ref()
967                .clone()
968                .write_parquet(path, write_options, Option::from(options)),
969        )??;
970        Ok(())
971    }
972
973    /// Write a `DataFrame` to a Parquet file, using advanced options.
974    fn write_parquet_with_options(
975        &self,
976        path: &str,
977        options: PyParquetWriterOptions,
978        column_specific_options: HashMap<String, PyParquetColumnOptions>,
979        write_options: Option<PyDataFrameWriteOptions>,
980        py: Python,
981    ) -> PyDataFusionResult<()> {
982        let table_options = TableParquetOptions {
983            global: options.options,
984            column_specific_options: column_specific_options
985                .into_iter()
986                .map(|(k, v)| (k, v.options))
987                .collect(),
988            ..Default::default()
989        };
990        let write_options = write_options
991            .map(DataFrameWriteOptions::from)
992            .unwrap_or_default();
993        wait_for_future(
994            py,
995            self.df.as_ref().clone().write_parquet(
996                path,
997                write_options,
998                Option::from(table_options),
999            ),
1000        )??;
1001        Ok(())
1002    }
1003
1004    /// Executes a query and writes the results to a partitioned JSON file.
1005    fn write_json(
1006        &self,
1007        path: &str,
1008        py: Python,
1009        write_options: Option<PyDataFrameWriteOptions>,
1010    ) -> PyDataFusionResult<()> {
1011        let write_options = write_options
1012            .map(DataFrameWriteOptions::from)
1013            .unwrap_or_default();
1014        wait_for_future(
1015            py,
1016            self.df
1017                .as_ref()
1018                .clone()
1019                .write_json(path, write_options, None),
1020        )??;
1021        Ok(())
1022    }
1023
1024    fn write_table(
1025        &self,
1026        py: Python,
1027        table_name: &str,
1028        write_options: Option<PyDataFrameWriteOptions>,
1029    ) -> PyDataFusionResult<()> {
1030        let write_options = write_options
1031            .map(DataFrameWriteOptions::from)
1032            .unwrap_or_default();
1033        wait_for_future(
1034            py,
1035            self.df
1036                .as_ref()
1037                .clone()
1038                .write_table(table_name, write_options),
1039        )??;
1040        Ok(())
1041    }
1042
1043    /// Convert to Arrow Table
1044    /// Collect the batches and pass to Arrow Table
1045    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1046        let batches = self.collect(py)?.into_pyobject(py)?;
1047
1048        // only use the DataFrame's schema if there are no batches, otherwise let the schema be
1049        // determined from the batches (avoids some inconsistencies with nullable columns)
1050        let args = if batches.len()? == 0 {
1051            let schema = self.schema().into_pyobject(py)?;
1052            PyTuple::new(py, &[batches, schema])?
1053        } else {
1054            PyTuple::new(py, &[batches])?
1055        };
1056
1057        // Instantiate pyarrow Table object and use its from_batches method
1058        let table_class = py.import("pyarrow")?.getattr("Table")?;
1059        let table: Py<PyAny> = table_class.call_method1("from_batches", args)?.into();
1060        Ok(table)
1061    }
1062
1063    #[pyo3(signature = (requested_schema=None))]
1064    fn __arrow_c_stream__<'py>(
1065        &'py self,
1066        py: Python<'py>,
1067        requested_schema: Option<Bound<'py, PyCapsule>>,
1068    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
1069        let df = self.df.as_ref().clone();
1070        let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1071
1072        let mut schema: Schema = self.df.schema().to_owned().as_arrow().clone();
1073        let mut projection: Option<SchemaRef> = None;
1074
1075        if let Some(schema_capsule) = requested_schema {
1076            validate_pycapsule(&schema_capsule, "arrow_schema")?;
1077
1078            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
1079            let desired_schema = Schema::try_from(schema_ptr)?;
1080
1081            schema = project_schema(schema, desired_schema)?;
1082            projection = Some(Arc::new(schema.clone()));
1083        }
1084
1085        let schema_ref = Arc::new(schema.clone());
1086
1087        let reader = PartitionedDataFrameStreamReader {
1088            streams,
1089            schema: schema_ref,
1090            projection,
1091            current: 0,
1092        };
1093        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1094
1095        // Create the Arrow stream and wrap it in a PyCapsule. The default
1096        // destructor provided by PyO3 will drop the stream unless ownership is
1097        // transferred to PyArrow during import.
1098        let stream = FFI_ArrowArrayStream::new(reader);
1099        let name = CString::new(ARROW_ARRAY_STREAM_NAME.to_bytes()).unwrap();
1100        let capsule = PyCapsule::new(py, stream, Some(name))?;
1101        Ok(capsule)
1102    }
1103
1104    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
1105        let df = self.df.as_ref().clone();
1106        let stream = spawn_future(py, async move { df.execute_stream().await })?;
1107        Ok(PyRecordBatchStream::new(stream))
1108    }
1109
1110    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
1111        let df = self.df.as_ref().clone();
1112        let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1113        Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
1114    }
1115
1116    /// Convert to pandas dataframe with pyarrow
1117    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
1118    fn to_pandas(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1119        let table = self.to_arrow_table(py)?;
1120
1121        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
1122        let result = table.call_method0(py, "to_pandas")?;
1123        Ok(result)
1124    }
1125
1126    /// Convert to Python list using pyarrow
1127    /// Each list item represents one row encoded as dictionary
1128    fn to_pylist(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1129        let table = self.to_arrow_table(py)?;
1130
1131        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
1132        let result = table.call_method0(py, "to_pylist")?;
1133        Ok(result)
1134    }
1135
1136    /// Convert to Python dictionary using pyarrow
1137    /// Each dictionary key is a column and the dictionary value represents the column values
1138    fn to_pydict(&self, py: Python) -> PyResult<Py<PyAny>> {
1139        let table = self.to_arrow_table(py)?;
1140
1141        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
1142        let result = table.call_method0(py, "to_pydict")?;
1143        Ok(result)
1144    }
1145
1146    /// Convert to polars dataframe with pyarrow
1147    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
1148    fn to_polars(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1149        let table = self.to_arrow_table(py)?;
1150        let dataframe = py.import("polars")?.getattr("DataFrame")?;
1151        let args = PyTuple::new(py, &[table])?;
1152        let result: Py<PyAny> = dataframe.call1(args)?.into();
1153        Ok(result)
1154    }
1155
1156    // Executes this DataFrame to get the total number of rows.
1157    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
1158        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
1159    }
1160
1161    /// Fill null values with a specified value for specific columns
1162    #[pyo3(signature = (value, columns=None))]
1163    fn fill_null(
1164        &self,
1165        value: Py<PyAny>,
1166        columns: Option<Vec<PyBackedStr>>,
1167        py: Python,
1168    ) -> PyDataFusionResult<Self> {
1169        let scalar_value = py_obj_to_scalar_value(py, value)?;
1170
1171        let cols = match columns {
1172            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
1173            None => Vec::new(), // Empty vector means fill null for all columns
1174        };
1175
1176        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
1177        Ok(Self::new(df))
1178    }
1179}
1180
1181#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
1182#[pyclass(frozen, eq, eq_int, name = "InsertOp", module = "datafusion")]
1183pub enum PyInsertOp {
1184    APPEND,
1185    REPLACE,
1186    OVERWRITE,
1187}
1188
1189impl From<PyInsertOp> for InsertOp {
1190    fn from(value: PyInsertOp) -> Self {
1191        match value {
1192            PyInsertOp::APPEND => InsertOp::Append,
1193            PyInsertOp::REPLACE => InsertOp::Replace,
1194            PyInsertOp::OVERWRITE => InsertOp::Overwrite,
1195        }
1196    }
1197}
1198
1199#[derive(Debug, Clone)]
1200#[pyclass(frozen, name = "DataFrameWriteOptions", module = "datafusion")]
1201pub struct PyDataFrameWriteOptions {
1202    insert_operation: InsertOp,
1203    single_file_output: bool,
1204    partition_by: Vec<String>,
1205    sort_by: Vec<SortExpr>,
1206}
1207
1208impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
1209    fn from(value: PyDataFrameWriteOptions) -> Self {
1210        DataFrameWriteOptions::new()
1211            .with_insert_operation(value.insert_operation)
1212            .with_single_file_output(value.single_file_output)
1213            .with_partition_by(value.partition_by)
1214            .with_sort_by(value.sort_by)
1215    }
1216}
1217
1218#[pymethods]
1219impl PyDataFrameWriteOptions {
1220    #[new]
1221    fn new(
1222        insert_operation: Option<PyInsertOp>,
1223        single_file_output: bool,
1224        partition_by: Option<Vec<String>>,
1225        sort_by: Option<Vec<PySortExpr>>,
1226    ) -> Self {
1227        let insert_operation = insert_operation.map(Into::into).unwrap_or(InsertOp::Append);
1228        let sort_by = sort_by
1229            .unwrap_or_default()
1230            .into_iter()
1231            .map(Into::into)
1232            .collect();
1233        Self {
1234            insert_operation,
1235            single_file_output,
1236            partition_by: partition_by.unwrap_or_default(),
1237            sort_by,
1238        }
1239    }
1240}
1241
1242/// Print DataFrame
1243fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
1244    // Get string representation of record batches
1245    let batches = wait_for_future(py, df.collect())??;
1246    let result = if batches.is_empty() {
1247        "DataFrame has no rows".to_string()
1248    } else {
1249        match pretty::pretty_format_batches(&batches) {
1250            Ok(batch) => format!("DataFrame()\n{batch}"),
1251            Err(err) => format!("Error: {:?}", err.to_string()),
1252        }
1253    };
1254
1255    // Import the Python 'builtins' module to access the print function
1256    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
1257    let print = py.import("builtins")?.getattr("print")?;
1258    print.call1((result,))?;
1259    Ok(())
1260}
1261
1262fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1263    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1264
1265    let project_indices: Vec<usize> = to_schema
1266        .fields
1267        .iter()
1268        .map(|field| field.name())
1269        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1270        .collect();
1271
1272    merged_schema.project(&project_indices)
1273}
1274// NOTE: `arrow::compute::cast` in combination with `RecordBatch::try_select` or
1275// DataFusion's `schema::cast_record_batch` do not fully cover the required
1276// transformations here. They will not create missing columns and may insert
1277// nulls for non-nullable fields without erroring. To maintain current behavior
1278// we perform the casting and null checks manually.
1279fn record_batch_into_schema(
1280    record_batch: RecordBatch,
1281    schema: &Schema,
1282) -> Result<RecordBatch, ArrowError> {
1283    let schema = Arc::new(schema.clone());
1284    let base_schema = record_batch.schema();
1285    if base_schema.fields().is_empty() {
1286        // Nothing to project
1287        return Ok(RecordBatch::new_empty(schema));
1288    }
1289
1290    let array_size = record_batch.column(0).len();
1291    let mut data_arrays = Vec::with_capacity(schema.fields().len());
1292
1293    for field in schema.fields() {
1294        let desired_data_type = field.data_type();
1295        if let Some(original_data) = record_batch.column_by_name(field.name()) {
1296            let original_data_type = original_data.data_type();
1297
1298            if can_cast_types(original_data_type, desired_data_type) {
1299                data_arrays.push(arrow::compute::kernels::cast(
1300                    original_data,
1301                    desired_data_type,
1302                )?);
1303            } else if field.is_nullable() {
1304                data_arrays.push(new_null_array(desired_data_type, array_size));
1305            } else {
1306                return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
1307            }
1308        } else {
1309            if !field.is_nullable() {
1310                return Err(ArrowError::CastError(format!(
1311                    "Attempting to set null to non-nullable field {} during schema projection.",
1312                    field.name()
1313                )));
1314            }
1315            data_arrays.push(new_null_array(desired_data_type, array_size));
1316        }
1317    }
1318
1319    RecordBatch::try_new(schema, data_arrays)
1320}
1321
1322/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
1323/// It additionally returns a bool, which indicates if there are more record batches available.
1324/// We do this so we can determine if we should indicate to the user that the data has been
1325/// truncated. This collects until we have archived both of these two conditions
1326///
1327/// - We have collected our minimum number of rows
1328/// - We have reached our limit, either data size or maximum number of rows
1329///
1330/// Otherwise it will return when the stream has exhausted. If you want a specific number of
1331/// rows, set min_rows == max_rows.
1332async fn collect_record_batches_to_display(
1333    df: DataFrame,
1334    config: FormatterConfig,
1335) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1336    let FormatterConfig {
1337        max_bytes,
1338        min_rows,
1339        repr_rows,
1340    } = config;
1341
1342    let partitioned_stream = df.execute_stream_partitioned().await?;
1343    let mut stream = futures::stream::iter(partitioned_stream).flatten();
1344    let mut size_estimate_so_far = 0;
1345    let mut rows_so_far = 0;
1346    let mut record_batches = Vec::default();
1347    let mut has_more = false;
1348
1349    // ensure minimum rows even if memory/row limits are hit
1350    while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
1351        let mut rb = match stream.next().await {
1352            None => {
1353                break;
1354            }
1355            Some(Ok(r)) => r,
1356            Some(Err(e)) => return Err(e),
1357        };
1358
1359        let mut rows_in_rb = rb.num_rows();
1360        if rows_in_rb > 0 {
1361            size_estimate_so_far += rb.get_array_memory_size();
1362
1363            if size_estimate_so_far > max_bytes {
1364                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1365                let total_rows = rows_in_rb + rows_so_far;
1366
1367                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1368                if reduced_row_num < min_rows {
1369                    reduced_row_num = min_rows.min(total_rows);
1370                }
1371
1372                let limited_rows_this_rb = reduced_row_num - rows_so_far;
1373                if limited_rows_this_rb < rows_in_rb {
1374                    rows_in_rb = limited_rows_this_rb;
1375                    rb = rb.slice(0, limited_rows_this_rb);
1376                    has_more = true;
1377                }
1378            }
1379
1380            if rows_in_rb + rows_so_far > repr_rows {
1381                rb = rb.slice(0, repr_rows - rows_so_far);
1382                has_more = true;
1383            }
1384
1385            rows_so_far += rb.num_rows();
1386            record_batches.push(rb);
1387        }
1388    }
1389
1390    if record_batches.is_empty() {
1391        return Ok((Vec::default(), false));
1392    }
1393
1394    if !has_more {
1395        // Data was not already truncated, so check to see if more record batches remain
1396        has_more = match stream.try_next().await {
1397            Ok(None) => false, // reached end
1398            Ok(Some(_)) => true,
1399            Err(_) => false, // Stream disconnected
1400        };
1401    }
1402
1403    Ok((record_batches, has_more))
1404}