1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::table::PyTable;
45use crate::udaf::PyAggregateUDF;
46use crate::udf::PyScalarUDF;
47use crate::udtf::PyTableFunction;
48use crate::udwf::PyWindowUDF;
49use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
50use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
51use datafusion::arrow::pyarrow::PyArrowType;
52use datafusion::arrow::record_batch::RecordBatch;
53use datafusion::catalog::CatalogProvider;
54use datafusion::common::TableReference;
55use datafusion::common::{exec_err, ScalarValue};
56use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
57use datafusion::datasource::file_format::parquet::ParquetFormat;
58use datafusion::datasource::listing::{
59 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
60};
61use datafusion::datasource::MemTable;
62use datafusion::datasource::TableProvider;
63use datafusion::execution::context::{
64 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
65};
66use datafusion::execution::disk_manager::DiskManagerMode;
67use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
68use datafusion::execution::options::ReadOptions;
69use datafusion::execution::runtime_env::RuntimeEnvBuilder;
70use datafusion::physical_plan::SendableRecordBatchStream;
71use datafusion::prelude::{
72 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
73};
74use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
75use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
76use pyo3::IntoPyObjectExt;
77use tokio::task::JoinHandle;
78
79#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
81#[derive(Clone, Default)]
82pub struct PySessionConfig {
83 pub config: SessionConfig,
84}
85
86impl From<SessionConfig> for PySessionConfig {
87 fn from(config: SessionConfig) -> Self {
88 Self { config }
89 }
90}
91
92#[pymethods]
93impl PySessionConfig {
94 #[pyo3(signature = (config_options=None))]
95 #[new]
96 fn new(config_options: Option<HashMap<String, String>>) -> Self {
97 let mut config = SessionConfig::new();
98 if let Some(hash_map) = config_options {
99 for (k, v) in &hash_map {
100 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
101 }
102 }
103
104 Self { config }
105 }
106
107 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
108 Self::from(
109 self.config
110 .clone()
111 .with_create_default_catalog_and_schema(enabled),
112 )
113 }
114
115 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
116 Self::from(
117 self.config
118 .clone()
119 .with_default_catalog_and_schema(catalog, schema),
120 )
121 }
122
123 fn with_information_schema(&self, enabled: bool) -> Self {
124 Self::from(self.config.clone().with_information_schema(enabled))
125 }
126
127 fn with_batch_size(&self, batch_size: usize) -> Self {
128 Self::from(self.config.clone().with_batch_size(batch_size))
129 }
130
131 fn with_target_partitions(&self, target_partitions: usize) -> Self {
132 Self::from(
133 self.config
134 .clone()
135 .with_target_partitions(target_partitions),
136 )
137 }
138
139 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
140 Self::from(self.config.clone().with_repartition_aggregations(enabled))
141 }
142
143 fn with_repartition_joins(&self, enabled: bool) -> Self {
144 Self::from(self.config.clone().with_repartition_joins(enabled))
145 }
146
147 fn with_repartition_windows(&self, enabled: bool) -> Self {
148 Self::from(self.config.clone().with_repartition_windows(enabled))
149 }
150
151 fn with_repartition_sorts(&self, enabled: bool) -> Self {
152 Self::from(self.config.clone().with_repartition_sorts(enabled))
153 }
154
155 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
156 Self::from(self.config.clone().with_repartition_file_scans(enabled))
157 }
158
159 fn with_repartition_file_min_size(&self, size: usize) -> Self {
160 Self::from(self.config.clone().with_repartition_file_min_size(size))
161 }
162
163 fn with_parquet_pruning(&self, enabled: bool) -> Self {
164 Self::from(self.config.clone().with_parquet_pruning(enabled))
165 }
166
167 fn set(&self, key: &str, value: &str) -> Self {
168 Self::from(self.config.clone().set_str(key, value))
169 }
170}
171
172#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
174#[derive(Clone)]
175pub struct PyRuntimeEnvBuilder {
176 pub builder: RuntimeEnvBuilder,
177}
178
179#[pymethods]
180impl PyRuntimeEnvBuilder {
181 #[new]
182 fn new() -> Self {
183 Self {
184 builder: RuntimeEnvBuilder::default(),
185 }
186 }
187
188 fn with_disk_manager_disabled(&self) -> Self {
189 let mut runtime_builder = self.builder.clone();
190
191 let mut disk_mgr_builder = runtime_builder
192 .disk_manager_builder
193 .clone()
194 .unwrap_or_default();
195 disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
196
197 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
198 Self {
199 builder: runtime_builder,
200 }
201 }
202
203 fn with_disk_manager_os(&self) -> Self {
204 let mut runtime_builder = self.builder.clone();
205
206 let mut disk_mgr_builder = runtime_builder
207 .disk_manager_builder
208 .clone()
209 .unwrap_or_default();
210 disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
211
212 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
213 Self {
214 builder: runtime_builder,
215 }
216 }
217
218 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
219 let paths = paths.iter().map(|s| s.into()).collect();
220 let mut runtime_builder = self.builder.clone();
221
222 let mut disk_mgr_builder = runtime_builder
223 .disk_manager_builder
224 .clone()
225 .unwrap_or_default();
226 disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
227
228 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
229 Self {
230 builder: runtime_builder,
231 }
232 }
233
234 fn with_unbounded_memory_pool(&self) -> Self {
235 let builder = self.builder.clone();
236 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
237 Self { builder }
238 }
239
240 fn with_fair_spill_pool(&self, size: usize) -> Self {
241 let builder = self.builder.clone();
242 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
243 Self { builder }
244 }
245
246 fn with_greedy_memory_pool(&self, size: usize) -> Self {
247 let builder = self.builder.clone();
248 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
249 Self { builder }
250 }
251
252 fn with_temp_file_path(&self, path: &str) -> Self {
253 let builder = self.builder.clone();
254 let builder = builder.with_temp_file_path(path);
255 Self { builder }
256 }
257}
258
259#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
261#[derive(Clone)]
262pub struct PySQLOptions {
263 pub options: SQLOptions,
264}
265
266impl From<SQLOptions> for PySQLOptions {
267 fn from(options: SQLOptions) -> Self {
268 Self { options }
269 }
270}
271
272#[pymethods]
273impl PySQLOptions {
274 #[new]
275 fn new() -> Self {
276 let options = SQLOptions::new();
277 Self { options }
278 }
279
280 fn with_allow_ddl(&self, allow: bool) -> Self {
282 Self::from(self.options.with_allow_ddl(allow))
283 }
284
285 pub fn with_allow_dml(&self, allow: bool) -> Self {
287 Self::from(self.options.with_allow_dml(allow))
288 }
289
290 pub fn with_allow_statements(&self, allow: bool) -> Self {
292 Self::from(self.options.with_allow_statements(allow))
293 }
294}
295
296#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
300#[derive(Clone)]
301pub struct PySessionContext {
302 pub ctx: SessionContext,
303}
304
305#[pymethods]
306impl PySessionContext {
307 #[pyo3(signature = (config=None, runtime=None))]
308 #[new]
309 pub fn new(
310 config: Option<PySessionConfig>,
311 runtime: Option<PyRuntimeEnvBuilder>,
312 ) -> PyDataFusionResult<Self> {
313 let config = if let Some(c) = config {
314 c.config
315 } else {
316 SessionConfig::default().with_information_schema(true)
317 };
318 let runtime_env_builder = if let Some(c) = runtime {
319 c.builder
320 } else {
321 RuntimeEnvBuilder::default()
322 };
323 let runtime = Arc::new(runtime_env_builder.build()?);
324 let session_state = SessionStateBuilder::new()
325 .with_config(config)
326 .with_runtime_env(runtime)
327 .with_default_features()
328 .build();
329 Ok(PySessionContext {
330 ctx: SessionContext::new_with_state(session_state),
331 })
332 }
333
334 pub fn enable_url_table(&self) -> PyResult<Self> {
335 Ok(PySessionContext {
336 ctx: self.ctx.clone().enable_url_table(),
337 })
338 }
339
340 #[classmethod]
341 #[pyo3(signature = ())]
342 fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
343 Ok(Self {
344 ctx: get_global_ctx().clone(),
345 })
346 }
347
348 #[pyo3(signature = (scheme, store, host=None))]
350 pub fn register_object_store(
351 &self,
352 scheme: &str,
353 store: StorageContexts,
354 host: Option<&str>,
355 ) -> PyResult<()> {
356 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
358 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
359 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
360 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
361 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
362 StorageContexts::HTTP(http) => (http.store, http.url),
363 };
364
365 let derived_host = if let Some(host) = host {
367 host
368 } else {
369 &upstream_host
370 };
371 let url_string = format!("{scheme}{derived_host}");
372 let url = Url::parse(&url_string).unwrap();
373 self.ctx.runtime_env().register_object_store(&url, store);
374 Ok(())
375 }
376
377 #[allow(clippy::too_many_arguments)]
378 #[pyo3(signature = (name, path, table_partition_cols=vec![],
379 file_extension=".parquet",
380 schema=None,
381 file_sort_order=None))]
382 pub fn register_listing_table(
383 &self,
384 name: &str,
385 path: &str,
386 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
387 file_extension: &str,
388 schema: Option<PyArrowType<Schema>>,
389 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
390 py: Python,
391 ) -> PyDataFusionResult<()> {
392 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
393 .with_file_extension(file_extension)
394 .with_table_partition_cols(
395 table_partition_cols
396 .into_iter()
397 .map(|(name, ty)| (name, ty.0))
398 .collect::<Vec<(String, DataType)>>(),
399 )
400 .with_file_sort_order(
401 file_sort_order
402 .unwrap_or_default()
403 .into_iter()
404 .map(|e| e.into_iter().map(|f| f.into()).collect())
405 .collect(),
406 );
407 let table_path = ListingTableUrl::parse(path)?;
408 let resolved_schema: SchemaRef = match schema {
409 Some(s) => Arc::new(s.0),
410 None => {
411 let state = self.ctx.state();
412 let schema = options.infer_schema(&state, &table_path);
413 wait_for_future(py, schema)??
414 }
415 };
416 let config = ListingTableConfig::new(table_path)
417 .with_listing_options(options)
418 .with_schema(resolved_schema);
419 let table = ListingTable::try_new(config)?;
420 self.ctx.register_table(name, Arc::new(table))?;
421 Ok(())
422 }
423
424 pub fn register_udtf(&self, func: PyTableFunction) {
425 let name = func.name.clone();
426 let func = Arc::new(func);
427 self.ctx.register_udtf(&name, func);
428 }
429
430 pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
432 let result = self.ctx.sql(query);
433 let df = wait_for_future(py, result)??;
434 Ok(PyDataFrame::new(df))
435 }
436
437 #[pyo3(signature = (query, options=None))]
438 pub fn sql_with_options(
439 &self,
440 query: &str,
441 options: Option<PySQLOptions>,
442 py: Python,
443 ) -> PyDataFusionResult<PyDataFrame> {
444 let options = if let Some(options) = options {
445 options.options
446 } else {
447 SQLOptions::new()
448 };
449 let result = self.ctx.sql_with_options(query, options);
450 let df = wait_for_future(py, result)??;
451 Ok(PyDataFrame::new(df))
452 }
453
454 #[pyo3(signature = (partitions, name=None, schema=None))]
455 pub fn create_dataframe(
456 &self,
457 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
458 name: Option<&str>,
459 schema: Option<PyArrowType<Schema>>,
460 py: Python,
461 ) -> PyDataFusionResult<PyDataFrame> {
462 let schema = if let Some(schema) = schema {
463 SchemaRef::from(schema.0)
464 } else {
465 partitions.0[0][0].schema()
466 };
467
468 let table = MemTable::try_new(schema, partitions.0)?;
469
470 let table_name = match name {
473 Some(val) => val.to_owned(),
474 None => {
475 "c".to_owned()
476 + Uuid::new_v4()
477 .simple()
478 .encode_lower(&mut Uuid::encode_buffer())
479 }
480 };
481
482 self.ctx.register_table(&*table_name, Arc::new(table))?;
483
484 let table = wait_for_future(py, self._table(&table_name))??;
485
486 let df = PyDataFrame::new(table);
487 Ok(df)
488 }
489
490 pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
492 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
493 }
494
495 #[pyo3(signature = (data, name=None))]
497 pub fn from_pylist(
498 &self,
499 data: Bound<'_, PyList>,
500 name: Option<&str>,
501 ) -> PyResult<PyDataFrame> {
502 let py = data.py();
504
505 let table_class = py.import("pyarrow")?.getattr("Table")?;
507 let args = PyTuple::new(py, &[data])?;
508 let table = table_class.call_method1("from_pylist", args)?;
509
510 let df = self.from_arrow(table, name, py)?;
512 Ok(df)
513 }
514
515 #[pyo3(signature = (data, name=None))]
517 pub fn from_pydict(
518 &self,
519 data: Bound<'_, PyDict>,
520 name: Option<&str>,
521 ) -> PyResult<PyDataFrame> {
522 let py = data.py();
524
525 let table_class = py.import("pyarrow")?.getattr("Table")?;
527 let args = PyTuple::new(py, &[data])?;
528 let table = table_class.call_method1("from_pydict", args)?;
529
530 let df = self.from_arrow(table, name, py)?;
532 Ok(df)
533 }
534
535 #[pyo3(signature = (data, name=None))]
537 pub fn from_arrow(
538 &self,
539 data: Bound<'_, PyAny>,
540 name: Option<&str>,
541 py: Python,
542 ) -> PyDataFusionResult<PyDataFrame> {
543 let (schema, batches) =
544 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
545 let schema = stream_reader.schema().as_ref().to_owned();
548 let batches = stream_reader
549 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
550
551 (schema, batches)
552 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
553 (array.schema().as_ref().to_owned(), vec![array])
557 } else {
558 return Err(crate::errors::PyDataFusionError::Common(
559 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
560 ));
561 };
562
563 let list_of_batches = PyArrowType::from(vec![batches]);
566 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
567 }
568
569 #[allow(clippy::wrong_self_convention)]
571 #[pyo3(signature = (data, name=None))]
572 pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
573 let py = data.py();
575
576 let table_class = py.import("pyarrow")?.getattr("Table")?;
578 let args = PyTuple::new(py, &[data])?;
579 let table = table_class.call_method1("from_pandas", args)?;
580
581 let df = self.from_arrow(table, name, py)?;
583 Ok(df)
584 }
585
586 #[pyo3(signature = (data, name=None))]
588 pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
589 let table = data.call_method0("to_arrow")?;
591
592 let df = self.from_arrow(table, name, data.py())?;
594 Ok(df)
595 }
596
597 pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
598 let table = PyTable::new(&table)?;
599
600 self.ctx.register_table(name, table.table)?;
601 Ok(())
602 }
603
604 pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
605 self.ctx.deregister_table(name)?;
606 Ok(())
607 }
608
609 pub fn register_catalog_provider(
610 &self,
611 name: &str,
612 provider: Bound<'_, PyAny>,
613 ) -> PyDataFusionResult<()> {
614 let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
615 let capsule = provider
616 .getattr("__datafusion_catalog_provider__")?
617 .call0()?;
618 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
619 validate_pycapsule(capsule, "datafusion_catalog_provider")?;
620
621 let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
622 let provider: ForeignCatalogProvider = provider.into();
623 Arc::new(provider) as Arc<dyn CatalogProvider>
624 } else {
625 match provider.extract::<PyCatalog>() {
626 Ok(py_catalog) => py_catalog.catalog,
627 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
628 as Arc<dyn CatalogProvider>,
629 }
630 };
631
632 let _ = self.ctx.register_catalog(name, provider);
633
634 Ok(())
635 }
636
637 pub fn register_table_provider(
639 &self,
640 name: &str,
641 provider: Bound<'_, PyAny>,
642 ) -> PyDataFusionResult<()> {
643 self.register_table(name, provider)
645 }
646
647 pub fn register_record_batches(
648 &self,
649 name: &str,
650 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
651 ) -> PyDataFusionResult<()> {
652 let schema = partitions.0[0][0].schema();
653 let table = MemTable::try_new(schema, partitions.0)?;
654 self.ctx.register_table(name, Arc::new(table))?;
655 Ok(())
656 }
657
658 #[allow(clippy::too_many_arguments)]
659 #[pyo3(signature = (name, path, table_partition_cols=vec![],
660 parquet_pruning=true,
661 file_extension=".parquet",
662 skip_metadata=true,
663 schema=None,
664 file_sort_order=None))]
665 pub fn register_parquet(
666 &self,
667 name: &str,
668 path: &str,
669 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
670 parquet_pruning: bool,
671 file_extension: &str,
672 skip_metadata: bool,
673 schema: Option<PyArrowType<Schema>>,
674 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
675 py: Python,
676 ) -> PyDataFusionResult<()> {
677 let mut options = ParquetReadOptions::default()
678 .table_partition_cols(
679 table_partition_cols
680 .into_iter()
681 .map(|(name, ty)| (name, ty.0))
682 .collect::<Vec<(String, DataType)>>(),
683 )
684 .parquet_pruning(parquet_pruning)
685 .skip_metadata(skip_metadata);
686 options.file_extension = file_extension;
687 options.schema = schema.as_ref().map(|x| &x.0);
688 options.file_sort_order = file_sort_order
689 .unwrap_or_default()
690 .into_iter()
691 .map(|e| e.into_iter().map(|f| f.into()).collect())
692 .collect();
693
694 let result = self.ctx.register_parquet(name, path, options);
695 wait_for_future(py, result)??;
696 Ok(())
697 }
698
699 #[allow(clippy::too_many_arguments)]
700 #[pyo3(signature = (name,
701 path,
702 schema=None,
703 has_header=true,
704 delimiter=",",
705 schema_infer_max_records=1000,
706 file_extension=".csv",
707 file_compression_type=None))]
708 pub fn register_csv(
709 &self,
710 name: &str,
711 path: &Bound<'_, PyAny>,
712 schema: Option<PyArrowType<Schema>>,
713 has_header: bool,
714 delimiter: &str,
715 schema_infer_max_records: usize,
716 file_extension: &str,
717 file_compression_type: Option<String>,
718 py: Python,
719 ) -> PyDataFusionResult<()> {
720 let delimiter = delimiter.as_bytes();
721 if delimiter.len() != 1 {
722 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
723 "Delimiter must be a single character",
724 )));
725 }
726
727 let mut options = CsvReadOptions::new()
728 .has_header(has_header)
729 .delimiter(delimiter[0])
730 .schema_infer_max_records(schema_infer_max_records)
731 .file_extension(file_extension)
732 .file_compression_type(parse_file_compression_type(file_compression_type)?);
733 options.schema = schema.as_ref().map(|x| &x.0);
734
735 if path.is_instance_of::<PyList>() {
736 let paths = path.extract::<Vec<String>>()?;
737 let result = self.register_csv_from_multiple_paths(name, paths, options);
738 wait_for_future(py, result)??;
739 } else {
740 let path = path.extract::<String>()?;
741 let result = self.ctx.register_csv(name, &path, options);
742 wait_for_future(py, result)??;
743 }
744
745 Ok(())
746 }
747
748 #[allow(clippy::too_many_arguments)]
749 #[pyo3(signature = (name,
750 path,
751 schema=None,
752 schema_infer_max_records=1000,
753 file_extension=".json",
754 table_partition_cols=vec![],
755 file_compression_type=None))]
756 pub fn register_json(
757 &self,
758 name: &str,
759 path: PathBuf,
760 schema: Option<PyArrowType<Schema>>,
761 schema_infer_max_records: usize,
762 file_extension: &str,
763 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
764 file_compression_type: Option<String>,
765 py: Python,
766 ) -> PyDataFusionResult<()> {
767 let path = path
768 .to_str()
769 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
770
771 let mut options = NdJsonReadOptions::default()
772 .file_compression_type(parse_file_compression_type(file_compression_type)?)
773 .table_partition_cols(
774 table_partition_cols
775 .into_iter()
776 .map(|(name, ty)| (name, ty.0))
777 .collect::<Vec<(String, DataType)>>(),
778 );
779 options.schema_infer_max_records = schema_infer_max_records;
780 options.file_extension = file_extension;
781 options.schema = schema.as_ref().map(|x| &x.0);
782
783 let result = self.ctx.register_json(name, path, options);
784 wait_for_future(py, result)??;
785
786 Ok(())
787 }
788
789 #[allow(clippy::too_many_arguments)]
790 #[pyo3(signature = (name,
791 path,
792 schema=None,
793 file_extension=".avro",
794 table_partition_cols=vec![]))]
795 pub fn register_avro(
796 &self,
797 name: &str,
798 path: PathBuf,
799 schema: Option<PyArrowType<Schema>>,
800 file_extension: &str,
801 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
802 py: Python,
803 ) -> PyDataFusionResult<()> {
804 let path = path
805 .to_str()
806 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
807
808 let mut options = AvroReadOptions::default().table_partition_cols(
809 table_partition_cols
810 .into_iter()
811 .map(|(name, ty)| (name, ty.0))
812 .collect::<Vec<(String, DataType)>>(),
813 );
814 options.file_extension = file_extension;
815 options.schema = schema.as_ref().map(|x| &x.0);
816
817 let result = self.ctx.register_avro(name, path, options);
818 wait_for_future(py, result)??;
819
820 Ok(())
821 }
822
823 pub fn register_dataset(
825 &self,
826 name: &str,
827 dataset: &Bound<'_, PyAny>,
828 py: Python,
829 ) -> PyDataFusionResult<()> {
830 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
831
832 self.ctx.register_table(name, table)?;
833
834 Ok(())
835 }
836
837 pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
838 self.ctx.register_udf(udf.function);
839 Ok(())
840 }
841
842 pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
843 self.ctx.register_udaf(udaf.function);
844 Ok(())
845 }
846
847 pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
848 self.ctx.register_udwf(udwf.function);
849 Ok(())
850 }
851
852 #[pyo3(signature = (name="datafusion"))]
853 pub fn catalog(&self, name: &str) -> PyResult<PyObject> {
854 let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
855 "Catalog with name {name} doesn't exist."
856 )))?;
857
858 Python::with_gil(|py| {
859 match catalog
860 .as_any()
861 .downcast_ref::<RustWrappedPyCatalogProvider>()
862 {
863 Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
864 None => PyCatalog::from(catalog).into_py_any(py),
865 }
866 })
867 }
868
869 pub fn catalog_names(&self) -> HashSet<String> {
870 self.ctx.catalog_names().into_iter().collect()
871 }
872
873 pub fn tables(&self) -> HashSet<String> {
874 self.ctx
875 .catalog_names()
876 .into_iter()
877 .filter_map(|name| self.ctx.catalog(&name))
878 .flat_map(move |catalog| {
879 catalog
880 .schema_names()
881 .into_iter()
882 .filter_map(move |name| catalog.schema(&name))
883 })
884 .flat_map(|schema| schema.table_names())
885 .collect()
886 }
887
888 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
889 let res = wait_for_future(py, self.ctx.table(name))
890 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
891 match res {
892 Ok(df) => Ok(PyDataFrame::new(df)),
893 Err(e) => {
894 if let datafusion::error::DataFusionError::Plan(msg) = &e {
895 if msg.contains("No table named") {
896 return Err(PyKeyError::new_err(msg.to_string()));
897 }
898 }
899 Err(py_datafusion_err(e))
900 }
901 }
902 }
903
904 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
905 Ok(self.ctx.table_exist(name)?)
906 }
907
908 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
909 Ok(PyDataFrame::new(self.ctx.read_empty()?))
910 }
911
912 pub fn session_id(&self) -> String {
913 self.ctx.session_id()
914 }
915
916 #[allow(clippy::too_many_arguments)]
917 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
918 pub fn read_json(
919 &self,
920 path: PathBuf,
921 schema: Option<PyArrowType<Schema>>,
922 schema_infer_max_records: usize,
923 file_extension: &str,
924 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
925 file_compression_type: Option<String>,
926 py: Python,
927 ) -> PyDataFusionResult<PyDataFrame> {
928 let path = path
929 .to_str()
930 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
931 let mut options = NdJsonReadOptions::default()
932 .table_partition_cols(
933 table_partition_cols
934 .into_iter()
935 .map(|(name, ty)| (name, ty.0))
936 .collect::<Vec<(String, DataType)>>(),
937 )
938 .file_compression_type(parse_file_compression_type(file_compression_type)?);
939 options.schema_infer_max_records = schema_infer_max_records;
940 options.file_extension = file_extension;
941 let df = if let Some(schema) = schema {
942 options.schema = Some(&schema.0);
943 let result = self.ctx.read_json(path, options);
944 wait_for_future(py, result)??
945 } else {
946 let result = self.ctx.read_json(path, options);
947 wait_for_future(py, result)??
948 };
949 Ok(PyDataFrame::new(df))
950 }
951
952 #[allow(clippy::too_many_arguments)]
953 #[pyo3(signature = (
954 path,
955 schema=None,
956 has_header=true,
957 delimiter=",",
958 schema_infer_max_records=1000,
959 file_extension=".csv",
960 table_partition_cols=vec![],
961 file_compression_type=None))]
962 pub fn read_csv(
963 &self,
964 path: &Bound<'_, PyAny>,
965 schema: Option<PyArrowType<Schema>>,
966 has_header: bool,
967 delimiter: &str,
968 schema_infer_max_records: usize,
969 file_extension: &str,
970 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
971 file_compression_type: Option<String>,
972 py: Python,
973 ) -> PyDataFusionResult<PyDataFrame> {
974 let delimiter = delimiter.as_bytes();
975 if delimiter.len() != 1 {
976 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
977 "Delimiter must be a single character",
978 )));
979 };
980
981 let mut options = CsvReadOptions::new()
982 .has_header(has_header)
983 .delimiter(delimiter[0])
984 .schema_infer_max_records(schema_infer_max_records)
985 .file_extension(file_extension)
986 .table_partition_cols(
987 table_partition_cols
988 .into_iter()
989 .map(|(name, ty)| (name, ty.0))
990 .collect::<Vec<(String, DataType)>>(),
991 )
992 .file_compression_type(parse_file_compression_type(file_compression_type)?);
993 options.schema = schema.as_ref().map(|x| &x.0);
994
995 if path.is_instance_of::<PyList>() {
996 let paths = path.extract::<Vec<String>>()?;
997 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
998 let result = self.ctx.read_csv(paths, options);
999 let df = PyDataFrame::new(wait_for_future(py, result)??);
1000 Ok(df)
1001 } else {
1002 let path = path.extract::<String>()?;
1003 let result = self.ctx.read_csv(path, options);
1004 let df = PyDataFrame::new(wait_for_future(py, result)??);
1005 Ok(df)
1006 }
1007 }
1008
1009 #[allow(clippy::too_many_arguments)]
1010 #[pyo3(signature = (
1011 path,
1012 table_partition_cols=vec![],
1013 parquet_pruning=true,
1014 file_extension=".parquet",
1015 skip_metadata=true,
1016 schema=None,
1017 file_sort_order=None))]
1018 pub fn read_parquet(
1019 &self,
1020 path: &str,
1021 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1022 parquet_pruning: bool,
1023 file_extension: &str,
1024 skip_metadata: bool,
1025 schema: Option<PyArrowType<Schema>>,
1026 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1027 py: Python,
1028 ) -> PyDataFusionResult<PyDataFrame> {
1029 let mut options = ParquetReadOptions::default()
1030 .table_partition_cols(
1031 table_partition_cols
1032 .into_iter()
1033 .map(|(name, ty)| (name, ty.0))
1034 .collect::<Vec<(String, DataType)>>(),
1035 )
1036 .parquet_pruning(parquet_pruning)
1037 .skip_metadata(skip_metadata);
1038 options.file_extension = file_extension;
1039 options.schema = schema.as_ref().map(|x| &x.0);
1040 options.file_sort_order = file_sort_order
1041 .unwrap_or_default()
1042 .into_iter()
1043 .map(|e| e.into_iter().map(|f| f.into()).collect())
1044 .collect();
1045
1046 let result = self.ctx.read_parquet(path, options);
1047 let df = PyDataFrame::new(wait_for_future(py, result)??);
1048 Ok(df)
1049 }
1050
1051 #[allow(clippy::too_many_arguments)]
1052 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1053 pub fn read_avro(
1054 &self,
1055 path: &str,
1056 schema: Option<PyArrowType<Schema>>,
1057 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1058 file_extension: &str,
1059 py: Python,
1060 ) -> PyDataFusionResult<PyDataFrame> {
1061 let mut options = AvroReadOptions::default().table_partition_cols(
1062 table_partition_cols
1063 .into_iter()
1064 .map(|(name, ty)| (name, ty.0))
1065 .collect::<Vec<(String, DataType)>>(),
1066 );
1067 options.file_extension = file_extension;
1068 let df = if let Some(schema) = schema {
1069 options.schema = Some(&schema.0);
1070 let read_future = self.ctx.read_avro(path, options);
1071 wait_for_future(py, read_future)??
1072 } else {
1073 let read_future = self.ctx.read_avro(path, options);
1074 wait_for_future(py, read_future)??
1075 };
1076 Ok(PyDataFrame::new(df))
1077 }
1078
1079 pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1080 let table = PyTable::new(&table)?;
1081 let df = self.ctx.read_table(table.table())?;
1082 Ok(PyDataFrame::new(df))
1083 }
1084
1085 fn __repr__(&self) -> PyResult<String> {
1086 let config = self.ctx.copied_config();
1087 let mut config_entries = config
1088 .options()
1089 .entries()
1090 .iter()
1091 .filter(|e| e.value.is_some())
1092 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1093 .collect::<Vec<_>>();
1094 config_entries.sort();
1095 Ok(format!(
1096 "SessionContext: id={}; configs=[\n\t{}]",
1097 self.session_id(),
1098 config_entries.join("\n\t")
1099 ))
1100 }
1101
1102 pub fn execute(
1104 &self,
1105 plan: PyExecutionPlan,
1106 part: usize,
1107 py: Python,
1108 ) -> PyDataFusionResult<PyRecordBatchStream> {
1109 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1110 let rt = &get_tokio_runtime().0;
1112 let plan = plan.plan.clone();
1113 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1114 rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1115 let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1116 Ok(PyRecordBatchStream::new(stream))
1117 }
1118}
1119
1120impl PySessionContext {
1121 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1122 self.ctx.table(name).await
1123 }
1124
1125 async fn register_csv_from_multiple_paths(
1126 &self,
1127 name: &str,
1128 table_paths: Vec<String>,
1129 options: CsvReadOptions<'_>,
1130 ) -> datafusion::common::Result<()> {
1131 let table_paths = table_paths.to_urls()?;
1132 let session_config = self.ctx.copied_config();
1133 let listing_options =
1134 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1135
1136 let option_extension = listing_options.file_extension.clone();
1137
1138 if table_paths.is_empty() {
1139 return exec_err!("No table paths were provided");
1140 }
1141
1142 for path in &table_paths {
1144 let file_path = path.as_str();
1145 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1146 return exec_err!(
1147 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1148 );
1149 }
1150 }
1151
1152 let resolved_schema = options
1153 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1154 .await?;
1155
1156 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1157 .with_listing_options(listing_options)
1158 .with_schema(resolved_schema);
1159 let table = ListingTable::try_new(config)?;
1160 self.ctx
1161 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1162 Ok(())
1163 }
1164}
1165
1166pub fn parse_file_compression_type(
1167 file_compression_type: Option<String>,
1168) -> Result<FileCompressionType, PyErr> {
1169 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1170 .map_err(|_| {
1171 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1172 })
1173}
1174
1175impl From<PySessionContext> for SessionContext {
1176 fn from(ctx: PySessionContext) -> SessionContext {
1177 ctx.ctx
1178 }
1179}
1180
1181impl From<SessionContext> for PySessionContext {
1182 fn from(ctx: SessionContext) -> PySessionContext {
1183 PySessionContext { ctx }
1184 }
1185}