1use std::ffi::CString;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22use arrow::compute::can_cast_types;
23use arrow::error::ArrowError;
24use arrow::ffi::FFI_ArrowSchema;
25use arrow::ffi_stream::FFI_ArrowArrayStream;
26use datafusion::arrow::datatypes::Schema;
27use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
28use datafusion::arrow::util::pretty;
29use datafusion::common::UnnestOptions;
30use datafusion::config::{CsvOptions, TableParquetOptions};
31use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
32use datafusion::datasource::TableProvider;
33use datafusion::error::DataFusionError;
34use datafusion::execution::SendableRecordBatchStream;
35use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
36use datafusion::prelude::*;
37use datafusion_ffi::table_provider::FFI_TableProvider;
38use futures::{StreamExt, TryStreamExt};
39use pyo3::exceptions::PyValueError;
40use pyo3::prelude::*;
41use pyo3::pybacked::PyBackedStr;
42use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
43use tokio::task::JoinHandle;
44
45use crate::catalog::PyTable;
46use crate::errors::{py_datafusion_err, PyDataFusionError};
47use crate::expr::sort_expr::to_sort_expressions;
48use crate::physical_plan::PyExecutionPlan;
49use crate::record_batch::PyRecordBatchStream;
50use crate::sql::logical::PyLogicalPlan;
51use crate::utils::{
52 get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
53};
54use crate::{
55 errors::PyDataFusionResult,
56 expr::{sort_expr::PySortExpr, PyExpr},
57};
58
59#[pyclass(name = "TableProvider", module = "datafusion")]
63pub struct PyTableProvider {
64 provider: Arc<dyn TableProvider + Send>,
65}
66
67impl PyTableProvider {
68 pub fn new(provider: Arc<dyn TableProvider>) -> Self {
69 Self { provider }
70 }
71
72 pub fn as_table(&self) -> PyTable {
73 let table_provider: Arc<dyn TableProvider> = self.provider.clone();
74 PyTable::new(table_provider)
75 }
76}
77
78#[pymethods]
79impl PyTableProvider {
80 fn __datafusion_table_provider__<'py>(
81 &self,
82 py: Python<'py>,
83 ) -> PyResult<Bound<'py, PyCapsule>> {
84 let name = CString::new("datafusion_table_provider").unwrap();
85
86 let runtime = get_tokio_runtime().0.handle().clone();
87 let provider = FFI_TableProvider::new(Arc::clone(&self.provider), false, Some(runtime));
88
89 PyCapsule::new(py, provider, Some(name.clone()))
90 }
91}
92
93#[derive(Debug, Clone)]
95pub struct FormatterConfig {
96 pub max_bytes: usize,
98 pub min_rows: usize,
100 pub repr_rows: usize,
102}
103
104impl Default for FormatterConfig {
105 fn default() -> Self {
106 Self {
107 max_bytes: 2 * 1024 * 1024, min_rows: 20,
109 repr_rows: 10,
110 }
111 }
112}
113
114impl FormatterConfig {
115 pub fn validate(&self) -> Result<(), String> {
121 if self.max_bytes == 0 {
122 return Err("max_bytes must be a positive integer".to_string());
123 }
124
125 if self.min_rows == 0 {
126 return Err("min_rows must be a positive integer".to_string());
127 }
128
129 if self.repr_rows == 0 {
130 return Err("repr_rows must be a positive integer".to_string());
131 }
132
133 Ok(())
134 }
135}
136
137struct PythonFormatter<'py> {
139 formatter: Bound<'py, PyAny>,
141 config: FormatterConfig,
143}
144
145fn get_python_formatter_with_config(py: Python) -> PyResult<PythonFormatter> {
147 let formatter = import_python_formatter(py)?;
148 let config = build_formatter_config_from_python(&formatter)?;
149 Ok(PythonFormatter { formatter, config })
150}
151
152fn import_python_formatter(py: Python) -> PyResult<Bound<'_, PyAny>> {
154 let formatter_module = py.import("datafusion.html_formatter")?;
155 let get_formatter = formatter_module.getattr("get_formatter")?;
156 get_formatter.call0()
157}
158
159fn get_attr<'a, T>(py_object: &'a Bound<'a, PyAny>, attr_name: &str, default_value: T) -> T
161where
162 T: for<'py> pyo3::FromPyObject<'py> + Clone,
163{
164 py_object
165 .getattr(attr_name)
166 .and_then(|v| v.extract::<T>())
167 .unwrap_or_else(|_| default_value.clone())
168}
169
170fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<FormatterConfig> {
172 let default_config = FormatterConfig::default();
173 let max_bytes = get_attr(formatter, "max_memory_bytes", default_config.max_bytes);
174 let min_rows = get_attr(formatter, "min_rows_display", default_config.min_rows);
175 let repr_rows = get_attr(formatter, "repr_rows", default_config.repr_rows);
176
177 let config = FormatterConfig {
178 max_bytes,
179 min_rows,
180 repr_rows,
181 };
182
183 config.validate().map_err(PyValueError::new_err)?;
185 Ok(config)
186}
187
188#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
192#[derive(Clone)]
193pub struct PyDataFrame {
194 df: Arc<DataFrame>,
195}
196
197impl PyDataFrame {
198 pub fn new(df: DataFrame) -> Self {
200 Self { df: Arc::new(df) }
201 }
202}
203
204#[pymethods]
205impl PyDataFrame {
206 fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
208 if let Ok(key) = key.extract::<PyBackedStr>() {
209 self.select_columns(vec![key])
211 } else if let Ok(tuple) = key.downcast::<PyTuple>() {
212 let keys = tuple
214 .iter()
215 .map(|item| item.extract::<PyBackedStr>())
216 .collect::<PyResult<Vec<PyBackedStr>>>()?;
217 self.select_columns(keys)
218 } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
219 self.select_columns(keys)
221 } else {
222 let message = "DataFrame can only be indexed by string index or indices".to_string();
223 Err(PyDataFusionError::Common(message))
224 }
225 }
226
227 fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
228 let PythonFormatter {
230 formatter: _,
231 config,
232 } = get_python_formatter_with_config(py)?;
233 let (batches, has_more) = wait_for_future(
234 py,
235 collect_record_batches_to_display(self.df.as_ref().clone(), config),
236 )?;
237 if batches.is_empty() {
238 return Ok("No data to display".to_string());
240 }
241
242 let batches_as_displ =
243 pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
244
245 let additional_str = match has_more {
246 true => "\nData truncated.",
247 false => "",
248 };
249
250 Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
251 }
252
253 fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
254 let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
256 let (batches, has_more) = wait_for_future(
257 py,
258 collect_record_batches_to_display(self.df.as_ref().clone(), config),
259 )?;
260 if batches.is_empty() {
261 return Ok("No data to display".to_string());
263 }
264
265 let table_uuid = uuid::Uuid::new_v4().to_string();
266
267 let py_batches = batches
269 .into_iter()
270 .map(|rb| rb.to_pyarrow(py))
271 .collect::<PyResult<Vec<PyObject>>>()?;
272
273 let py_schema = self.schema().into_pyobject(py)?;
274
275 let kwargs = pyo3::types::PyDict::new(py);
276 let py_batches_list = PyList::new(py, py_batches.as_slice())?;
277 kwargs.set_item("batches", py_batches_list)?;
278 kwargs.set_item("schema", py_schema)?;
279 kwargs.set_item("has_more", has_more)?;
280 kwargs.set_item("table_uuid", table_uuid)?;
281
282 let html_result = formatter.call_method("format_html", (), Some(&kwargs))?;
283 let html_str: String = html_result.extract()?;
284
285 Ok(html_str)
286 }
287
288 fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
290 let df = self.df.as_ref().clone();
291 let stat_df = wait_for_future(py, df.describe())?;
292 Ok(Self::new(stat_df))
293 }
294
295 fn schema(&self) -> PyArrowType<Schema> {
297 PyArrowType(self.df.schema().into())
298 }
299
300 #[allow(clippy::wrong_self_convention)]
308 fn into_view(&self) -> PyDataFusionResult<PyTable> {
309 let table_provider = self.df.as_ref().clone().into_view();
313 let table_provider = PyTableProvider::new(table_provider);
314
315 Ok(table_provider.as_table())
316 }
317
318 #[pyo3(signature = (*args))]
319 fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
320 let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
321 let df = self.df.as_ref().clone().select_columns(&args)?;
322 Ok(Self::new(df))
323 }
324
325 #[pyo3(signature = (*args))]
326 fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
327 let expr: Vec<Expr> = args.into_iter().map(|e| e.into()).collect();
328 let df = self.df.as_ref().clone().select(expr)?;
329 Ok(Self::new(df))
330 }
331
332 #[pyo3(signature = (*args))]
333 fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
334 let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
335 let df = self.df.as_ref().clone().drop_columns(&cols)?;
336 Ok(Self::new(df))
337 }
338
339 fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
340 let df = self.df.as_ref().clone().filter(predicate.into())?;
341 Ok(Self::new(df))
342 }
343
344 fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
345 let df = self.df.as_ref().clone().with_column(name, expr.into())?;
346 Ok(Self::new(df))
347 }
348
349 fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
350 let mut df = self.df.as_ref().clone();
351 for expr in exprs {
352 let expr: Expr = expr.into();
353 let name = format!("{}", expr.schema_name());
354 df = df.with_column(name.as_str(), expr)?
355 }
356 Ok(Self::new(df))
357 }
358
359 fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
362 let df = self
363 .df
364 .as_ref()
365 .clone()
366 .with_column_renamed(old_name, new_name)?;
367 Ok(Self::new(df))
368 }
369
370 fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
371 let group_by = group_by.into_iter().map(|e| e.into()).collect();
372 let aggs = aggs.into_iter().map(|e| e.into()).collect();
373 let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
374 Ok(Self::new(df))
375 }
376
377 #[pyo3(signature = (*exprs))]
378 fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
379 let exprs = to_sort_expressions(exprs);
380 let df = self.df.as_ref().clone().sort(exprs)?;
381 Ok(Self::new(df))
382 }
383
384 #[pyo3(signature = (count, offset=0))]
385 fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
386 let df = self.df.as_ref().clone().limit(offset, Some(count))?;
387 Ok(Self::new(df))
388 }
389
390 fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
394 let batches = wait_for_future(py, self.df.as_ref().clone().collect())
395 .map_err(PyDataFusionError::from)?;
396 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
399 }
400
401 fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
403 let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
404 Ok(Self::new(df))
405 }
406
407 fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
410 let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
411 .map_err(PyDataFusionError::from)?;
412
413 batches
414 .into_iter()
415 .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
416 .collect()
417 }
418
419 #[pyo3(signature = (num=20))]
421 fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
422 let df = self.df.as_ref().clone().limit(0, Some(num))?;
423 print_dataframe(py, df)
424 }
425
426 fn distinct(&self) -> PyDataFusionResult<Self> {
428 let df = self.df.as_ref().clone().distinct()?;
429 Ok(Self::new(df))
430 }
431
432 fn join(
433 &self,
434 right: PyDataFrame,
435 how: &str,
436 left_on: Vec<PyBackedStr>,
437 right_on: Vec<PyBackedStr>,
438 ) -> PyDataFusionResult<Self> {
439 let join_type = match how {
440 "inner" => JoinType::Inner,
441 "left" => JoinType::Left,
442 "right" => JoinType::Right,
443 "full" => JoinType::Full,
444 "semi" => JoinType::LeftSemi,
445 "anti" => JoinType::LeftAnti,
446 how => {
447 return Err(PyDataFusionError::Common(format!(
448 "The join type {how} does not exist or is not implemented"
449 )));
450 }
451 };
452
453 let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
454 let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
455
456 let df = self.df.as_ref().clone().join(
457 right.df.as_ref().clone(),
458 join_type,
459 &left_keys,
460 &right_keys,
461 None,
462 )?;
463 Ok(Self::new(df))
464 }
465
466 fn join_on(
467 &self,
468 right: PyDataFrame,
469 on_exprs: Vec<PyExpr>,
470 how: &str,
471 ) -> PyDataFusionResult<Self> {
472 let join_type = match how {
473 "inner" => JoinType::Inner,
474 "left" => JoinType::Left,
475 "right" => JoinType::Right,
476 "full" => JoinType::Full,
477 "semi" => JoinType::LeftSemi,
478 "anti" => JoinType::LeftAnti,
479 how => {
480 return Err(PyDataFusionError::Common(format!(
481 "The join type {how} does not exist or is not implemented"
482 )));
483 }
484 };
485 let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
486
487 let df = self
488 .df
489 .as_ref()
490 .clone()
491 .join_on(right.df.as_ref().clone(), join_type, exprs)?;
492 Ok(Self::new(df))
493 }
494
495 #[pyo3(signature = (verbose=false, analyze=false))]
497 fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
498 let df = self.df.as_ref().clone().explain(verbose, analyze)?;
499 print_dataframe(py, df)
500 }
501
502 fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
504 Ok(self.df.as_ref().clone().logical_plan().clone().into())
505 }
506
507 fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
509 Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
510 }
511
512 fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
514 let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
515 Ok(plan.into())
516 }
517
518 fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
520 let new_df = self
521 .df
522 .as_ref()
523 .clone()
524 .repartition(Partitioning::RoundRobinBatch(num))?;
525 Ok(Self::new(new_df))
526 }
527
528 #[pyo3(signature = (*args, num))]
530 fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
531 let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
532 let new_df = self
533 .df
534 .as_ref()
535 .clone()
536 .repartition(Partitioning::Hash(expr, num))?;
537 Ok(Self::new(new_df))
538 }
539
540 #[pyo3(signature = (py_df, distinct=false))]
543 fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
544 let new_df = if distinct {
545 self.df
546 .as_ref()
547 .clone()
548 .union_distinct(py_df.df.as_ref().clone())?
549 } else {
550 self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
551 };
552
553 Ok(Self::new(new_df))
554 }
555
556 fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
559 let new_df = self
560 .df
561 .as_ref()
562 .clone()
563 .union_distinct(py_df.df.as_ref().clone())?;
564 Ok(Self::new(new_df))
565 }
566
567 #[pyo3(signature = (column, preserve_nulls=true))]
568 fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
569 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
572 let df = self
573 .df
574 .as_ref()
575 .clone()
576 .unnest_columns_with_options(&[column], unnest_options)?;
577 Ok(Self::new(df))
578 }
579
580 #[pyo3(signature = (columns, preserve_nulls=true))]
581 fn unnest_columns(
582 &self,
583 columns: Vec<String>,
584 preserve_nulls: bool,
585 ) -> PyDataFusionResult<Self> {
586 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
589 let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
590 let df = self
591 .df
592 .as_ref()
593 .clone()
594 .unnest_columns_with_options(&cols, unnest_options)?;
595 Ok(Self::new(df))
596 }
597
598 fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
600 let new_df = self
601 .df
602 .as_ref()
603 .clone()
604 .intersect(py_df.df.as_ref().clone())?;
605 Ok(Self::new(new_df))
606 }
607
608 fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
610 let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
611 Ok(Self::new(new_df))
612 }
613
614 fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
616 let csv_options = CsvOptions {
617 has_header: Some(with_header),
618 ..Default::default()
619 };
620 wait_for_future(
621 py,
622 self.df.as_ref().clone().write_csv(
623 path,
624 DataFrameWriteOptions::new(),
625 Some(csv_options),
626 ),
627 )?;
628 Ok(())
629 }
630
631 #[pyo3(signature = (
633 path,
634 compression="zstd",
635 compression_level=None
636 ))]
637 fn write_parquet(
638 &self,
639 path: &str,
640 compression: &str,
641 compression_level: Option<u32>,
642 py: Python,
643 ) -> PyDataFusionResult<()> {
644 fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
645 cl.ok_or(PyValueError::new_err("compression_level is not defined"))
646 }
647
648 let _validated = match compression.to_lowercase().as_str() {
649 "snappy" => Compression::SNAPPY,
650 "gzip" => Compression::GZIP(
651 GzipLevel::try_new(compression_level.unwrap_or(6))
652 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
653 ),
654 "brotli" => Compression::BROTLI(
655 BrotliLevel::try_new(verify_compression_level(compression_level)?)
656 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
657 ),
658 "zstd" => Compression::ZSTD(
659 ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
660 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
661 ),
662 "lzo" => Compression::LZO,
663 "lz4" => Compression::LZ4,
664 "lz4_raw" => Compression::LZ4_RAW,
665 "uncompressed" => Compression::UNCOMPRESSED,
666 _ => {
667 return Err(PyDataFusionError::Common(format!(
668 "Unrecognized compression type {compression}"
669 )));
670 }
671 };
672
673 let mut compression_string = compression.to_string();
674 if let Some(level) = compression_level {
675 compression_string.push_str(&format!("({level})"));
676 }
677
678 let mut options = TableParquetOptions::default();
679 options.global.compression = Some(compression_string);
680
681 wait_for_future(
682 py,
683 self.df.as_ref().clone().write_parquet(
684 path,
685 DataFrameWriteOptions::new(),
686 Option::from(options),
687 ),
688 )?;
689 Ok(())
690 }
691
692 fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
694 wait_for_future(
695 py,
696 self.df
697 .as_ref()
698 .clone()
699 .write_json(path, DataFrameWriteOptions::new(), None),
700 )?;
701 Ok(())
702 }
703
704 fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
707 let batches = self.collect(py)?.into_pyobject(py)?;
708 let schema = self.schema().into_pyobject(py)?;
709
710 let table_class = py.import("pyarrow")?.getattr("Table")?;
712 let args = PyTuple::new(py, &[batches, schema])?;
713 let table: PyObject = table_class.call_method1("from_batches", args)?.into();
714 Ok(table)
715 }
716
717 #[pyo3(signature = (requested_schema=None))]
718 fn __arrow_c_stream__<'py>(
719 &'py mut self,
720 py: Python<'py>,
721 requested_schema: Option<Bound<'py, PyCapsule>>,
722 ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
723 let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
724 let mut schema: Schema = self.df.schema().to_owned().into();
725
726 if let Some(schema_capsule) = requested_schema {
727 validate_pycapsule(&schema_capsule, "arrow_schema")?;
728
729 let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
730 let desired_schema = Schema::try_from(schema_ptr)?;
731
732 schema = project_schema(schema, desired_schema)?;
733
734 batches = batches
735 .into_iter()
736 .map(|record_batch| record_batch_into_schema(record_batch, &schema))
737 .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
738 }
739
740 let batches_wrapped = batches.into_iter().map(Ok);
741
742 let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
743 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
744
745 let ffi_stream = FFI_ArrowArrayStream::new(reader);
746 let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
747 PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
748 }
749
750 fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
751 let rt = &get_tokio_runtime().0;
753 let df = self.df.as_ref().clone();
754 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
755 rt.spawn(async move { df.execute_stream().await });
756 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
757 Ok(PyRecordBatchStream::new(stream?))
758 }
759
760 fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
761 let rt = &get_tokio_runtime().0;
763 let df = self.df.as_ref().clone();
764 let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
765 rt.spawn(async move { df.execute_stream_partitioned().await });
766 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
767
768 match stream {
769 Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
770 _ => Err(PyValueError::new_err(
771 "Unable to execute stream partitioned",
772 )),
773 }
774 }
775
776 fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
779 let table = self.to_arrow_table(py)?;
780
781 let result = table.call_method0(py, "to_pandas")?;
783 Ok(result)
784 }
785
786 fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
789 let table = self.to_arrow_table(py)?;
790
791 let result = table.call_method0(py, "to_pylist")?;
793 Ok(result)
794 }
795
796 fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
799 let table = self.to_arrow_table(py)?;
800
801 let result = table.call_method0(py, "to_pydict")?;
803 Ok(result)
804 }
805
806 fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
809 let table = self.to_arrow_table(py)?;
810 let dataframe = py.import("polars")?.getattr("DataFrame")?;
811 let args = PyTuple::new(py, &[table])?;
812 let result: PyObject = dataframe.call1(args)?.into();
813 Ok(result)
814 }
815
816 fn count(&self, py: Python) -> PyDataFusionResult<usize> {
818 Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
819 }
820
821 #[pyo3(signature = (value, columns=None))]
823 fn fill_null(
824 &self,
825 value: PyObject,
826 columns: Option<Vec<PyBackedStr>>,
827 py: Python,
828 ) -> PyDataFusionResult<Self> {
829 let scalar_value = py_obj_to_scalar_value(py, value)?;
830
831 let cols = match columns {
832 Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
833 None => Vec::new(), };
835
836 let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
837 Ok(Self::new(df))
838 }
839}
840
841fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
843 let batches = wait_for_future(py, df.collect())?;
845 let batches_as_string = pretty::pretty_format_batches(&batches);
846 let result = match batches_as_string {
847 Ok(batch) => format!("DataFrame()\n{batch}"),
848 Err(err) => format!("Error: {:?}", err.to_string()),
849 };
850
851 let print = py.import("builtins")?.getattr("print")?;
854 print.call1((result,))?;
855 Ok(())
856}
857
858fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
859 let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
860
861 let project_indices: Vec<usize> = to_schema
862 .fields
863 .iter()
864 .map(|field| field.name())
865 .filter_map(|field_name| merged_schema.index_of(field_name).ok())
866 .collect();
867
868 merged_schema.project(&project_indices)
869}
870
871fn record_batch_into_schema(
872 record_batch: RecordBatch,
873 schema: &Schema,
874) -> Result<RecordBatch, ArrowError> {
875 let schema = Arc::new(schema.clone());
876 let base_schema = record_batch.schema();
877 if base_schema.fields().is_empty() {
878 return Ok(RecordBatch::new_empty(schema));
880 }
881
882 let array_size = record_batch.column(0).len();
883 let mut data_arrays = Vec::with_capacity(schema.fields().len());
884
885 for field in schema.fields() {
886 let desired_data_type = field.data_type();
887 if let Some(original_data) = record_batch.column_by_name(field.name()) {
888 let original_data_type = original_data.data_type();
889
890 if can_cast_types(original_data_type, desired_data_type) {
891 data_arrays.push(arrow::compute::kernels::cast(
892 original_data,
893 desired_data_type,
894 )?);
895 } else if field.is_nullable() {
896 data_arrays.push(new_null_array(desired_data_type, array_size));
897 } else {
898 return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
899 }
900 } else {
901 if !field.is_nullable() {
902 return Err(ArrowError::CastError(format!(
903 "Attempting to set null to non-nullable field {} during schema projection.",
904 field.name()
905 )));
906 }
907 data_arrays.push(new_null_array(desired_data_type, array_size));
908 }
909 }
910
911 RecordBatch::try_new(schema, data_arrays)
912}
913
914async fn collect_record_batches_to_display(
925 df: DataFrame,
926 config: FormatterConfig,
927) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
928 let FormatterConfig {
929 max_bytes,
930 min_rows,
931 repr_rows,
932 } = config;
933
934 let partitioned_stream = df.execute_stream_partitioned().await?;
935 let mut stream = futures::stream::iter(partitioned_stream).flatten();
936 let mut size_estimate_so_far = 0;
937 let mut rows_so_far = 0;
938 let mut record_batches = Vec::default();
939 let mut has_more = false;
940
941 while (size_estimate_so_far < max_bytes && rows_so_far < repr_rows) || rows_so_far < min_rows {
943 let mut rb = match stream.next().await {
944 None => {
945 break;
946 }
947 Some(Ok(r)) => r,
948 Some(Err(e)) => return Err(e),
949 };
950
951 let mut rows_in_rb = rb.num_rows();
952 if rows_in_rb > 0 {
953 size_estimate_so_far += rb.get_array_memory_size();
954
955 if size_estimate_so_far > max_bytes {
956 let ratio = max_bytes as f32 / size_estimate_so_far as f32;
957 let total_rows = rows_in_rb + rows_so_far;
958
959 let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
960 if reduced_row_num < min_rows {
961 reduced_row_num = min_rows.min(total_rows);
962 }
963
964 let limited_rows_this_rb = reduced_row_num - rows_so_far;
965 if limited_rows_this_rb < rows_in_rb {
966 rows_in_rb = limited_rows_this_rb;
967 rb = rb.slice(0, limited_rows_this_rb);
968 has_more = true;
969 }
970 }
971
972 if rows_in_rb + rows_so_far > repr_rows {
973 rb = rb.slice(0, repr_rows - rows_so_far);
974 has_more = true;
975 }
976
977 rows_so_far += rb.num_rows();
978 record_batches.push(rb);
979 }
980 }
981
982 if record_batches.is_empty() {
983 return Ok((Vec::default(), false));
984 }
985
986 if !has_more {
987 has_more = match stream.try_next().await {
989 Ok(None) => false, Ok(Some(_)) => true,
991 Err(_) => false, };
993 }
994
995 Ok((record_batches, has_more))
996}