1use std::ffi::CString;
19use std::sync::Arc;
20
21use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader};
22use arrow::compute::can_cast_types;
23use arrow::error::ArrowError;
24use arrow::ffi::FFI_ArrowSchema;
25use arrow::ffi_stream::FFI_ArrowArrayStream;
26use arrow::util::display::{ArrayFormatter, FormatOptions};
27use datafusion::arrow::datatypes::Schema;
28use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
29use datafusion::arrow::util::pretty;
30use datafusion::common::UnnestOptions;
31use datafusion::config::{CsvOptions, TableParquetOptions};
32use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33use datafusion::datasource::TableProvider;
34use datafusion::error::DataFusionError;
35use datafusion::execution::SendableRecordBatchStream;
36use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
37use datafusion::prelude::*;
38use futures::{StreamExt, TryStreamExt};
39use pyo3::exceptions::PyValueError;
40use pyo3::prelude::*;
41use pyo3::pybacked::PyBackedStr;
42use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
43use tokio::task::JoinHandle;
44
45use crate::catalog::PyTable;
46use crate::errors::{py_datafusion_err, PyDataFusionError};
47use crate::expr::sort_expr::to_sort_expressions;
48use crate::physical_plan::PyExecutionPlan;
49use crate::record_batch::PyRecordBatchStream;
50use crate::sql::logical::PyLogicalPlan;
51use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
52use crate::{
53 errors::PyDataFusionResult,
54 expr::{sort_expr::PySortExpr, PyExpr},
55};
56
57#[pyclass(name = "TableProvider", module = "datafusion")]
61pub struct PyTableProvider {
62 provider: Arc<dyn TableProvider>,
63}
64
65impl PyTableProvider {
66 pub fn new(provider: Arc<dyn TableProvider>) -> Self {
67 Self { provider }
68 }
69
70 pub fn as_table(&self) -> PyTable {
71 let table_provider: Arc<dyn TableProvider> = self.provider.clone();
72 PyTable::new(table_provider)
73 }
74}
75const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
77const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
78
79#[pyclass(name = "DataFrame", module = "datafusion", subclass)]
83#[derive(Clone)]
84pub struct PyDataFrame {
85 df: Arc<DataFrame>,
86}
87
88impl PyDataFrame {
89 pub fn new(df: DataFrame) -> Self {
91 Self { df: Arc::new(df) }
92 }
93}
94
95#[pymethods]
96impl PyDataFrame {
97 fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
99 if let Ok(key) = key.extract::<PyBackedStr>() {
100 self.select_columns(vec![key])
102 } else if let Ok(tuple) = key.downcast::<PyTuple>() {
103 let keys = tuple
105 .iter()
106 .map(|item| item.extract::<PyBackedStr>())
107 .collect::<PyResult<Vec<PyBackedStr>>>()?;
108 self.select_columns(keys)
109 } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
110 self.select_columns(keys)
112 } else {
113 let message = "DataFrame can only be indexed by string index or indices".to_string();
114 Err(PyDataFusionError::Common(message))
115 }
116 }
117
118 fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
119 let (batches, has_more) = wait_for_future(
120 py,
121 collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
122 )?;
123 if batches.is_empty() {
124 return Ok("No data to display".to_string());
126 }
127
128 let batches_as_displ =
129 pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;
130
131 let additional_str = match has_more {
132 true => "\nData truncated.",
133 false => "",
134 };
135
136 Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
137 }
138
139 fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
140 let (batches, has_more) = wait_for_future(
141 py,
142 collect_record_batches_to_display(
143 self.df.as_ref().clone(),
144 MIN_TABLE_ROWS_TO_DISPLAY,
145 usize::MAX,
146 ),
147 )?;
148 if batches.is_empty() {
149 return Ok("No data to display".to_string());
151 }
152
153 let table_uuid = uuid::Uuid::new_v4().to_string();
154
155 let mut html_str = "
156 <style>
157 .expandable-container {
158 display: inline-block;
159 max-width: 200px;
160 }
161 .expandable {
162 white-space: nowrap;
163 overflow: hidden;
164 text-overflow: ellipsis;
165 display: block;
166 }
167 .full-text {
168 display: none;
169 white-space: normal;
170 }
171 .expand-btn {
172 cursor: pointer;
173 color: blue;
174 text-decoration: underline;
175 border: none;
176 background: none;
177 font-size: inherit;
178 display: block;
179 margin-top: 5px;
180 }
181 </style>
182
183 <div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
184 <table style=\"border-collapse: collapse; min-width: 100%\">
185 <thead>\n".to_string();
186
187 let schema = batches[0].schema();
188
189 let mut header = Vec::new();
190 for field in schema.fields() {
191 header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
192 }
193 let header_str = header.join("");
194 html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));
195
196 let batch_formatters = batches
197 .iter()
198 .map(|batch| {
199 batch
200 .columns()
201 .iter()
202 .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
203 .map(|c| {
204 c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
205 })
206 .collect::<Result<Vec<_>, _>>()
207 })
208 .collect::<Result<Vec<_>, _>>()?;
209
210 let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
211
212 let mut table_row = 0;
214 for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
215 for batch_row in 0..num_rows_in_batch {
216 table_row += 1;
217 let mut cells = Vec::new();
218 for (col, formatter) in batch_formatter.iter().enumerate() {
219 let cell_data = formatter.value(batch_row).to_string();
220 if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
222 let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
223 cells.push(format!("
224 <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
225 <div class=\"expandable-container\">
226 <span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span>
227 <span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span>
228 <button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button>
229 </div>
230 </td>"));
231 } else {
232 cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row)));
233 }
234 }
235 let row_str = cells.join("");
236 html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
237 }
238 }
239 html_str.push_str("</tbody></table></div>\n");
240
241 html_str.push_str("
242 <script>
243 function toggleDataFrameCellText(table_uuid, row, col) {
244 var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col);
245 var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col);
246 var button = event.target;
247
248 if (fullText.style.display === \"none\") {
249 shortText.style.display = \"none\";
250 fullText.style.display = \"inline\";
251 button.textContent = \"(less)\";
252 } else {
253 shortText.style.display = \"inline\";
254 fullText.style.display = \"none\";
255 button.textContent = \"...\";
256 }
257 }
258 </script>
259 ");
260
261 if has_more {
262 html_str.push_str("Data truncated due to size.");
263 }
264
265 Ok(html_str)
266 }
267
268 fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
270 let df = self.df.as_ref().clone();
271 let stat_df = wait_for_future(py, df.describe())?;
272 Ok(Self::new(stat_df))
273 }
274
275 fn schema(&self) -> PyArrowType<Schema> {
277 PyArrowType(self.df.schema().into())
278 }
279
280 #[allow(clippy::wrong_self_convention)]
288 fn into_view(&self) -> PyDataFusionResult<PyTable> {
289 let table_provider = self.df.as_ref().clone().into_view();
293 let table_provider = PyTableProvider::new(table_provider);
294
295 Ok(table_provider.as_table())
296 }
297
298 #[pyo3(signature = (*args))]
299 fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
300 let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
301 let df = self.df.as_ref().clone().select_columns(&args)?;
302 Ok(Self::new(df))
303 }
304
305 #[pyo3(signature = (*args))]
306 fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
307 let expr = args.into_iter().map(|e| e.into()).collect();
308 let df = self.df.as_ref().clone().select(expr)?;
309 Ok(Self::new(df))
310 }
311
312 #[pyo3(signature = (*args))]
313 fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
314 let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
315 let df = self.df.as_ref().clone().drop_columns(&cols)?;
316 Ok(Self::new(df))
317 }
318
319 fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
320 let df = self.df.as_ref().clone().filter(predicate.into())?;
321 Ok(Self::new(df))
322 }
323
324 fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
325 let df = self.df.as_ref().clone().with_column(name, expr.into())?;
326 Ok(Self::new(df))
327 }
328
329 fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
330 let mut df = self.df.as_ref().clone();
331 for expr in exprs {
332 let expr: Expr = expr.into();
333 let name = format!("{}", expr.schema_name());
334 df = df.with_column(name.as_str(), expr)?
335 }
336 Ok(Self::new(df))
337 }
338
339 fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult<Self> {
342 let df = self
343 .df
344 .as_ref()
345 .clone()
346 .with_column_renamed(old_name, new_name)?;
347 Ok(Self::new(df))
348 }
349
350 fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
351 let group_by = group_by.into_iter().map(|e| e.into()).collect();
352 let aggs = aggs.into_iter().map(|e| e.into()).collect();
353 let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
354 Ok(Self::new(df))
355 }
356
357 #[pyo3(signature = (*exprs))]
358 fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
359 let exprs = to_sort_expressions(exprs);
360 let df = self.df.as_ref().clone().sort(exprs)?;
361 Ok(Self::new(df))
362 }
363
364 #[pyo3(signature = (count, offset=0))]
365 fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
366 let df = self.df.as_ref().clone().limit(offset, Some(count))?;
367 Ok(Self::new(df))
368 }
369
370 fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
374 let batches = wait_for_future(py, self.df.as_ref().clone().collect())
375 .map_err(PyDataFusionError::from)?;
376 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
379 }
380
381 fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
383 let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
384 Ok(Self::new(df))
385 }
386
387 fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
390 let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())
391 .map_err(PyDataFusionError::from)?;
392
393 batches
394 .into_iter()
395 .map(|rbs| rbs.into_iter().map(|rb| rb.to_pyarrow(py)).collect())
396 .collect()
397 }
398
399 #[pyo3(signature = (num=20))]
401 fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
402 let df = self.df.as_ref().clone().limit(0, Some(num))?;
403 print_dataframe(py, df)
404 }
405
406 fn distinct(&self) -> PyDataFusionResult<Self> {
408 let df = self.df.as_ref().clone().distinct()?;
409 Ok(Self::new(df))
410 }
411
412 fn join(
413 &self,
414 right: PyDataFrame,
415 how: &str,
416 left_on: Vec<PyBackedStr>,
417 right_on: Vec<PyBackedStr>,
418 ) -> PyDataFusionResult<Self> {
419 let join_type = match how {
420 "inner" => JoinType::Inner,
421 "left" => JoinType::Left,
422 "right" => JoinType::Right,
423 "full" => JoinType::Full,
424 "semi" => JoinType::LeftSemi,
425 "anti" => JoinType::LeftAnti,
426 how => {
427 return Err(PyDataFusionError::Common(format!(
428 "The join type {how} does not exist or is not implemented"
429 )));
430 }
431 };
432
433 let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
434 let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
435
436 let df = self.df.as_ref().clone().join(
437 right.df.as_ref().clone(),
438 join_type,
439 &left_keys,
440 &right_keys,
441 None,
442 )?;
443 Ok(Self::new(df))
444 }
445
446 fn join_on(
447 &self,
448 right: PyDataFrame,
449 on_exprs: Vec<PyExpr>,
450 how: &str,
451 ) -> PyDataFusionResult<Self> {
452 let join_type = match how {
453 "inner" => JoinType::Inner,
454 "left" => JoinType::Left,
455 "right" => JoinType::Right,
456 "full" => JoinType::Full,
457 "semi" => JoinType::LeftSemi,
458 "anti" => JoinType::LeftAnti,
459 how => {
460 return Err(PyDataFusionError::Common(format!(
461 "The join type {how} does not exist or is not implemented"
462 )));
463 }
464 };
465 let exprs: Vec<Expr> = on_exprs.into_iter().map(|e| e.into()).collect();
466
467 let df = self
468 .df
469 .as_ref()
470 .clone()
471 .join_on(right.df.as_ref().clone(), join_type, exprs)?;
472 Ok(Self::new(df))
473 }
474
475 #[pyo3(signature = (verbose=false, analyze=false))]
477 fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> {
478 let df = self.df.as_ref().clone().explain(verbose, analyze)?;
479 print_dataframe(py, df)
480 }
481
482 fn logical_plan(&self) -> PyResult<PyLogicalPlan> {
484 Ok(self.df.as_ref().clone().logical_plan().clone().into())
485 }
486
487 fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
489 Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
490 }
491
492 fn execution_plan(&self, py: Python) -> PyDataFusionResult<PyExecutionPlan> {
494 let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?;
495 Ok(plan.into())
496 }
497
498 fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
500 let new_df = self
501 .df
502 .as_ref()
503 .clone()
504 .repartition(Partitioning::RoundRobinBatch(num))?;
505 Ok(Self::new(new_df))
506 }
507
508 #[pyo3(signature = (*args, num))]
510 fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) -> PyDataFusionResult<Self> {
511 let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
512 let new_df = self
513 .df
514 .as_ref()
515 .clone()
516 .repartition(Partitioning::Hash(expr, num))?;
517 Ok(Self::new(new_df))
518 }
519
520 #[pyo3(signature = (py_df, distinct=false))]
523 fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult<Self> {
524 let new_df = if distinct {
525 self.df
526 .as_ref()
527 .clone()
528 .union_distinct(py_df.df.as_ref().clone())?
529 } else {
530 self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
531 };
532
533 Ok(Self::new(new_df))
534 }
535
536 fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
539 let new_df = self
540 .df
541 .as_ref()
542 .clone()
543 .union_distinct(py_df.df.as_ref().clone())?;
544 Ok(Self::new(new_df))
545 }
546
547 #[pyo3(signature = (column, preserve_nulls=true))]
548 fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult<Self> {
549 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
552 let df = self
553 .df
554 .as_ref()
555 .clone()
556 .unnest_columns_with_options(&[column], unnest_options)?;
557 Ok(Self::new(df))
558 }
559
560 #[pyo3(signature = (columns, preserve_nulls=true))]
561 fn unnest_columns(
562 &self,
563 columns: Vec<String>,
564 preserve_nulls: bool,
565 ) -> PyDataFusionResult<Self> {
566 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls);
569 let cols = columns.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
570 let df = self
571 .df
572 .as_ref()
573 .clone()
574 .unnest_columns_with_options(&cols, unnest_options)?;
575 Ok(Self::new(df))
576 }
577
578 fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
580 let new_df = self
581 .df
582 .as_ref()
583 .clone()
584 .intersect(py_df.df.as_ref().clone())?;
585 Ok(Self::new(new_df))
586 }
587
588 fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
590 let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
591 Ok(Self::new(new_df))
592 }
593
594 fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> {
596 let csv_options = CsvOptions {
597 has_header: Some(with_header),
598 ..Default::default()
599 };
600 wait_for_future(
601 py,
602 self.df.as_ref().clone().write_csv(
603 path,
604 DataFrameWriteOptions::new(),
605 Some(csv_options),
606 ),
607 )?;
608 Ok(())
609 }
610
611 #[pyo3(signature = (
613 path,
614 compression="zstd",
615 compression_level=None
616 ))]
617 fn write_parquet(
618 &self,
619 path: &str,
620 compression: &str,
621 compression_level: Option<u32>,
622 py: Python,
623 ) -> PyDataFusionResult<()> {
624 fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
625 cl.ok_or(PyValueError::new_err("compression_level is not defined"))
626 }
627
628 let _validated = match compression.to_lowercase().as_str() {
629 "snappy" => Compression::SNAPPY,
630 "gzip" => Compression::GZIP(
631 GzipLevel::try_new(compression_level.unwrap_or(6))
632 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
633 ),
634 "brotli" => Compression::BROTLI(
635 BrotliLevel::try_new(verify_compression_level(compression_level)?)
636 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
637 ),
638 "zstd" => Compression::ZSTD(
639 ZstdLevel::try_new(verify_compression_level(compression_level)? as i32)
640 .map_err(|e| PyValueError::new_err(format!("{e}")))?,
641 ),
642 "lzo" => Compression::LZO,
643 "lz4" => Compression::LZ4,
644 "lz4_raw" => Compression::LZ4_RAW,
645 "uncompressed" => Compression::UNCOMPRESSED,
646 _ => {
647 return Err(PyDataFusionError::Common(format!(
648 "Unrecognized compression type {compression}"
649 )));
650 }
651 };
652
653 let mut compression_string = compression.to_string();
654 if let Some(level) = compression_level {
655 compression_string.push_str(&format!("({level})"));
656 }
657
658 let mut options = TableParquetOptions::default();
659 options.global.compression = Some(compression_string);
660
661 wait_for_future(
662 py,
663 self.df.as_ref().clone().write_parquet(
664 path,
665 DataFrameWriteOptions::new(),
666 Option::from(options),
667 ),
668 )?;
669 Ok(())
670 }
671
672 fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
674 wait_for_future(
675 py,
676 self.df
677 .as_ref()
678 .clone()
679 .write_json(path, DataFrameWriteOptions::new(), None),
680 )?;
681 Ok(())
682 }
683
684 fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
687 let batches = self.collect(py)?.into_pyobject(py)?;
688 let schema = self.schema().into_pyobject(py)?;
689
690 let table_class = py.import("pyarrow")?.getattr("Table")?;
692 let args = PyTuple::new(py, &[batches, schema])?;
693 let table: PyObject = table_class.call_method1("from_batches", args)?.into();
694 Ok(table)
695 }
696
697 #[pyo3(signature = (requested_schema=None))]
698 fn __arrow_c_stream__<'py>(
699 &'py mut self,
700 py: Python<'py>,
701 requested_schema: Option<Bound<'py, PyCapsule>>,
702 ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
703 let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
704 let mut schema: Schema = self.df.schema().to_owned().into();
705
706 if let Some(schema_capsule) = requested_schema {
707 validate_pycapsule(&schema_capsule, "arrow_schema")?;
708
709 let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
710 let desired_schema = Schema::try_from(schema_ptr)?;
711
712 schema = project_schema(schema, desired_schema)?;
713
714 batches = batches
715 .into_iter()
716 .map(|record_batch| record_batch_into_schema(record_batch, &schema))
717 .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
718 }
719
720 let batches_wrapped = batches.into_iter().map(Ok);
721
722 let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema));
723 let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);
724
725 let ffi_stream = FFI_ArrowArrayStream::new(reader);
726 let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
727 PyCapsule::new(py, ffi_stream, Some(stream_capsule_name)).map_err(PyDataFusionError::from)
728 }
729
730 fn execute_stream(&self, py: Python) -> PyDataFusionResult<PyRecordBatchStream> {
731 let rt = &get_tokio_runtime().0;
733 let df = self.df.as_ref().clone();
734 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
735 rt.spawn(async move { df.execute_stream().await });
736 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
737 Ok(PyRecordBatchStream::new(stream?))
738 }
739
740 fn execute_stream_partitioned(&self, py: Python) -> PyResult<Vec<PyRecordBatchStream>> {
741 let rt = &get_tokio_runtime().0;
743 let df = self.df.as_ref().clone();
744 let fut: JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
745 rt.spawn(async move { df.execute_stream_partitioned().await });
746 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
747
748 match stream {
749 Ok(batches) => Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
750 _ => Err(PyValueError::new_err(
751 "Unable to execute stream partitioned",
752 )),
753 }
754 }
755
756 fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
759 let table = self.to_arrow_table(py)?;
760
761 let result = table.call_method0(py, "to_pandas")?;
763 Ok(result)
764 }
765
766 fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
769 let table = self.to_arrow_table(py)?;
770
771 let result = table.call_method0(py, "to_pylist")?;
773 Ok(result)
774 }
775
776 fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
779 let table = self.to_arrow_table(py)?;
780
781 let result = table.call_method0(py, "to_pydict")?;
783 Ok(result)
784 }
785
786 fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
789 let table = self.to_arrow_table(py)?;
790 let dataframe = py.import("polars")?.getattr("DataFrame")?;
791 let args = PyTuple::new(py, &[table])?;
792 let result: PyObject = dataframe.call1(args)?.into();
793 Ok(result)
794 }
795
796 fn count(&self, py: Python) -> PyDataFusionResult<usize> {
798 Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
799 }
800}
801
802fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
804 let batches = wait_for_future(py, df.collect())?;
806 let batches_as_string = pretty::pretty_format_batches(&batches);
807 let result = match batches_as_string {
808 Ok(batch) => format!("DataFrame()\n{batch}"),
809 Err(err) => format!("Error: {:?}", err.to_string()),
810 };
811
812 let print = py.import("builtins")?.getattr("print")?;
815 print.call1((result,))?;
816 Ok(())
817}
818
819fn project_schema(from_schema: Schema, to_schema: Schema) -> Result<Schema, ArrowError> {
820 let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?;
821
822 let project_indices: Vec<usize> = to_schema
823 .fields
824 .iter()
825 .map(|field| field.name())
826 .filter_map(|field_name| merged_schema.index_of(field_name).ok())
827 .collect();
828
829 merged_schema.project(&project_indices)
830}
831
832fn record_batch_into_schema(
833 record_batch: RecordBatch,
834 schema: &Schema,
835) -> Result<RecordBatch, ArrowError> {
836 let schema = Arc::new(schema.clone());
837 let base_schema = record_batch.schema();
838 if base_schema.fields().len() == 0 {
839 return Ok(RecordBatch::new_empty(schema));
841 }
842
843 let array_size = record_batch.column(0).len();
844 let mut data_arrays = Vec::with_capacity(schema.fields().len());
845
846 for field in schema.fields() {
847 let desired_data_type = field.data_type();
848 if let Some(original_data) = record_batch.column_by_name(field.name()) {
849 let original_data_type = original_data.data_type();
850
851 if can_cast_types(original_data_type, desired_data_type) {
852 data_arrays.push(arrow::compute::kernels::cast(
853 original_data,
854 desired_data_type,
855 )?);
856 } else if field.is_nullable() {
857 data_arrays.push(new_null_array(desired_data_type, array_size));
858 } else {
859 return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name())));
860 }
861 } else {
862 if !field.is_nullable() {
863 return Err(ArrowError::CastError(format!(
864 "Attempting to set null to non-nullable field {} during schema projection.",
865 field.name()
866 )));
867 }
868 data_arrays.push(new_null_array(desired_data_type, array_size));
869 }
870 }
871
872 RecordBatch::try_new(schema, data_arrays)
873}
874
875async fn collect_record_batches_to_display(
886 df: DataFrame,
887 min_rows: usize,
888 max_rows: usize,
889) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
890 let partitioned_stream = df.execute_stream_partitioned().await?;
891 let mut stream = futures::stream::iter(partitioned_stream).flatten();
892 let mut size_estimate_so_far = 0;
893 let mut rows_so_far = 0;
894 let mut record_batches = Vec::default();
895 let mut has_more = false;
896
897 while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
898 || rows_so_far < min_rows
899 {
900 let mut rb = match stream.next().await {
901 None => {
902 break;
903 }
904 Some(Ok(r)) => r,
905 Some(Err(e)) => return Err(e),
906 };
907
908 let mut rows_in_rb = rb.num_rows();
909 if rows_in_rb > 0 {
910 size_estimate_so_far += rb.get_array_memory_size();
911
912 if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
913 let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
914 let total_rows = rows_in_rb + rows_so_far;
915
916 let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
917 if reduced_row_num < min_rows {
918 reduced_row_num = min_rows.min(total_rows);
919 }
920
921 let limited_rows_this_rb = reduced_row_num - rows_so_far;
922 if limited_rows_this_rb < rows_in_rb {
923 rows_in_rb = limited_rows_this_rb;
924 rb = rb.slice(0, limited_rows_this_rb);
925 has_more = true;
926 }
927 }
928
929 if rows_in_rb + rows_so_far > max_rows {
930 rb = rb.slice(0, max_rows - rows_so_far);
931 has_more = true;
932 }
933
934 rows_so_far += rb.num_rows();
935 record_batches.push(rb);
936 }
937 }
938
939 if record_batches.is_empty() {
940 return Ok((Vec::default(), false));
941 }
942
943 if !has_more {
944 has_more = match stream.try_next().await {
946 Ok(None) => false, Ok(Some(_)) => true,
948 Err(_) => false, };
950 }
951
952 Ok((record_batches, has_more))
953}