Skip to main content

datafusion_python/
dataframe.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::ffi::{CStr, CString};
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::{Array, ArrayRef, RecordBatch, RecordBatchReader, new_null_array};
24use arrow::compute::can_cast_types;
25use arrow::error::ArrowError;
26use arrow::ffi::FFI_ArrowSchema;
27use arrow::ffi_stream::FFI_ArrowArrayStream;
28use arrow::pyarrow::FromPyArrow;
29use cstr::cstr;
30use datafusion::arrow::datatypes::{Schema, SchemaRef};
31use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
32use datafusion::arrow::util::pretty;
33use datafusion::catalog::TableProvider;
34use datafusion::common::UnnestOptions;
35use datafusion::config::{CsvOptions, ParquetColumnOptions, ParquetOptions, TableParquetOptions};
36use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
37use datafusion::error::DataFusionError;
38use datafusion::execution::SendableRecordBatchStream;
39use datafusion::logical_expr::SortExpr;
40use datafusion::logical_expr::dml::InsertOp;
41use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
42use datafusion::prelude::*;
43use futures::{StreamExt, TryStreamExt};
44use parking_lot::Mutex;
45use pyo3::PyErr;
46use pyo3::exceptions::PyValueError;
47use pyo3::prelude::*;
48use pyo3::pybacked::PyBackedStr;
49use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
50
51use crate::common::data_type::PyScalarValue;
52use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
53use crate::expr::PyExpr;
54use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
55use crate::physical_plan::PyExecutionPlan;
56use crate::record_batch::{PyRecordBatchStream, poll_next_batch};
57use crate::sql::logical::PyLogicalPlan;
58use crate::table::{PyTable, TempViewTable};
59use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};
60
61/// File-level static CStr for the Arrow array stream capsule name.
62static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
63
64// Type aliases to simplify very complex types used in this file and
65// avoid compiler complaints about deeply nested types in struct fields.
66type CachedBatches = Option<(Vec<RecordBatch>, bool)>;
67type SharedCachedBatches = Arc<Mutex<CachedBatches>>;
68
69/// Configuration for DataFrame display formatting
70#[derive(Debug, Clone)]
71pub struct FormatterConfig {
72    /// Maximum memory in bytes to use for display (default: 2MB)
73    pub max_bytes: usize,
74    /// Minimum number of rows to display (default: 10)
75    pub min_rows: usize,
76    /// Maximum number of rows to include in __repr__ output (default: 10)
77    pub max_rows: usize,
78}
79
80impl Default for FormatterConfig {
81    fn default() -> Self {
82        Self {
83            max_bytes: 2 * 1024 * 1024, // 2MB
84            min_rows: 10,
85            max_rows: 10,
86        }
87    }
88}
89
90impl FormatterConfig {
91    /// Validates that all configuration values are positive integers.
92    ///
93    /// # Returns
94    ///
95    /// `Ok(())` if all values are valid, or an `Err` with a descriptive error message.
96    pub fn validate(&self) -> Result<(), String> {
97        if self.max_bytes == 0 {
98            return Err("max_bytes must be a positive integer".to_string());
99        }
100
101        if self.min_rows == 0 {
102            return Err("min_rows must be a positive integer".to_string());
103        }
104
105        if self.max_rows == 0 {
106            return Err("max_rows must be a positive integer".to_string());
107        }
108
109        if self.min_rows > self.max_rows {
110            return Err("min_rows must be less than or equal to max_rows".to_string());
111        }
112
113        Ok(())
114    }
115}
116
117/// Holds the Python formatter and its configuration
118struct PythonFormatter<'py> {
119    /// The Python formatter object
120    formatter: Bound<'py, PyAny>,
121    /// The formatter configuration
122    config: FormatterConfig,
123}
124
125/// Get the Python formatter and its configuration
126fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
127    let formatter = import_python_formatter(py)?;
128    let config = build_formatter_config_from_python(&formatter)?;
129    Ok(PythonFormatter { formatter, config })
130}
131
132/// Get the Python formatter from the datafusion.dataframe_formatter module
133fn import_python_formatter(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> {
134    let formatter_module = py.import("datafusion.dataframe_formatter")?;
135    let get_formatter = formatter_module.getattr("get_formatter")?;
136    get_formatter.call0()
137}
138
139// Helper function to extract attributes with fallback to default
140fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
141where
142    T: for<'py> pyo3::FromPyObject<'py> + Clone,
143{
144    py_object
145        .getattr(attr_name)
146        .and_then(|v| v.extract::<T>())
147        .unwrap_or_else(|_| default_value.clone())
148}
149
150/// Helper function to create a FormatterConfig from a Python formatter object
151fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
152    let default_config = FormatterConfig::default();
153    let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
154    let min_rows = get_attr(formatter, "min_rows", default_config.min_rows);
155
156    // Backward compatibility: Try max_rows first (new name), fall back to repr_rows (deprecated),
157    // then use default. This ensures backward compatibility with custom formatter implementations
158    // during the deprecation period.
159    let max_rows = get_attr(formatter, "max_rows", 0usize);
160    let max_rows = if max_rows > 0 {
161        // max_rows attribute exists and has a value
162        max_rows
163    } else {
164        // Try the deprecated repr_rows attribute
165        let repr_rows = get_attr(formatter, "repr_rows", 0usize);
166        if repr_rows > 0 {
167            repr_rows
168        } else {
169            // Use default
170            default_config.max_rows
171        }
172    };
173
174    let config = FormatterConfig {
175        max_bytes,
176        min_rows,
177        max_rows,
178    };
179
180    // Return the validated config, converting String error to PyErr
181    config.validate().map_err(PyValueError::new_err)?;
182    Ok(config)
183}
184
185/// Python mapping of `ParquetOptions` (includes just the writer-related options).
186#[pyclass(frozen, name = "ParquetWriterOptions", module = "datafusion", subclass)]
187#[derive(Clone, Default)]
188pub struct PyParquetWriterOptions {
189    options: ParquetOptions,
190}
191
192#[pymethods]
193impl PyParquetWriterOptions {
194    #[new]
195    #[allow(clippy::too_many_arguments)]
196    pub fn new(
197        data_pagesize_limit: usize,
198        write_batch_size: usize,
199        writer_version: &str,
200        skip_arrow_metadata: bool,
201        compression: Option<String>,
202        dictionary_enabled: Option<bool>,
203        dictionary_page_size_limit: usize,
204        statistics_enabled: Option<String>,
205        max_row_group_size: usize,
206        created_by: String,
207        column_index_truncate_length: Option<usize>,
208        statistics_truncate_length: Option<usize>,
209        data_page_row_count_limit: usize,
210        encoding: Option<String>,
211        bloom_filter_on_write: bool,
212        bloom_filter_fpp: Option<f64>,
213        bloom_filter_ndv: Option<u64>,
214        allow_single_file_parallelism: bool,
215        maximum_parallel_row_group_writers: usize,
216        maximum_buffered_record_batches_per_stream: usize,
217    ) -> PyResult<Self> {
218        let writer_version =
219            datafusion::common::parquet_config::DFParquetWriterVersion::from_str(writer_version)
220                .map_err(py_datafusion_err)?;
221        Ok(Self {
222            options: ParquetOptions {
223                data_pagesize_limit,
224                write_batch_size,
225                writer_version,
226                skip_arrow_metadata,
227                compression,
228                dictionary_enabled,
229                dictionary_page_size_limit,
230                statistics_enabled,
231                max_row_group_size,
232                created_by,
233                column_index_truncate_length,
234                statistics_truncate_length,
235                data_page_row_count_limit,
236                encoding,
237                bloom_filter_on_write,
238                bloom_filter_fpp,
239                bloom_filter_ndv,
240                allow_single_file_parallelism,
241                maximum_parallel_row_group_writers,
242                maximum_buffered_record_batches_per_stream,
243                ..Default::default()
244            },
245        })
246    }
247}
248
249/// Python mapping of `ParquetColumnOptions`.
250#[pyclass(frozen, name = "ParquetColumnOptions", module = "datafusion", subclass)]
251#[derive(Clone, Default)]
252pub struct PyParquetColumnOptions {
253    options: ParquetColumnOptions,
254}
255
256#[pymethods]
257impl PyParquetColumnOptions {
258    #[new]
259    pub fn new(
260        bloom_filter_enabled: Option<bool>,
261        encoding: Option<String>,
262        dictionary_enabled: Option<bool>,
263        compression: Option<String>,
264        statistics_enabled: Option<String>,
265        bloom_filter_fpp: Option<f64>,
266        bloom_filter_ndv: Option<u64>,
267    ) -> Self {
268        Self {
269            options: ParquetColumnOptions {
270                bloom_filter_enabled,
271                encoding,
272                dictionary_enabled,
273                compression,
274                statistics_enabled,
275                bloom_filter_fpp,
276                bloom_filter_ndv,
277            },
278        }
279    }
280}
281
282/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
283/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
284/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
285#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)]
286#[derive(Clone)]
287pub struct PyDataFrame {
288    df: Arc<DataFrame>,
289
290    // In IPython environment cache batches between __repr__ and _repr_html_ calls.
291    batches: SharedCachedBatches,
292}
293
294impl PyDataFrame {
295    /// creates a new PyDataFrame
296    pub fn new(df: DataFrame) -> Self {
297        Self {
298            df: Arc::new(df),
299            batches: Arc::new(Mutex::new(None)),
300        }
301    }
302
303    /// Return a clone of the inner Arc<DataFrame> for crate-local callers.
304    pub(crate) fn inner_df(&self) -> Arc<DataFrame> {
305        Arc::clone(&self.df)
306    }
307
308    fn prepare_repr_string<'py>(
309        &self,
310        py: Python<'py>,
311        as_html: bool,
312    ) -> PyDataFusionResult<String> {
313        // Get the Python formatter and config
314        let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
315
316        let is_ipython = *is_ipython_env(py);
317
318        let (cached_batches, should_cache) = {
319            let mut cache = self.batches.lock();
320            let should_cache = is_ipython && cache.is_none();
321            let batches = cache.take();
322            (batches, should_cache)
323        };
324
325        let (batches, has_more) = match cached_batches {
326            Some(b) => b,
327            None => wait_for_future(
328                py,
329                collect_record_batches_to_display(self.df.as_ref().clone(), config),
330            )??,
331        };
332
333        if batches.is_empty() {
334            // This should not be reached, but do it for safety since we index into the vector below
335            return Ok("No data to display".to_string());
336        }
337
338        let table_uuid = uuid::Uuid::new_v4().to_string();
339
340        // Convert record batches to Py<PyAny> list
341        let py_batches = batches
342            .iter()
343            .map(|rb| rb.to_pyarrow(py))
344            .collect::<PyResult<Vec<Bound<'py, PyAny>>>>()?;
345
346        let py_schema = self.schema().into_pyobject(py)?;
347
348        let kwargs = pyo3::types::PyDict::new(py);
349        let py_batches_list = PyList::new(py, py_batches.as_slice())?;
350        kwargs.set_item("batches", py_batches_list)?;
351        kwargs.set_item("schema", py_schema)?;
352        kwargs.set_item("has_more", has_more)?;
353        kwargs.set_item("table_uuid", table_uuid)?;
354
355        let method_name = match as_html {
356            true => "format_html",
357            false => "format_str",
358        };
359
360        let html_result = formatter.call_method(method_name, (), Some(&kwargs))?;
361        let html_str: String = html_result.extract()?;
362
363        if should_cache {
364            let mut cache = self.batches.lock();
365            *cache = Some((batches.clone(), has_more));
366        }
367
368        Ok(html_str)
369    }
370
371    async fn collect_column_inner(&self, column: &str) -> Result<ArrayRef, DataFusionError> {
372        let batches = self
373            .df
374            .as_ref()
375            .clone()
376            .select_columns(&[column])?
377            .collect()
378            .await?;
379
380        let arrays = batches
381            .iter()
382            .map(|b| b.column(0).as_ref())
383            .collect::<Vec<_>>();
384
385        arrow_select::concat::concat(&arrays).map_err(Into::into)
386    }
387}
388
389/// Synchronous wrapper around partitioned [`SendableRecordBatchStream`]s used
390/// for the `__arrow_c_stream__` implementation.
391///
392/// It drains each partition's stream sequentially, yielding record batches in
393/// their original partition order. When a `projection` is set, each batch is
394/// converted via `record_batch_into_schema` to apply schema changes per batch.
395struct PartitionedDataFrameStreamReader {
396    streams: Vec<SendableRecordBatchStream>,
397    schema: SchemaRef,
398    projection: Option<SchemaRef>,
399    current: usize,
400}
401
402impl Iterator for PartitionedDataFrameStreamReader {
403    type Item = Result<RecordBatch, ArrowError>;
404
405    fn next(&mut self) -> Option<Self::Item> {
406        while self.current < self.streams.len() {
407            let stream = &mut self.streams[self.current];
408            let fut = poll_next_batch(stream);
409            let result = Python::attach(|py| wait_for_future(py, fut));
410
411            match result {
412                Ok(Ok(Some(batch))) => {
413                    let batch = if let Some(ref schema) = self.projection {
414                        match record_batch_into_schema(batch, schema.as_ref()) {
415                            Ok(b) => b,
416                            Err(e) => return Some(Err(e)),
417                        }
418                    } else {
419                        batch
420                    };
421                    return Some(Ok(batch));
422                }
423                Ok(Ok(None)) => {
424                    self.current += 1;
425                    continue;
426                }
427                Ok(Err(e)) => {
428                    return Some(Err(ArrowError::ExternalError(Box::new(e))));
429                }
430                Err(e) => {
431                    return Some(Err(ArrowError::ExternalError(Box::new(e))));
432                }
433            }
434        }
435
436        None
437    }
438}
439
440impl RecordBatchReader for PartitionedDataFrameStreamReader {
441    fn schema(&self) -> SchemaRef {
442        self.schema.clone()
443    }
444}
445
446#[pymethods]
447impl PyDataFrame {
448    /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]`
449    fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
450        if let Ok(key) = key.extract::<PyBackedStr>() {
451            // df[col]
452            self.select_columns(vec![key])
453        } else if let Ok(tuple) = key.downcast::<PyTuple>() {
454            // df[col1, col2, col3]
455            let keys = tuple
456                .iter()
457                .map(|item| item.extract::<PyBackedStr>())
458                .collect::<PyResult<Vec<PyBackedStr>>>()?;
459            self.select_columns(keys)
460        } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
461            // df[[col1, col2, col3]]
462            self.select_columns(keys)
463        } else {
464            let message = "DataFrame can only be indexed by string index or indices".to_string();
465            Err(PyDataFusionError::Common(message))
466        }
467    }
468
469    fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
470        self.prepare_repr_string(py, false)
471    }
472
473    #[staticmethod]
474    #[expect(unused_variables)]
475    fn default_str_repr<'py>(
476        batches: Vec<Bound<'py, PyAny>>,
477        schema: &Bound<'py, PyAny>,
478        has_more: bool,
479        table_uuid: &str,
480    ) -> PyResult<String> {
481        let batches = batches
482            .into_iter()
483            .map(|batch| RecordBatch::from_pyarrow_bound(&batch))
484            .collect::<PyResult<Vec<RecordBatch>>>()?
485            .into_iter()
486            .filter(|batch| batch.num_rows() > 0)
487            .collect::<Vec<_>>();
488
489        if batches.is_empty() {
490            return Ok("No data to display".to_owned());
491        }
492
493        let batches_as_displ =
494            pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
495
496        let additional_str = match has_more {
497            true => "\nData truncated.",
498            false => "",
499        };
500
501        Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
502    }
503
504    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
505        self.prepare_repr_string(py, true)
506    }
507
508    /// Calculate summary statistics for a DataFrame
509    fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
510        let df = self.df.as_ref().clone();
511        let stat_df = wait_for_future(py, df.describe())??;
512        Ok(Self::new(stat_df))
513    }
514
515    /// Returns the schema from the logical plan
516    fn schema(&self) -> PyArrowType<Schema> {
517        PyArrowType(self.df.schema().as_arrow().clone())
518    }
519
520    /// Convert this DataFrame into a Table Provider that can be used in register_table
521    /// By convention, into_... methods consume self and return the new object.
522    /// Disabling the clippy lint, so we can use &self
523    /// because we're working with Python bindings
524    /// where objects are shared
525    #[allow(clippy::wrong_self_convention)]
526    pub fn into_view(&self, temporary: bool) -> PyDataFusionResult<PyTable> {
527        let table_provider = if temporary {
528            Arc::new(TempViewTable::new(Arc::clone(&self.df))) as Arc<dyn TableProvider>
529        } else {
530            // Call the underlying Rust DataFrame::into_view method.
531            // Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
532            // so that we don't invalidate this PyDataFrame.
533            self.df.as_ref().clone().into_view()
534        };
535        Ok(PyTable::from(table_provider))
536    }
537
538    #[pyo3(signature = (*args))]
539    fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
540        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
541        let df = self.df.as_ref().clone().select_columns(&args)?;
542        Ok(Self::new(df))
543    }
544
545    #[pyo3(signature = (*args))]
546    fn select_exprs(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
547        let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
548        let df = self.df.as_ref().clone().select_exprs(&args)?;
549        Ok(Self::new(df))
550    }
551
552    #[pyo3(signature = (*args))]
553    fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
554        let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
555        let df = self.df.as_ref().clone().select(expr)?;
556        Ok(Self::new(df))
557    }
558
559    #[pyo3(signature = (*args))]
560    fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
561        let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
562        let df = self.df.as_ref().clone().drop_columns(&cols)?;
563        Ok(Self::new(df))
564    }
565
566    fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
567        let df = self.df.as_ref().clone().filter(predicate.into())?;
568        Ok(Self::new(df))
569    }
570
571    fn parse_sql_expr(&self, expr: PyBackedStr) -> PyDataFusionResult<PyExpr> {
572        self.df
573            .as_ref()
574            .parse_sql_expr(&expr)
575            .map(PyExpr::from)
576            .map_err(PyDataFusionError::from)
577    }
578
579    fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
580        let df = self.df.as_ref().clone().with_column(name, expr.into())?;
581        Ok(Self::new(df))
582    }
583
584    fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
585        let mut df = self.df.as_ref().clone();
586        for expr in exprs {
587            let expr: Expr = expr.into();
588            let name = format!("{}", expr.schema_name());
589            df = df.with_column(name.as_str(), expr)?
590        }
591        Ok(Self::new(df))
592    }
593
594    /// Rename one column by applying a new projection. This is a no-op if the column to be
595    /// renamed does not exist.
596    fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
597        let df = self
598            .df
599            .as_ref()
600            .clone()
601            .with_column_renamed(old_name, new_name)?;
602        Ok(Self::new(df))
603    }
604
605    fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
606        let group_by = group_by.into_iter().map(|e| e.into()).collect();
607        let aggs = aggs.into_iter().map(|e| e.into()).collect();
608        let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
609        Ok(Self::new(df))
610    }
611
612    #[pyo3(signature = (*exprs))]
613    fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
614        let exprs = to_sort_expressions(exprs);
615        let df = self.df.as_ref().clone().sort(exprs)?;
616        Ok(Self::new(df))
617    }
618
619    #[pyo3(signature = (count, offset=0))]
620    fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
621        let df = self.df.as_ref().clone().limit(offset, Some(count))?;
622        Ok(Self::new(df))
623    }
624
625    /// Executes the plan, returning a list of `RecordBatch`es.
626    /// Unless some order is specified in the plan, there is no
627    /// guarantee of the order of the result.
628    fn collect<'py>(&self, py: Python<'py>) -> PyResult<Vec<Bound<'py, PyAny>>> {
629        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
630            .map_err(PyDataFusionError::from)?;
631        // cannot use PyResult<Vec<RecordBatch>> return type due to
632        // https://github.com/PyO3/pyo3/issues/1813
633        batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
634    }
635
636    /// Cache DataFrame.
637    fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
638        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
639        Ok(Self::new(df))
640    }
641
642    /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
643    /// maintaining the input partitioning.
644    fn collect_partitioned<'py>(&self, py: Python<'py>) -> PyResult<Vec<Vec<Bound<'py, PyAny>>>> {
645        let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?
646            .map_err(PyDataFusionError::from)?;
647
648        batches
649            .into_iter()
650            .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
651            .collect()
652    }
653
654    fn collect_column<'py>(&self, py: Python<'py>, column: &str) -> PyResult<Bound<'py, PyAny>> {
655        wait_for_future(py, self.collect_column_inner(column))?
656            .map_err(PyDataFusionError::from)?
657            .to_data()
658            .to_pyarrow(py)
659    }
660
661    /// Print the result, 20 lines by default
662    #[pyo3(signature = (num=20))]
663    fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
664        let df = self.df.as_ref().clone().limit(0, Some(num))?;
665        print_dataframe(py, df)
666    }
667
668    /// Filter out duplicate rows
669    fn distinct(&self) -> PyDataFusionResult<Self> {
670        let df = self.df.as_ref().clone().distinct()?;
671        Ok(Self::new(df))
672    }
673
674    fn join(
675        &self,
676        right: PyDataFrame,
677        how: &str,
678        left_on: Vec<PyBackedStr>,
679        right_on: Vec<PyBackedStr>,
680        coalesce_keys: bool,
681    ) -> PyDataFusionResult<Self> {
682        let join_type = match how {
683            "inner" => JoinType::Inner,
684            "left" => JoinType::Left,
685            "right" => JoinType::Right,
686            "full" => JoinType::Full,
687            "semi" => JoinType::LeftSemi,
688            "anti" => JoinType::LeftAnti,
689            how => {
690                return Err(PyDataFusionError::Common(format!(
691                    "The join type {how} does not exist or is not implemented"
692                )));
693            }
694        };
695
696        let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
697        let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
698
699        let mut df = self.df.as_ref().clone().join(
700            right.df.as_ref().clone(),
701            join_type,
702            &left_keys,
703            &right_keys,
704            None,
705        )?;
706
707        if coalesce_keys {
708            let mutual_keys = left_keys
709                .iter()
710                .zip(right_keys.iter())
711                .filter(|(l, r)| l == r)
712                .map(|(key, _)| *key)
713                .collect::<Vec<_>>();
714
715            let fields_to_coalesce = mutual_keys
716                .iter()
717                .map(|name| {
718                    let qualified_fields = df
719                        .logical_plan()
720                        .schema()
721                        .qualified_fields_with_unqualified_name(name);
722                    (*name, qualified_fields)
723                })
724                .filter(|(_, fields)| fields.len() == 2)
725                .collect::<Vec<_>>();
726
727            let expr: Vec<Expr> = df
728                .logical_plan()
729                .schema()
730                .fields()
731                .into_iter()
732                .enumerate()
733                .map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
734                .filter_map(|(qualifier, field)| {
735                    if let Some((key_name, qualified_fields)) = fields_to_coalesce
736                        .iter()
737                        .find(|(_, qf)| qf.contains(&(qualifier, field)))
738                    {
739                        // Only add the coalesce expression once (when we encounter the first field)
740                        // Skip the second field (it's already included in to coalesce)
741                        if (qualifier, field) == qualified_fields[0] {
742                            let left_col = Expr::Column(Column::from(qualified_fields[0]));
743                            let right_col = Expr::Column(Column::from(qualified_fields[1]));
744                            return Some(coalesce(vec![left_col, right_col]).alias(*key_name));
745                        }
746                        None
747                    } else {
748                        Some(Expr::Column(Column::from((qualifier, field))))
749                    }
750                })
751                .collect();
752            df = df.select(expr)?;
753        }
754
755        Ok(Self::new(df))
756    }
757
758    fn join_on(
759        &self,
760        right: PyDataFrame,
761        on_exprs: Vec<PyExpr>,
762        how: &str,
763    ) -> PyDataFusionResult<Self> {
764        let join_type = match how {
765            "inner" => JoinType::Inner,
766            "left" => JoinType::Left,
767            "right" => JoinType::Right,
768            "full" => JoinType::Full,
769            "semi" => JoinType::LeftSemi,
770            "anti" => JoinType::LeftAnti,
771            how => {
772                return Err(PyDataFusionError::Common(format!(
773                    "The join type {how} does not exist or is not implemented"
774                )));
775            }
776        };
777        let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
778
779        let df = self
780            .df
781            .as_ref()
782            .clone()
783            .join_on(right.df.as_ref().clone(), join_type, exprs)?;
784        Ok(Self::new(df))
785    }
786
787    /// Print the query plan
788    #[pyo3(signature = (verbose=false, analyze=false))]
789    fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
790        let df = self.df.as_ref().clone().explain(verbose, analyze)?;
791        print_dataframe(py, df)
792    }
793
794    /// Get the logical plan for this `DataFrame`
795    fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
796        Ok(self.df.as_ref().clone().logical_plan().clone().into())
797    }
798
799    /// Get the optimized logical plan for this `DataFrame`
800    fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
801        Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
802    }
803
804    /// Get the execution plan for this `DataFrame`
805    fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
806        let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())??;
807        Ok(plan.into())
808    }
809
810    /// Repartition a `DataFrame` based on a logical partitioning scheme.
811    fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
812        let new_df = self
813            .df
814            .as_ref()
815            .clone()
816            .repartition(Partitioning::RoundRobinBatch(num))?;
817        Ok(Self::new(new_df))
818    }
819
820    /// Repartition a `DataFrame` based on a logical partitioning scheme.
821    #[pyo3(signature = (*args, num))]
822    fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
823        let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
824        let new_df = self
825            .df
826            .as_ref()
827            .clone()
828            .repartition(Partitioning::Hash(expr, num))?;
829        Ok(Self::new(new_df))
830    }
831
832    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
833    /// two `DataFrame`s must have exactly the same schema
834    #[pyo3(signature = (py_df, distinct=false))]
835    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
836        let new_df = if distinct {
837            self.df
838                .as_ref()
839                .clone()
840                .union_distinct(py_df.df.as_ref().clone())?
841        } else {
842            self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
843        };
844
845        Ok(Self::new(new_df))
846    }
847
848    /// Calculate the distinct union of two `DataFrame`s.  The
849    /// two `DataFrame`s must have exactly the same schema
850    fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
851        let new_df = self
852            .df
853            .as_ref()
854            .clone()
855            .union_distinct(py_df.df.as_ref().clone())?;
856        Ok(Self::new(new_df))
857    }
858
859    #[pyo3(signature = (column, preserve_nulls=true))]
860    fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
861        // TODO: expose RecursionUnnestOptions
862        // REF: https://github.com/apache/datafusion/pull/11577
863        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
864        let df = self
865            .df
866            .as_ref()
867            .clone()
868            .unnest_columns_with_options(&[column], unnest_options)?;
869        Ok(Self::new(df))
870    }
871
872    #[pyo3(signature = (columns, preserve_nulls=true))]
873    fn unnest_columns(
874        &self,
875        columns: Vec<String>,
876        preserve_nulls: bool,
877    ) -> PyDataFusionResult<Self> {
878        // TODO: expose RecursionUnnestOptions
879        // REF: https://github.com/apache/datafusion/pull/11577
880        let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
881        let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
882        let df = self
883            .df
884            .as_ref()
885            .clone()
886            .unnest_columns_with_options(&cols, unnest_options)?;
887        Ok(Self::new(df))
888    }
889
890    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
891    fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
892        let new_df = self
893            .df
894            .as_ref()
895            .clone()
896            .intersect(py_df.df.as_ref().clone())?;
897        Ok(Self::new(new_df))
898    }
899
900    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s must have exactly the same schema
901    fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
902        let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
903        Ok(Self::new(new_df))
904    }
905
906    /// Write a `DataFrame` to a CSV file.
907    fn write_csv(
908        &self,
909        py: Python,
910        path: &str,
911        with_header: bool,
912        write_options: Option<PyDataFrameWriteOptions>,
913    ) -> PyDataFusionResult<()> {
914        let csv_options = CsvOptions {
915            has_header: Some(with_header),
916            ..Default::default()
917        };
918        let write_options = write_options
919            .map(DataFrameWriteOptions::from)
920            .unwrap_or_default();
921
922        wait_for_future(
923            py,
924            self.df
925                .as_ref()
926                .clone()
927                .write_csv(path, write_options, Some(csv_options)),
928        )??;
929        Ok(())
930    }
931
932    /// Write a `DataFrame` to a Parquet file.
933    #[pyo3(signature = (
934        path,
935        compression="zstd",
936        compression_level=None,
937        write_options=None,
938        ))]
939    fn write_parquet(
940        &self,
941        path: &str,
942        compression: &str,
943        compression_level: Option<u32>,
944        write_options: Option<PyDataFrameWriteOptions>,
945        py: Python,
946    ) -> PyDataFusionResult<()> {
947        fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
948            cl.ok_or(PyValueError::new_err("compression_level is not defined"))
949        }
950
951        let _validated = match compression.to_lowercase().as_str() {
952            "snappy" => Compression::SNAPPY,
953            "gzip" => Compression::GZIP(
954                GzipLevel::try_new(compression_level.unwrap_or(6))
955                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
956            ),
957            "brotli" => Compression::BROTLI(
958                BrotliLevel::try_new(verify_compression_level(compression_level)?)
959                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
960            ),
961            "zstd" => Compression::ZSTD(
962                ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
963                    .map_err(|e| PyValueError::new_err(format!("{e}")))?,
964            ),
965            "lzo" => Compression::LZO,
966            "lz4" => Compression::LZ4,
967            "lz4_raw" => Compression::LZ4_RAW,
968            "uncompressed" => Compression::UNCOMPRESSED,
969            _ => {
970                return Err(PyDataFusionError::Common(format!(
971                    "Unrecognized compression type {compression}"
972                )));
973            }
974        };
975
976        let mut compression_string = compression.to_string();
977        if let Some(level) = compression_level {
978            compression_string.push_str(&format!("({level})"));
979        }
980
981        let mut options = TableParquetOptions::default();
982        options.global.compression = Some(compression_string);
983        let write_options = write_options
984            .map(DataFrameWriteOptions::from)
985            .unwrap_or_default();
986
987        wait_for_future(
988            py,
989            self.df
990                .as_ref()
991                .clone()
992                .write_parquet(path, write_options, Option::from(options)),
993        )??;
994        Ok(())
995    }
996
997    /// Write a `DataFrame` to a Parquet file, using advanced options.
998    fn write_parquet_with_options(
999        &self,
1000        path: &str,
1001        options: PyParquetWriterOptions,
1002        column_specific_options: HashMap<String, PyParquetColumnOptions>,
1003        write_options: Option<PyDataFrameWriteOptions>,
1004        py: Python,
1005    ) -> PyDataFusionResult<()> {
1006        let table_options = TableParquetOptions {
1007            global: options.options,
1008            column_specific_options: column_specific_options
1009                .into_iter()
1010                .map(|(k, v)| (k, v.options))
1011                .collect(),
1012            ..Default::default()
1013        };
1014        let write_options = write_options
1015            .map(DataFrameWriteOptions::from)
1016            .unwrap_or_default();
1017        wait_for_future(
1018            py,
1019            self.df.as_ref().clone().write_parquet(
1020                path,
1021                write_options,
1022                Option::from(table_options),
1023            ),
1024        )??;
1025        Ok(())
1026    }
1027
1028    /// Executes a query and writes the results to a partitioned JSON file.
1029    fn write_json(
1030        &self,
1031        path: &str,
1032        py: Python,
1033        write_options: Option<PyDataFrameWriteOptions>,
1034    ) -> PyDataFusionResult<()> {
1035        let write_options = write_options
1036            .map(DataFrameWriteOptions::from)
1037            .unwrap_or_default();
1038        wait_for_future(
1039            py,
1040            self.df
1041                .as_ref()
1042                .clone()
1043                .write_json(path, write_options, None),
1044        )??;
1045        Ok(())
1046    }
1047
1048    fn write_table(
1049        &self,
1050        py: Python,
1051        table_name: &str,
1052        write_options: Option<PyDataFrameWriteOptions>,
1053    ) -> PyDataFusionResult<()> {
1054        let write_options = write_options
1055            .map(DataFrameWriteOptions::from)
1056            .unwrap_or_default();
1057        wait_for_future(
1058            py,
1059            self.df
1060                .as_ref()
1061                .clone()
1062                .write_table(table_name, write_options),
1063        )??;
1064        Ok(())
1065    }
1066
1067    /// Convert to Arrow Table
1068    /// Collect the batches and pass to Arrow Table
1069    fn to_arrow_table(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1070        let batches = self.collect(py)?.into_pyobject(py)?;
1071
1072        // only use the DataFrame's schema if there are no batches, otherwise let the schema be
1073        // determined from the batches (avoids some inconsistencies with nullable columns)
1074        let args = if batches.len()? == 0 {
1075            let schema = self.schema().into_pyobject(py)?;
1076            PyTuple::new(py, &[batches, schema])?
1077        } else {
1078            PyTuple::new(py, &[batches])?
1079        };
1080
1081        // Instantiate pyarrow Table object and use its from_batches method
1082        let table_class = py.import("pyarrow")?.getattr("Table")?;
1083        let table: Py<PyAny> = table_class.call_method1("from_batches", args)?.into();
1084        Ok(table)
1085    }
1086
1087    #[pyo3(signature = (requested_schema=None))]
1088    fn __arrow_c_stream__<'py>(
1089        &'py self,
1090        py: Python<'py>,
1091        requested_schema: Option<Bound<'py, PyCapsule>>,
1092    ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
1093        let df = self.df.as_ref().clone();
1094        let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1095
1096        let mut schema: Schema = self.df.schema().to_owned().as_arrow().clone();
1097        let mut projection: Option<SchemaRef> = None;
1098
1099        if let Some(schema_capsule) = requested_schema {
1100            validate_pycapsule(&schema_capsule, "arrow_schema")?;
1101
1102            let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
1103            let desired_schema = Schema::try_from(schema_ptr)?;
1104
1105            schema = project_schema(schema, desired_schema)?;
1106            projection = Some(Arc::new(schema.clone()));
1107        }
1108
1109        let schema_ref = Arc::new(schema.clone());
1110
1111        let reader = PartitionedDataFrameStreamReader {
1112            streams,
1113            schema: schema_ref,
1114            projection,
1115            current: 0,
1116        };
1117        let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
1118
1119        // Create the Arrow stream and wrap it in a PyCapsule. The default
1120        // destructor provided by PyO3 will drop the stream unless ownership is
1121        // transferred to PyArrow during import.
1122        let stream = FFI_ArrowArrayStream::new(reader);
1123        let name = CString::new(ARROW_ARRAY_STREAM_NAME.to_bytes()).unwrap();
1124        let capsule = PyCapsule::new(py, stream, Some(name))?;
1125        Ok(capsule)
1126    }
1127
1128    fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
1129        let df = self.df.as_ref().clone();
1130        let stream = spawn_future(py, async move { df.execute_stream().await })?;
1131        Ok(PyRecordBatchStream::new(stream))
1132    }
1133
1134    fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
1135        let df = self.df.as_ref().clone();
1136        let streams = spawn_future(py, async move { df.execute_stream_partitioned().await })?;
1137        Ok(streams.into_iter().map(PyRecordBatchStream::new).collect())
1138    }
1139
1140    /// Convert to pandas dataframe with pyarrow
1141    /// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
1142    fn to_pandas(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1143        let table = self.to_arrow_table(py)?;
1144
1145        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
1146        let result = table.call_method0(py, "to_pandas")?;
1147        Ok(result)
1148    }
1149
1150    /// Convert to Python list using pyarrow
1151    /// Each list item represents one row encoded as dictionary
1152    fn to_pylist(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1153        let table = self.to_arrow_table(py)?;
1154
1155        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
1156        let result = table.call_method0(py, "to_pylist")?;
1157        Ok(result)
1158    }
1159
1160    /// Convert to Python dictionary using pyarrow
1161    /// Each dictionary key is a column and the dictionary value represents the column values
1162    fn to_pydict(&self, py: Python) -> PyResult<Py<PyAny>> {
1163        let table = self.to_arrow_table(py)?;
1164
1165        // See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
1166        let result = table.call_method0(py, "to_pydict")?;
1167        Ok(result)
1168    }
1169
1170    /// Convert to polars dataframe with pyarrow
1171    /// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
1172    fn to_polars(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
1173        let table = self.to_arrow_table(py)?;
1174        let dataframe = py.import("polars")?.getattr("DataFrame")?;
1175        let args = PyTuple::new(py, &[table])?;
1176        let result: Py<PyAny> = dataframe.call1(args)?.into();
1177        Ok(result)
1178    }
1179
1180    // Executes this DataFrame to get the total number of rows.
1181    fn count(&self, py: Python) -> PyDataFusionResult<usize> {
1182        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
1183    }
1184
1185    /// Fill null values with a specified value for specific columns
1186    #[pyo3(signature = (value, columns=None))]
1187    fn fill_null(
1188        &self,
1189        value: Py<PyAny>,
1190        columns: Option<Vec<PyBackedStr>>,
1191        py: Python,
1192    ) -> PyDataFusionResult<Self> {
1193        let scalar_value: PyScalarValue = value.extract(py)?;
1194
1195        let cols = match columns {
1196            Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
1197            None => Vec::new(), // Empty vector means fill null for all columns
1198        };
1199
1200        let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
1201        Ok(Self::new(df))
1202    }
1203}
1204
1205#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
1206#[pyclass(frozen, eq, eq_int, name = "InsertOp", module = "datafusion")]
1207pub enum PyInsertOp {
1208    APPEND,
1209    REPLACE,
1210    OVERWRITE,
1211}
1212
1213impl From<PyInsertOp> for InsertOp {
1214    fn from(value: PyInsertOp) -> Self {
1215        match value {
1216            PyInsertOp::APPEND => InsertOp::Append,
1217            PyInsertOp::REPLACE => InsertOp::Replace,
1218            PyInsertOp::OVERWRITE => InsertOp::Overwrite,
1219        }
1220    }
1221}
1222
1223#[derive(Debug, Clone)]
1224#[pyclass(frozen, name = "DataFrameWriteOptions", module = "datafusion")]
1225pub struct PyDataFrameWriteOptions {
1226    insert_operation: InsertOp,
1227    single_file_output: bool,
1228    partition_by: Vec<String>,
1229    sort_by: Vec<SortExpr>,
1230}
1231
1232impl From<PyDataFrameWriteOptions> for DataFrameWriteOptions {
1233    fn from(value: PyDataFrameWriteOptions) -> Self {
1234        DataFrameWriteOptions::new()
1235            .with_insert_operation(value.insert_operation)
1236            .with_single_file_output(value.single_file_output)
1237            .with_partition_by(value.partition_by)
1238            .with_sort_by(value.sort_by)
1239    }
1240}
1241
1242#[pymethods]
1243impl PyDataFrameWriteOptions {
1244    #[new]
1245    fn new(
1246        insert_operation: Option<PyInsertOp>,
1247        single_file_output: bool,
1248        partition_by: Option<Vec<String>>,
1249        sort_by: Option<Vec<PySortExpr>>,
1250    ) -> Self {
1251        let insert_operation = insert_operation.map(Into::into).unwrap_or(InsertOp::Append);
1252        let sort_by = sort_by
1253            .unwrap_or_default()
1254            .into_iter()
1255            .map(Into::into)
1256            .collect();
1257        Self {
1258            insert_operation,
1259            single_file_output,
1260            partition_by: partition_by.unwrap_or_default(),
1261            sort_by,
1262        }
1263    }
1264}
1265
1266/// Print DataFrame
1267fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
1268    // Get string representation of record batches
1269    let batches = wait_for_future(py, df.collect())??;
1270    let result = if batches.is_empty() {
1271        "DataFrame has no rows".to_string()
1272    } else {
1273        match pretty::pretty_format_batches(&batches) {
1274            Ok(batch) => format!("DataFrame()\n{batch}"),
1275            Err(err) => format!("Error: {:?}", err.to_string()),
1276        }
1277    };
1278
1279    // Import the Python 'builtins' module to access the print function
1280    // Note that println! does not print to the Python debug console and is not visible in notebooks for instance
1281    let print = py.import("builtins")?.getattr("print")?;
1282    print.call1((result,))?;
1283    Ok(())
1284}
1285
1286fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
1287    let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
1288
1289    let project_indices: Vec<usize> = to_schema
1290        .fields
1291        .iter()
1292        .map(|field| field.name())
1293        .filter_map(|field_name| merged_schema.index_of(field_name).ok())
1294        .collect();
1295
1296    merged_schema.project(&project_indices)
1297}
1298// NOTE: `arrow::compute::cast` in combination with `RecordBatch::try_select` or
1299// DataFusion's `schema::cast_record_batch` do not fully cover the required
1300// transformations here. They will not create missing columns and may insert
1301// nulls for non-nullable fields without erroring. To maintain current behavior
1302// we perform the casting and null checks manually.
1303fn record_batch_into_schema(
1304    record_batch: RecordBatch,
1305    schema: &Schema,
1306) -> Result<RecordBatch, ArrowError> {
1307    let schema = Arc::new(schema.clone());
1308    let base_schema = record_batch.schema();
1309    if base_schema.fields().is_empty() {
1310        // Nothing to project
1311        return Ok(RecordBatch::new_empty(schema));
1312    }
1313
1314    let array_size = record_batch.column(0).len();
1315    let mut data_arrays = Vec::with_capacity(schema.fields().len());
1316
1317    for field in schema.fields() {
1318        let desired_data_type = field.data_type();
1319        if let Some(original_data) = record_batch.column_by_name(field.name()) {
1320            let original_data_type = original_data.data_type();
1321
1322            if can_cast_types(original_data_type, desired_data_type) {
1323                data_arrays.push(arrow::compute::kernels::cast(
1324                    original_data,
1325                    desired_data_type,
1326                )?);
1327            } else if field.is_nullable() {
1328                data_arrays.push(new_null_array(desired_data_type, array_size));
1329            } else {
1330                return Err(ArrowError::CastError(format!(
1331                    "Attempting to cast to non-nullable and non-castable field {} during schema projection.",
1332                    field.name()
1333                )));
1334            }
1335        } else {
1336            if !field.is_nullable() {
1337                return Err(ArrowError::CastError(format!(
1338                    "Attempting to set null to non-nullable field {} during schema projection.",
1339                    field.name()
1340                )));
1341            }
1342            data_arrays.push(new_null_array(desired_data_type, array_size));
1343        }
1344    }
1345
1346    RecordBatch::try_new(schema, data_arrays)
1347}
1348
1349/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
1350/// It additionally returns a bool, which indicates if there are more record batches available.
1351/// We do this so we can determine if we should indicate to the user that the data has been
1352/// truncated. This collects until we have archived both of these two conditions
1353///
1354/// - We have collected our minimum number of rows
1355/// - We have reached our limit, either data size or maximum number of rows
1356///
1357/// Otherwise it will return when the stream has exhausted. If you want a specific number of
1358/// rows, set min_rows == max_rows.
1359async fn collect_record_batches_to_display(
1360    df: DataFrame,
1361    config: FormatterConfig,
1362) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
1363    let FormatterConfig {
1364        max_bytes,
1365        min_rows,
1366        max_rows,
1367    } = config;
1368
1369    let partitioned_stream = df.execute_stream_partitioned().await?;
1370    let mut stream = futures::stream::iter(partitioned_stream).flatten();
1371    let mut size_estimate_so_far = 0;
1372    let mut rows_so_far = 0;
1373    let mut record_batches = Vec::default();
1374    let mut has_more = false;
1375
1376    // Collect rows until we hit a limit (memory or max_rows) OR reach the guaranteed minimum.
1377    // The minimum rows constraint overrides both memory and row limits to ensure a baseline
1378    // of data is always displayed, even if it temporarily exceeds those limits.
1379    // This provides better UX by guaranteeing users see at least min_rows rows.
1380    while (size_estimate_so_far < max_bytes && rows_so_far < max_rows) || rows_so_far < min_rows {
1381        let mut rb = match stream.next().await {
1382            None => {
1383                break;
1384            }
1385            Some(Ok(r)) => r,
1386            Some(Err(e)) => return Err(e),
1387        };
1388
1389        let mut rows_in_rb = rb.num_rows();
1390        if rows_in_rb > 0 {
1391            size_estimate_so_far += rb.get_array_memory_size();
1392
1393            // When memory limit is exceeded, scale back row count proportionally to stay within budget
1394            if size_estimate_so_far > max_bytes {
1395                let ratio = max_bytes as f32 / size_estimate_so_far as f32;
1396                let total_rows = rows_in_rb + rows_so_far;
1397
1398                // Calculate reduced rows maintaining the memory/data proportion
1399                let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
1400                // Ensure we always respect the minimum rows guarantee
1401                if reduced_row_num < min_rows {
1402                    reduced_row_num = min_rows.min(total_rows);
1403                }
1404
1405                let limited_rows_this_rb = reduced_row_num - rows_so_far;
1406                if limited_rows_this_rb < rows_in_rb {
1407                    rows_in_rb = limited_rows_this_rb;
1408                    rb = rb.slice(0, limited_rows_this_rb);
1409                    has_more = true;
1410                }
1411            }
1412
1413            if rows_in_rb + rows_so_far > max_rows {
1414                rb = rb.slice(0, max_rows - rows_so_far);
1415                has_more = true;
1416            }
1417
1418            rows_so_far += rb.num_rows();
1419            record_batches.push(rb);
1420        }
1421    }
1422
1423    if record_batches.is_empty() {
1424        return Ok((Vec::default(), false));
1425    }
1426
1427    if !has_more {
1428        // Data was not already truncated, so check to see if more record batches remain
1429        has_more = match stream.try_next().await {
1430            Ok(None) => false, // reached end
1431            Ok(Some(_)) => true,
1432            Err(_) => false, // Stream disconnected
1433        };
1434    }
1435
1436    Ok((record_batches, has_more))
1437}