1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udtf::PyTableFunction;
47use crate::udwf::PyWindowUDF;
48use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
49use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
50use datafusion::arrow::pyarrow::PyArrowType;
51use datafusion::arrow::record_batch::RecordBatch;
52use datafusion::common::TableReference;
53use datafusion::common::{exec_err, ScalarValue};
54use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
55use datafusion::datasource::file_format::parquet::ParquetFormat;
56use datafusion::datasource::listing::{
57 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
58};
59use datafusion::datasource::MemTable;
60use datafusion::datasource::TableProvider;
61use datafusion::execution::context::{
62 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
63};
64use datafusion::execution::disk_manager::DiskManagerConfig;
65use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
66use datafusion::execution::options::ReadOptions;
67use datafusion::execution::runtime_env::RuntimeEnvBuilder;
68use datafusion::physical_plan::SendableRecordBatchStream;
69use datafusion::prelude::{
70 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
71};
72use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
73use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
74use tokio::task::JoinHandle;
75
76#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
78#[derive(Clone, Default)]
79pub struct PySessionConfig {
80 pub config: SessionConfig,
81}
82
83impl From<SessionConfig> for PySessionConfig {
84 fn from(config: SessionConfig) -> Self {
85 Self { config }
86 }
87}
88
89#[pymethods]
90impl PySessionConfig {
91 #[pyo3(signature = (config_options=None))]
92 #[new]
93 fn new(config_options: Option<HashMap<String, String>>) -> Self {
94 let mut config = SessionConfig::new();
95 if let Some(hash_map) = config_options {
96 for (k, v) in &hash_map {
97 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
98 }
99 }
100
101 Self { config }
102 }
103
104 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
105 Self::from(
106 self.config
107 .clone()
108 .with_create_default_catalog_and_schema(enabled),
109 )
110 }
111
112 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
113 Self::from(
114 self.config
115 .clone()
116 .with_default_catalog_and_schema(catalog, schema),
117 )
118 }
119
120 fn with_information_schema(&self, enabled: bool) -> Self {
121 Self::from(self.config.clone().with_information_schema(enabled))
122 }
123
124 fn with_batch_size(&self, batch_size: usize) -> Self {
125 Self::from(self.config.clone().with_batch_size(batch_size))
126 }
127
128 fn with_target_partitions(&self, target_partitions: usize) -> Self {
129 Self::from(
130 self.config
131 .clone()
132 .with_target_partitions(target_partitions),
133 )
134 }
135
136 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
137 Self::from(self.config.clone().with_repartition_aggregations(enabled))
138 }
139
140 fn with_repartition_joins(&self, enabled: bool) -> Self {
141 Self::from(self.config.clone().with_repartition_joins(enabled))
142 }
143
144 fn with_repartition_windows(&self, enabled: bool) -> Self {
145 Self::from(self.config.clone().with_repartition_windows(enabled))
146 }
147
148 fn with_repartition_sorts(&self, enabled: bool) -> Self {
149 Self::from(self.config.clone().with_repartition_sorts(enabled))
150 }
151
152 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
153 Self::from(self.config.clone().with_repartition_file_scans(enabled))
154 }
155
156 fn with_repartition_file_min_size(&self, size: usize) -> Self {
157 Self::from(self.config.clone().with_repartition_file_min_size(size))
158 }
159
160 fn with_parquet_pruning(&self, enabled: bool) -> Self {
161 Self::from(self.config.clone().with_parquet_pruning(enabled))
162 }
163
164 fn set(&self, key: &str, value: &str) -> Self {
165 Self::from(self.config.clone().set_str(key, value))
166 }
167}
168
169#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
171#[derive(Clone)]
172pub struct PyRuntimeEnvBuilder {
173 pub builder: RuntimeEnvBuilder,
174}
175
176#[pymethods]
177impl PyRuntimeEnvBuilder {
178 #[new]
179 fn new() -> Self {
180 Self {
181 builder: RuntimeEnvBuilder::default(),
182 }
183 }
184
185 fn with_disk_manager_disabled(&self) -> Self {
186 let mut builder = self.builder.clone();
187 builder = builder.with_disk_manager(DiskManagerConfig::Disabled);
188 Self { builder }
189 }
190
191 fn with_disk_manager_os(&self) -> Self {
192 let builder = self.builder.clone();
193 let builder = builder.with_disk_manager(DiskManagerConfig::NewOs);
194 Self { builder }
195 }
196
197 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
198 let builder = self.builder.clone();
199 let paths = paths.iter().map(|s| s.into()).collect();
200 let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths));
201 Self { builder }
202 }
203
204 fn with_unbounded_memory_pool(&self) -> Self {
205 let builder = self.builder.clone();
206 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
207 Self { builder }
208 }
209
210 fn with_fair_spill_pool(&self, size: usize) -> Self {
211 let builder = self.builder.clone();
212 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
213 Self { builder }
214 }
215
216 fn with_greedy_memory_pool(&self, size: usize) -> Self {
217 let builder = self.builder.clone();
218 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
219 Self { builder }
220 }
221
222 fn with_temp_file_path(&self, path: &str) -> Self {
223 let builder = self.builder.clone();
224 let builder = builder.with_temp_file_path(path);
225 Self { builder }
226 }
227}
228
229#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
231#[derive(Clone)]
232pub struct PySQLOptions {
233 pub options: SQLOptions,
234}
235
236impl From<SQLOptions> for PySQLOptions {
237 fn from(options: SQLOptions) -> Self {
238 Self { options }
239 }
240}
241
242#[pymethods]
243impl PySQLOptions {
244 #[new]
245 fn new() -> Self {
246 let options = SQLOptions::new();
247 Self { options }
248 }
249
250 fn with_allow_ddl(&self, allow: bool) -> Self {
252 Self::from(self.options.with_allow_ddl(allow))
253 }
254
255 pub fn with_allow_dml(&self, allow: bool) -> Self {
257 Self::from(self.options.with_allow_dml(allow))
258 }
259
260 pub fn with_allow_statements(&self, allow: bool) -> Self {
262 Self::from(self.options.with_allow_statements(allow))
263 }
264}
265
266#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
270#[derive(Clone)]
271pub struct PySessionContext {
272 pub ctx: SessionContext,
273}
274
275#[pymethods]
276impl PySessionContext {
277 #[pyo3(signature = (config=None, runtime=None))]
278 #[new]
279 pub fn new(
280 config: Option<PySessionConfig>,
281 runtime: Option<PyRuntimeEnvBuilder>,
282 ) -> PyDataFusionResult<Self> {
283 let config = if let Some(c) = config {
284 c.config
285 } else {
286 SessionConfig::default().with_information_schema(true)
287 };
288 let runtime_env_builder = if let Some(c) = runtime {
289 c.builder
290 } else {
291 RuntimeEnvBuilder::default()
292 };
293 let runtime = Arc::new(runtime_env_builder.build()?);
294 let session_state = SessionStateBuilder::new()
295 .with_config(config)
296 .with_runtime_env(runtime)
297 .with_default_features()
298 .build();
299 Ok(PySessionContext {
300 ctx: SessionContext::new_with_state(session_state),
301 })
302 }
303
304 pub fn enable_url_table(&self) -> PyResult<Self> {
305 Ok(PySessionContext {
306 ctx: self.ctx.clone().enable_url_table(),
307 })
308 }
309
310 #[classmethod]
311 #[pyo3(signature = ())]
312 fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
313 Ok(Self {
314 ctx: get_global_ctx().clone(),
315 })
316 }
317
318 #[pyo3(signature = (scheme, store, host=None))]
320 pub fn register_object_store(
321 &mut self,
322 scheme: &str,
323 store: StorageContexts,
324 host: Option<&str>,
325 ) -> PyResult<()> {
326 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
328 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
329 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
330 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
331 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
332 StorageContexts::HTTP(http) => (http.store, http.url),
333 };
334
335 let derived_host = if let Some(host) = host {
337 host
338 } else {
339 &upstream_host
340 };
341 let url_string = format!("{}{}", scheme, derived_host);
342 let url = Url::parse(&url_string).unwrap();
343 self.ctx.runtime_env().register_object_store(&url, store);
344 Ok(())
345 }
346
347 #[allow(clippy::too_many_arguments)]
348 #[pyo3(signature = (name, path, table_partition_cols=vec![],
349 file_extension=".parquet",
350 schema=None,
351 file_sort_order=None))]
352 pub fn register_listing_table(
353 &mut self,
354 name: &str,
355 path: &str,
356 table_partition_cols: Vec<(String, String)>,
357 file_extension: &str,
358 schema: Option<PyArrowType<Schema>>,
359 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
360 py: Python,
361 ) -> PyDataFusionResult<()> {
362 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
363 .with_file_extension(file_extension)
364 .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
365 .with_file_sort_order(
366 file_sort_order
367 .unwrap_or_default()
368 .into_iter()
369 .map(|e| e.into_iter().map(|f| f.into()).collect())
370 .collect(),
371 );
372 let table_path = ListingTableUrl::parse(path)?;
373 let resolved_schema: SchemaRef = match schema {
374 Some(s) => Arc::new(s.0),
375 None => {
376 let state = self.ctx.state();
377 let schema = options.infer_schema(&state, &table_path);
378 wait_for_future(py, schema)?
379 }
380 };
381 let config = ListingTableConfig::new(table_path)
382 .with_listing_options(options)
383 .with_schema(resolved_schema);
384 let table = ListingTable::try_new(config)?;
385 self.register_table(
386 name,
387 &PyTable {
388 table: Arc::new(table),
389 },
390 )?;
391 Ok(())
392 }
393
394 pub fn register_udtf(&mut self, func: PyTableFunction) {
395 let name = func.name.clone();
396 let func = Arc::new(func);
397 self.ctx.register_udtf(&name, func);
398 }
399
400 pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
402 let result = self.ctx.sql(query);
403 let df = wait_for_future(py, result)?;
404 Ok(PyDataFrame::new(df))
405 }
406
407 #[pyo3(signature = (query, options=None))]
408 pub fn sql_with_options(
409 &mut self,
410 query: &str,
411 options: Option<PySQLOptions>,
412 py: Python,
413 ) -> PyDataFusionResult<PyDataFrame> {
414 let options = if let Some(options) = options {
415 options.options
416 } else {
417 SQLOptions::new()
418 };
419 let result = self.ctx.sql_with_options(query, options);
420 let df = wait_for_future(py, result)?;
421 Ok(PyDataFrame::new(df))
422 }
423
424 #[pyo3(signature = (partitions, name=None, schema=None))]
425 pub fn create_dataframe(
426 &mut self,
427 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
428 name: Option<&str>,
429 schema: Option<PyArrowType<Schema>>,
430 py: Python,
431 ) -> PyDataFusionResult<PyDataFrame> {
432 let schema = if let Some(schema) = schema {
433 SchemaRef::from(schema.0)
434 } else {
435 partitions.0[0][0].schema()
436 };
437
438 let table = MemTable::try_new(schema, partitions.0)?;
439
440 let table_name = match name {
443 Some(val) => val.to_owned(),
444 None => {
445 "c".to_owned()
446 + Uuid::new_v4()
447 .simple()
448 .encode_lower(&mut Uuid::encode_buffer())
449 }
450 };
451
452 self.ctx.register_table(&*table_name, Arc::new(table))?;
453
454 let table = wait_for_future(py, self._table(&table_name))?;
455
456 let df = PyDataFrame::new(table);
457 Ok(df)
458 }
459
460 pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
462 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
463 }
464
465 #[pyo3(signature = (data, name=None))]
467 pub fn from_pylist(
468 &mut self,
469 data: Bound<'_, PyList>,
470 name: Option<&str>,
471 ) -> PyResult<PyDataFrame> {
472 let py = data.py();
474
475 let table_class = py.import("pyarrow")?.getattr("Table")?;
477 let args = PyTuple::new(py, &[data])?;
478 let table = table_class.call_method1("from_pylist", args)?;
479
480 let df = self.from_arrow(table, name, py)?;
482 Ok(df)
483 }
484
485 #[pyo3(signature = (data, name=None))]
487 pub fn from_pydict(
488 &mut self,
489 data: Bound<'_, PyDict>,
490 name: Option<&str>,
491 ) -> PyResult<PyDataFrame> {
492 let py = data.py();
494
495 let table_class = py.import("pyarrow")?.getattr("Table")?;
497 let args = PyTuple::new(py, &[data])?;
498 let table = table_class.call_method1("from_pydict", args)?;
499
500 let df = self.from_arrow(table, name, py)?;
502 Ok(df)
503 }
504
505 #[pyo3(signature = (data, name=None))]
507 pub fn from_arrow(
508 &mut self,
509 data: Bound<'_, PyAny>,
510 name: Option<&str>,
511 py: Python,
512 ) -> PyDataFusionResult<PyDataFrame> {
513 let (schema, batches) =
514 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
515 let schema = stream_reader.schema().as_ref().to_owned();
518 let batches = stream_reader
519 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
520
521 (schema, batches)
522 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
523 (array.schema().as_ref().to_owned(), vec![array])
527 } else {
528 return Err(crate::errors::PyDataFusionError::Common(
529 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
530 ));
531 };
532
533 let list_of_batches = PyArrowType::from(vec![batches]);
536 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
537 }
538
539 #[allow(clippy::wrong_self_convention)]
541 #[pyo3(signature = (data, name=None))]
542 pub fn from_pandas(
543 &mut self,
544 data: Bound<'_, PyAny>,
545 name: Option<&str>,
546 ) -> PyResult<PyDataFrame> {
547 let py = data.py();
549
550 let table_class = py.import("pyarrow")?.getattr("Table")?;
552 let args = PyTuple::new(py, &[data])?;
553 let table = table_class.call_method1("from_pandas", args)?;
554
555 let df = self.from_arrow(table, name, py)?;
557 Ok(df)
558 }
559
560 #[pyo3(signature = (data, name=None))]
562 pub fn from_polars(
563 &mut self,
564 data: Bound<'_, PyAny>,
565 name: Option<&str>,
566 ) -> PyResult<PyDataFrame> {
567 let table = data.call_method0("to_arrow")?;
569
570 let df = self.from_arrow(table, name, data.py())?;
572 Ok(df)
573 }
574
575 pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
576 self.ctx.register_table(name, table.table())?;
577 Ok(())
578 }
579
580 pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
581 self.ctx.deregister_table(name)?;
582 Ok(())
583 }
584
585 pub fn register_table_provider(
587 &mut self,
588 name: &str,
589 provider: Bound<'_, PyAny>,
590 ) -> PyDataFusionResult<()> {
591 if provider.hasattr("__datafusion_table_provider__")? {
592 let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
593 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
594 validate_pycapsule(capsule, "datafusion_table_provider")?;
595
596 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
597 let provider: ForeignTableProvider = provider.into();
598
599 let _ = self.ctx.register_table(name, Arc::new(provider))?;
600
601 Ok(())
602 } else {
603 Err(crate::errors::PyDataFusionError::Common(
604 "__datafusion_table_provider__ does not exist on Table Provider object."
605 .to_string(),
606 ))
607 }
608 }
609
610 pub fn register_record_batches(
611 &mut self,
612 name: &str,
613 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
614 ) -> PyDataFusionResult<()> {
615 let schema = partitions.0[0][0].schema();
616 let table = MemTable::try_new(schema, partitions.0)?;
617 self.ctx.register_table(name, Arc::new(table))?;
618 Ok(())
619 }
620
621 #[allow(clippy::too_many_arguments)]
622 #[pyo3(signature = (name, path, table_partition_cols=vec![],
623 parquet_pruning=true,
624 file_extension=".parquet",
625 skip_metadata=true,
626 schema=None,
627 file_sort_order=None))]
628 pub fn register_parquet(
629 &mut self,
630 name: &str,
631 path: &str,
632 table_partition_cols: Vec<(String, String)>,
633 parquet_pruning: bool,
634 file_extension: &str,
635 skip_metadata: bool,
636 schema: Option<PyArrowType<Schema>>,
637 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
638 py: Python,
639 ) -> PyDataFusionResult<()> {
640 let mut options = ParquetReadOptions::default()
641 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
642 .parquet_pruning(parquet_pruning)
643 .skip_metadata(skip_metadata);
644 options.file_extension = file_extension;
645 options.schema = schema.as_ref().map(|x| &x.0);
646 options.file_sort_order = file_sort_order
647 .unwrap_or_default()
648 .into_iter()
649 .map(|e| e.into_iter().map(|f| f.into()).collect())
650 .collect();
651
652 let result = self.ctx.register_parquet(name, path, options);
653 wait_for_future(py, result)?;
654 Ok(())
655 }
656
657 #[allow(clippy::too_many_arguments)]
658 #[pyo3(signature = (name,
659 path,
660 schema=None,
661 has_header=true,
662 delimiter=",",
663 schema_infer_max_records=1000,
664 file_extension=".csv",
665 file_compression_type=None))]
666 pub fn register_csv(
667 &mut self,
668 name: &str,
669 path: &Bound<'_, PyAny>,
670 schema: Option<PyArrowType<Schema>>,
671 has_header: bool,
672 delimiter: &str,
673 schema_infer_max_records: usize,
674 file_extension: &str,
675 file_compression_type: Option<String>,
676 py: Python,
677 ) -> PyDataFusionResult<()> {
678 let delimiter = delimiter.as_bytes();
679 if delimiter.len() != 1 {
680 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
681 "Delimiter must be a single character",
682 )));
683 }
684
685 let mut options = CsvReadOptions::new()
686 .has_header(has_header)
687 .delimiter(delimiter[0])
688 .schema_infer_max_records(schema_infer_max_records)
689 .file_extension(file_extension)
690 .file_compression_type(parse_file_compression_type(file_compression_type)?);
691 options.schema = schema.as_ref().map(|x| &x.0);
692
693 if path.is_instance_of::<PyList>() {
694 let paths = path.extract::<Vec<String>>()?;
695 let result = self.register_csv_from_multiple_paths(name, paths, options);
696 wait_for_future(py, result)?;
697 } else {
698 let path = path.extract::<String>()?;
699 let result = self.ctx.register_csv(name, &path, options);
700 wait_for_future(py, result)?;
701 }
702
703 Ok(())
704 }
705
706 #[allow(clippy::too_many_arguments)]
707 #[pyo3(signature = (name,
708 path,
709 schema=None,
710 schema_infer_max_records=1000,
711 file_extension=".json",
712 table_partition_cols=vec![],
713 file_compression_type=None))]
714 pub fn register_json(
715 &mut self,
716 name: &str,
717 path: PathBuf,
718 schema: Option<PyArrowType<Schema>>,
719 schema_infer_max_records: usize,
720 file_extension: &str,
721 table_partition_cols: Vec<(String, String)>,
722 file_compression_type: Option<String>,
723 py: Python,
724 ) -> PyDataFusionResult<()> {
725 let path = path
726 .to_str()
727 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
728
729 let mut options = NdJsonReadOptions::default()
730 .file_compression_type(parse_file_compression_type(file_compression_type)?)
731 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
732 options.schema_infer_max_records = schema_infer_max_records;
733 options.file_extension = file_extension;
734 options.schema = schema.as_ref().map(|x| &x.0);
735
736 let result = self.ctx.register_json(name, path, options);
737 wait_for_future(py, result)?;
738
739 Ok(())
740 }
741
742 #[allow(clippy::too_many_arguments)]
743 #[pyo3(signature = (name,
744 path,
745 schema=None,
746 file_extension=".avro",
747 table_partition_cols=vec![]))]
748 pub fn register_avro(
749 &mut self,
750 name: &str,
751 path: PathBuf,
752 schema: Option<PyArrowType<Schema>>,
753 file_extension: &str,
754 table_partition_cols: Vec<(String, String)>,
755 py: Python,
756 ) -> PyDataFusionResult<()> {
757 let path = path
758 .to_str()
759 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
760
761 let mut options = AvroReadOptions::default()
762 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
763 options.file_extension = file_extension;
764 options.schema = schema.as_ref().map(|x| &x.0);
765
766 let result = self.ctx.register_avro(name, path, options);
767 wait_for_future(py, result)?;
768
769 Ok(())
770 }
771
772 pub fn register_dataset(
774 &self,
775 name: &str,
776 dataset: &Bound<'_, PyAny>,
777 py: Python,
778 ) -> PyDataFusionResult<()> {
779 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
780
781 self.ctx.register_table(name, table)?;
782
783 Ok(())
784 }
785
786 pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
787 self.ctx.register_udf(udf.function);
788 Ok(())
789 }
790
791 pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
792 self.ctx.register_udaf(udaf.function);
793 Ok(())
794 }
795
796 pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
797 self.ctx.register_udwf(udwf.function);
798 Ok(())
799 }
800
801 #[pyo3(signature = (name="datafusion"))]
802 pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
803 match self.ctx.catalog(name) {
804 Some(catalog) => Ok(PyCatalog::new(catalog)),
805 None => Err(PyKeyError::new_err(format!(
806 "Catalog with name {} doesn't exist.",
807 &name,
808 ))),
809 }
810 }
811
812 pub fn tables(&self) -> HashSet<String> {
813 self.ctx
814 .catalog_names()
815 .into_iter()
816 .filter_map(|name| self.ctx.catalog(&name))
817 .flat_map(move |catalog| {
818 catalog
819 .schema_names()
820 .into_iter()
821 .filter_map(move |name| catalog.schema(&name))
822 })
823 .flat_map(|schema| schema.table_names())
824 .collect()
825 }
826
827 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
828 let x = wait_for_future(py, self.ctx.table(name))
829 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
830 Ok(PyDataFrame::new(x))
831 }
832
833 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
834 Ok(self.ctx.table_exist(name)?)
835 }
836
837 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
838 Ok(PyDataFrame::new(self.ctx.read_empty()?))
839 }
840
841 pub fn session_id(&self) -> String {
842 self.ctx.session_id()
843 }
844
845 #[allow(clippy::too_many_arguments)]
846 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
847 pub fn read_json(
848 &mut self,
849 path: PathBuf,
850 schema: Option<PyArrowType<Schema>>,
851 schema_infer_max_records: usize,
852 file_extension: &str,
853 table_partition_cols: Vec<(String, String)>,
854 file_compression_type: Option<String>,
855 py: Python,
856 ) -> PyDataFusionResult<PyDataFrame> {
857 let path = path
858 .to_str()
859 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
860 let mut options = NdJsonReadOptions::default()
861 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
862 .file_compression_type(parse_file_compression_type(file_compression_type)?);
863 options.schema_infer_max_records = schema_infer_max_records;
864 options.file_extension = file_extension;
865 let df = if let Some(schema) = schema {
866 options.schema = Some(&schema.0);
867 let result = self.ctx.read_json(path, options);
868 wait_for_future(py, result)?
869 } else {
870 let result = self.ctx.read_json(path, options);
871 wait_for_future(py, result)?
872 };
873 Ok(PyDataFrame::new(df))
874 }
875
876 #[allow(clippy::too_many_arguments)]
877 #[pyo3(signature = (
878 path,
879 schema=None,
880 has_header=true,
881 delimiter=",",
882 schema_infer_max_records=1000,
883 file_extension=".csv",
884 table_partition_cols=vec![],
885 file_compression_type=None))]
886 pub fn read_csv(
887 &self,
888 path: &Bound<'_, PyAny>,
889 schema: Option<PyArrowType<Schema>>,
890 has_header: bool,
891 delimiter: &str,
892 schema_infer_max_records: usize,
893 file_extension: &str,
894 table_partition_cols: Vec<(String, String)>,
895 file_compression_type: Option<String>,
896 py: Python,
897 ) -> PyDataFusionResult<PyDataFrame> {
898 let delimiter = delimiter.as_bytes();
899 if delimiter.len() != 1 {
900 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
901 "Delimiter must be a single character",
902 )));
903 };
904
905 let mut options = CsvReadOptions::new()
906 .has_header(has_header)
907 .delimiter(delimiter[0])
908 .schema_infer_max_records(schema_infer_max_records)
909 .file_extension(file_extension)
910 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
911 .file_compression_type(parse_file_compression_type(file_compression_type)?);
912 options.schema = schema.as_ref().map(|x| &x.0);
913
914 if path.is_instance_of::<PyList>() {
915 let paths = path.extract::<Vec<String>>()?;
916 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
917 let result = self.ctx.read_csv(paths, options);
918 let df = PyDataFrame::new(wait_for_future(py, result)?);
919 Ok(df)
920 } else {
921 let path = path.extract::<String>()?;
922 let result = self.ctx.read_csv(path, options);
923 let df = PyDataFrame::new(wait_for_future(py, result)?);
924 Ok(df)
925 }
926 }
927
928 #[allow(clippy::too_many_arguments)]
929 #[pyo3(signature = (
930 path,
931 table_partition_cols=vec![],
932 parquet_pruning=true,
933 file_extension=".parquet",
934 skip_metadata=true,
935 schema=None,
936 file_sort_order=None))]
937 pub fn read_parquet(
938 &self,
939 path: &str,
940 table_partition_cols: Vec<(String, String)>,
941 parquet_pruning: bool,
942 file_extension: &str,
943 skip_metadata: bool,
944 schema: Option<PyArrowType<Schema>>,
945 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
946 py: Python,
947 ) -> PyDataFusionResult<PyDataFrame> {
948 let mut options = ParquetReadOptions::default()
949 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
950 .parquet_pruning(parquet_pruning)
951 .skip_metadata(skip_metadata);
952 options.file_extension = file_extension;
953 options.schema = schema.as_ref().map(|x| &x.0);
954 options.file_sort_order = file_sort_order
955 .unwrap_or_default()
956 .into_iter()
957 .map(|e| e.into_iter().map(|f| f.into()).collect())
958 .collect();
959
960 let result = self.ctx.read_parquet(path, options);
961 let df = PyDataFrame::new(wait_for_future(py, result)?);
962 Ok(df)
963 }
964
965 #[allow(clippy::too_many_arguments)]
966 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
967 pub fn read_avro(
968 &self,
969 path: &str,
970 schema: Option<PyArrowType<Schema>>,
971 table_partition_cols: Vec<(String, String)>,
972 file_extension: &str,
973 py: Python,
974 ) -> PyDataFusionResult<PyDataFrame> {
975 let mut options = AvroReadOptions::default()
976 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
977 options.file_extension = file_extension;
978 let df = if let Some(schema) = schema {
979 options.schema = Some(&schema.0);
980 let read_future = self.ctx.read_avro(path, options);
981 wait_for_future(py, read_future)?
982 } else {
983 let read_future = self.ctx.read_avro(path, options);
984 wait_for_future(py, read_future)?
985 };
986 Ok(PyDataFrame::new(df))
987 }
988
989 pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
990 let df = self.ctx.read_table(table.table())?;
991 Ok(PyDataFrame::new(df))
992 }
993
994 fn __repr__(&self) -> PyResult<String> {
995 let config = self.ctx.copied_config();
996 let mut config_entries = config
997 .options()
998 .entries()
999 .iter()
1000 .filter(|e| e.value.is_some())
1001 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1002 .collect::<Vec<_>>();
1003 config_entries.sort();
1004 Ok(format!(
1005 "SessionContext: id={}; configs=[\n\t{}]",
1006 self.session_id(),
1007 config_entries.join("\n\t")
1008 ))
1009 }
1010
1011 pub fn execute(
1013 &self,
1014 plan: PyExecutionPlan,
1015 part: usize,
1016 py: Python,
1017 ) -> PyDataFusionResult<PyRecordBatchStream> {
1018 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1019 let rt = &get_tokio_runtime().0;
1021 let plan = plan.plan.clone();
1022 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1023 rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1024 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1025 Ok(PyRecordBatchStream::new(stream?))
1026 }
1027}
1028
1029impl PySessionContext {
1030 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1031 self.ctx.table(name).await
1032 }
1033
1034 async fn register_csv_from_multiple_paths(
1035 &self,
1036 name: &str,
1037 table_paths: Vec<String>,
1038 options: CsvReadOptions<'_>,
1039 ) -> datafusion::common::Result<()> {
1040 let table_paths = table_paths.to_urls()?;
1041 let session_config = self.ctx.copied_config();
1042 let listing_options =
1043 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1044
1045 let option_extension = listing_options.file_extension.clone();
1046
1047 if table_paths.is_empty() {
1048 return exec_err!("No table paths were provided");
1049 }
1050
1051 for path in &table_paths {
1053 let file_path = path.as_str();
1054 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1055 return exec_err!(
1056 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1057 );
1058 }
1059 }
1060
1061 let resolved_schema = options
1062 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1063 .await?;
1064
1065 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1066 .with_listing_options(listing_options)
1067 .with_schema(resolved_schema);
1068 let table = ListingTable::try_new(config)?;
1069 self.ctx
1070 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1071 Ok(())
1072 }
1073}
1074
1075pub fn convert_table_partition_cols(
1076 table_partition_cols: Vec<(String, String)>,
1077) -> PyDataFusionResult<Vec<(String, DataType)>> {
1078 table_partition_cols
1079 .into_iter()
1080 .map(|(name, ty)| match ty.as_str() {
1081 "string" => Ok((name, DataType::Utf8)),
1082 "int" => Ok((name, DataType::Int32)),
1083 _ => Err(crate::errors::PyDataFusionError::Common(format!(
1084 "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
1085 ))),
1086 })
1087 .collect::<Result<Vec<_>, _>>()
1088}
1089
1090pub fn parse_file_compression_type(
1091 file_compression_type: Option<String>,
1092) -> Result<FileCompressionType, PyErr> {
1093 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1094 .map_err(|_| {
1095 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1096 })
1097}
1098
1099impl From<PySessionContext> for SessionContext {
1100 fn from(ctx: PySessionContext) -> SessionContext {
1101 ctx.ctx
1102 }
1103}
1104
1105impl From<SessionContext> for PySessionContext {
1106 fn from(ctx: SessionContext) -> PySessionContext {
1107 PySessionContext { ctx }
1108 }
1109}