1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
27use datafusion::arrow::pyarrow::PyArrowType;
28use datafusion::arrow::record_batch::RecordBatch;
29use datafusion::catalog::CatalogProvider;
30use datafusion::common::{exec_err, ScalarValue, TableReference};
31use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
32use datafusion::datasource::file_format::parquet::ParquetFormat;
33use datafusion::datasource::listing::{
34 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35};
36use datafusion::datasource::{MemTable, TableProvider};
37use datafusion::execution::context::{
38 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
39};
40use datafusion::execution::disk_manager::DiskManagerMode;
41use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
42use datafusion::execution::options::ReadOptions;
43use datafusion::execution::runtime_env::RuntimeEnvBuilder;
44use datafusion::execution::session_state::SessionStateBuilder;
45use datafusion::prelude::{
46 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
47};
48use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
49use object_store::ObjectStore;
50use pyo3::exceptions::{PyKeyError, PyValueError};
51use pyo3::prelude::*;
52use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
53use pyo3::IntoPyObjectExt;
54use url::Url;
55use uuid::Uuid;
56
57use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
58use crate::common::data_type::PyScalarValue;
59use crate::dataframe::PyDataFrame;
60use crate::dataset::Dataset;
61use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
62use crate::expr::sort_expr::PySortExpr;
63use crate::physical_plan::PyExecutionPlan;
64use crate::record_batch::PyRecordBatchStream;
65use crate::sql::exceptions::py_value_err;
66use crate::sql::logical::PyLogicalPlan;
67use crate::sql::util::replace_placeholders_with_strings;
68use crate::store::StorageContexts;
69use crate::table::PyTable;
70use crate::udaf::PyAggregateUDF;
71use crate::udf::PyScalarUDF;
72use crate::udtf::PyTableFunction;
73use crate::udwf::PyWindowUDF;
74use crate::utils::{get_global_ctx, spawn_future, validate_pycapsule, wait_for_future};
75
76#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
78#[derive(Clone, Default)]
79pub struct PySessionConfig {
80 pub config: SessionConfig,
81}
82
83impl From<SessionConfig> for PySessionConfig {
84 fn from(config: SessionConfig) -> Self {
85 Self { config }
86 }
87}
88
89#[pymethods]
90impl PySessionConfig {
91 #[pyo3(signature = (config_options=None))]
92 #[new]
93 fn new(config_options: Option<HashMap<String, String>>) -> Self {
94 let mut config = SessionConfig::new();
95 if let Some(hash_map) = config_options {
96 for (k, v) in &hash_map {
97 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
98 }
99 }
100
101 Self { config }
102 }
103
104 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
105 Self::from(
106 self.config
107 .clone()
108 .with_create_default_catalog_and_schema(enabled),
109 )
110 }
111
112 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
113 Self::from(
114 self.config
115 .clone()
116 .with_default_catalog_and_schema(catalog, schema),
117 )
118 }
119
120 fn with_information_schema(&self, enabled: bool) -> Self {
121 Self::from(self.config.clone().with_information_schema(enabled))
122 }
123
124 fn with_batch_size(&self, batch_size: usize) -> Self {
125 Self::from(self.config.clone().with_batch_size(batch_size))
126 }
127
128 fn with_target_partitions(&self, target_partitions: usize) -> Self {
129 Self::from(
130 self.config
131 .clone()
132 .with_target_partitions(target_partitions),
133 )
134 }
135
136 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
137 Self::from(self.config.clone().with_repartition_aggregations(enabled))
138 }
139
140 fn with_repartition_joins(&self, enabled: bool) -> Self {
141 Self::from(self.config.clone().with_repartition_joins(enabled))
142 }
143
144 fn with_repartition_windows(&self, enabled: bool) -> Self {
145 Self::from(self.config.clone().with_repartition_windows(enabled))
146 }
147
148 fn with_repartition_sorts(&self, enabled: bool) -> Self {
149 Self::from(self.config.clone().with_repartition_sorts(enabled))
150 }
151
152 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
153 Self::from(self.config.clone().with_repartition_file_scans(enabled))
154 }
155
156 fn with_repartition_file_min_size(&self, size: usize) -> Self {
157 Self::from(self.config.clone().with_repartition_file_min_size(size))
158 }
159
160 fn with_parquet_pruning(&self, enabled: bool) -> Self {
161 Self::from(self.config.clone().with_parquet_pruning(enabled))
162 }
163
164 fn set(&self, key: &str, value: &str) -> Self {
165 Self::from(self.config.clone().set_str(key, value))
166 }
167}
168
169#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
171#[derive(Clone)]
172pub struct PyRuntimeEnvBuilder {
173 pub builder: RuntimeEnvBuilder,
174}
175
176#[pymethods]
177impl PyRuntimeEnvBuilder {
178 #[new]
179 fn new() -> Self {
180 Self {
181 builder: RuntimeEnvBuilder::default(),
182 }
183 }
184
185 fn with_disk_manager_disabled(&self) -> Self {
186 let mut runtime_builder = self.builder.clone();
187
188 let mut disk_mgr_builder = runtime_builder
189 .disk_manager_builder
190 .clone()
191 .unwrap_or_default();
192 disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
193
194 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
195 Self {
196 builder: runtime_builder,
197 }
198 }
199
200 fn with_disk_manager_os(&self) -> Self {
201 let mut runtime_builder = self.builder.clone();
202
203 let mut disk_mgr_builder = runtime_builder
204 .disk_manager_builder
205 .clone()
206 .unwrap_or_default();
207 disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
208
209 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
210 Self {
211 builder: runtime_builder,
212 }
213 }
214
215 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
216 let paths = paths.iter().map(|s| s.into()).collect();
217 let mut runtime_builder = self.builder.clone();
218
219 let mut disk_mgr_builder = runtime_builder
220 .disk_manager_builder
221 .clone()
222 .unwrap_or_default();
223 disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
224
225 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
226 Self {
227 builder: runtime_builder,
228 }
229 }
230
231 fn with_unbounded_memory_pool(&self) -> Self {
232 let builder = self.builder.clone();
233 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
234 Self { builder }
235 }
236
237 fn with_fair_spill_pool(&self, size: usize) -> Self {
238 let builder = self.builder.clone();
239 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
240 Self { builder }
241 }
242
243 fn with_greedy_memory_pool(&self, size: usize) -> Self {
244 let builder = self.builder.clone();
245 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
246 Self { builder }
247 }
248
249 fn with_temp_file_path(&self, path: &str) -> Self {
250 let builder = self.builder.clone();
251 let builder = builder.with_temp_file_path(path);
252 Self { builder }
253 }
254}
255
256#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
258#[derive(Clone)]
259pub struct PySQLOptions {
260 pub options: SQLOptions,
261}
262
263impl From<SQLOptions> for PySQLOptions {
264 fn from(options: SQLOptions) -> Self {
265 Self { options }
266 }
267}
268
269#[pymethods]
270impl PySQLOptions {
271 #[new]
272 fn new() -> Self {
273 let options = SQLOptions::new();
274 Self { options }
275 }
276
277 fn with_allow_ddl(&self, allow: bool) -> Self {
279 Self::from(self.options.with_allow_ddl(allow))
280 }
281
282 pub fn with_allow_dml(&self, allow: bool) -> Self {
284 Self::from(self.options.with_allow_dml(allow))
285 }
286
287 pub fn with_allow_statements(&self, allow: bool) -> Self {
289 Self::from(self.options.with_allow_statements(allow))
290 }
291}
292
293#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
297#[derive(Clone)]
298pub struct PySessionContext {
299 pub ctx: SessionContext,
300}
301
302#[pymethods]
303impl PySessionContext {
304 #[pyo3(signature = (config=None, runtime=None))]
305 #[new]
306 pub fn new(
307 config: Option<PySessionConfig>,
308 runtime: Option<PyRuntimeEnvBuilder>,
309 ) -> PyDataFusionResult<Self> {
310 let config = if let Some(c) = config {
311 c.config
312 } else {
313 SessionConfig::default().with_information_schema(true)
314 };
315 let runtime_env_builder = if let Some(c) = runtime {
316 c.builder
317 } else {
318 RuntimeEnvBuilder::default()
319 };
320 let runtime = Arc::new(runtime_env_builder.build()?);
321 let session_state = SessionStateBuilder::new()
322 .with_config(config)
323 .with_runtime_env(runtime)
324 .with_default_features()
325 .build();
326 Ok(PySessionContext {
327 ctx: SessionContext::new_with_state(session_state),
328 })
329 }
330
331 pub fn enable_url_table(&self) -> PyResult<Self> {
332 Ok(PySessionContext {
333 ctx: self.ctx.clone().enable_url_table(),
334 })
335 }
336
337 #[classmethod]
338 #[pyo3(signature = ())]
339 fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
340 Ok(Self {
341 ctx: get_global_ctx().clone(),
342 })
343 }
344
345 #[pyo3(signature = (scheme, store, host=None))]
347 pub fn register_object_store(
348 &self,
349 scheme: &str,
350 store: StorageContexts,
351 host: Option<&str>,
352 ) -> PyResult<()> {
353 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
355 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
356 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
357 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
358 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
359 StorageContexts::HTTP(http) => (http.store, http.url),
360 };
361
362 let derived_host = if let Some(host) = host {
364 host
365 } else {
366 &upstream_host
367 };
368 let url_string = format!("{scheme}{derived_host}");
369 let url = Url::parse(&url_string).unwrap();
370 self.ctx.runtime_env().register_object_store(&url, store);
371 Ok(())
372 }
373
374 #[allow(clippy::too_many_arguments)]
375 #[pyo3(signature = (name, path, table_partition_cols=vec![],
376 file_extension=".parquet",
377 schema=None,
378 file_sort_order=None))]
379 pub fn register_listing_table(
380 &self,
381 name: &str,
382 path: &str,
383 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
384 file_extension: &str,
385 schema: Option<PyArrowType<Schema>>,
386 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
387 py: Python,
388 ) -> PyDataFusionResult<()> {
389 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
390 .with_file_extension(file_extension)
391 .with_table_partition_cols(
392 table_partition_cols
393 .into_iter()
394 .map(|(name, ty)| (name, ty.0))
395 .collect::<Vec<(String, DataType)>>(),
396 )
397 .with_file_sort_order(
398 file_sort_order
399 .unwrap_or_default()
400 .into_iter()
401 .map(|e| e.into_iter().map(|f| f.into()).collect())
402 .collect(),
403 );
404 let table_path = ListingTableUrl::parse(path)?;
405 let resolved_schema: SchemaRef = match schema {
406 Some(s) => Arc::new(s.0),
407 None => {
408 let state = self.ctx.state();
409 let schema = options.infer_schema(&state, &table_path);
410 wait_for_future(py, schema)??
411 }
412 };
413 let config = ListingTableConfig::new(table_path)
414 .with_listing_options(options)
415 .with_schema(resolved_schema);
416 let table = ListingTable::try_new(config)?;
417 self.ctx.register_table(name, Arc::new(table))?;
418 Ok(())
419 }
420
421 pub fn register_udtf(&self, func: PyTableFunction) {
422 let name = func.name.clone();
423 let func = Arc::new(func);
424 self.ctx.register_udtf(&name, func);
425 }
426
427 #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
428 pub fn sql_with_options(
429 &self,
430 py: Python,
431 mut query: String,
432 options: Option<PySQLOptions>,
433 param_values: HashMap<String, PyScalarValue>,
434 param_strings: HashMap<String, String>,
435 ) -> PyDataFusionResult<PyDataFrame> {
436 let options = if let Some(options) = options {
437 options.options
438 } else {
439 SQLOptions::new()
440 };
441
442 let param_values = param_values
443 .into_iter()
444 .map(|(name, value)| (name, ScalarValue::from(value)))
445 .collect::<HashMap<_, _>>();
446
447 let state = self.ctx.state();
448 let dialect = state.config().options().sql_parser.dialect.as_ref();
449
450 if !param_strings.is_empty() {
451 query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
452 }
453
454 let mut df = wait_for_future(py, async {
455 self.ctx.sql_with_options(&query, options).await
456 })??;
457
458 if !param_values.is_empty() {
459 df = df.with_param_values(param_values)?;
460 }
461
462 Ok(PyDataFrame::new(df))
463 }
464
465 #[pyo3(signature = (partitions, name=None, schema=None))]
466 pub fn create_dataframe(
467 &self,
468 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
469 name: Option<&str>,
470 schema: Option<PyArrowType<Schema>>,
471 py: Python,
472 ) -> PyDataFusionResult<PyDataFrame> {
473 let schema = if let Some(schema) = schema {
474 SchemaRef::from(schema.0)
475 } else {
476 partitions.0[0][0].schema()
477 };
478
479 let table = MemTable::try_new(schema, partitions.0)?;
480
481 let table_name = match name {
484 Some(val) => val.to_owned(),
485 None => {
486 "c".to_owned()
487 + Uuid::new_v4()
488 .simple()
489 .encode_lower(&mut Uuid::encode_buffer())
490 }
491 };
492
493 self.ctx.register_table(&*table_name, Arc::new(table))?;
494
495 let table = wait_for_future(py, self._table(&table_name))??;
496
497 let df = PyDataFrame::new(table);
498 Ok(df)
499 }
500
501 pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
503 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
504 }
505
506 #[pyo3(signature = (data, name=None))]
508 pub fn from_pylist(
509 &self,
510 data: Bound<'_, PyList>,
511 name: Option<&str>,
512 ) -> PyResult<PyDataFrame> {
513 let py = data.py();
515
516 let table_class = py.import("pyarrow")?.getattr("Table")?;
518 let args = PyTuple::new(py, &[data])?;
519 let table = table_class.call_method1("from_pylist", args)?;
520
521 let df = self.from_arrow(table, name, py)?;
523 Ok(df)
524 }
525
526 #[pyo3(signature = (data, name=None))]
528 pub fn from_pydict(
529 &self,
530 data: Bound<'_, PyDict>,
531 name: Option<&str>,
532 ) -> PyResult<PyDataFrame> {
533 let py = data.py();
535
536 let table_class = py.import("pyarrow")?.getattr("Table")?;
538 let args = PyTuple::new(py, &[data])?;
539 let table = table_class.call_method1("from_pydict", args)?;
540
541 let df = self.from_arrow(table, name, py)?;
543 Ok(df)
544 }
545
546 #[pyo3(signature = (data, name=None))]
548 pub fn from_arrow(
549 &self,
550 data: Bound<'_, PyAny>,
551 name: Option<&str>,
552 py: Python,
553 ) -> PyDataFusionResult<PyDataFrame> {
554 let (schema, batches) =
555 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
556 let schema = stream_reader.schema().as_ref().to_owned();
559 let batches = stream_reader
560 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
561
562 (schema, batches)
563 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
564 (array.schema().as_ref().to_owned(), vec![array])
568 } else {
569 return Err(PyDataFusionError::Common(
570 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
571 ));
572 };
573
574 let list_of_batches = PyArrowType::from(vec![batches]);
577 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
578 }
579
580 #[allow(clippy::wrong_self_convention)]
582 #[pyo3(signature = (data, name=None))]
583 pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
584 let py = data.py();
586
587 let table_class = py.import("pyarrow")?.getattr("Table")?;
589 let args = PyTuple::new(py, &[data])?;
590 let table = table_class.call_method1("from_pandas", args)?;
591
592 let df = self.from_arrow(table, name, py)?;
594 Ok(df)
595 }
596
597 #[pyo3(signature = (data, name=None))]
599 pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
600 let table = data.call_method0("to_arrow")?;
602
603 let df = self.from_arrow(table, name, data.py())?;
605 Ok(df)
606 }
607
608 pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
609 let table = PyTable::new(&table)?;
610
611 self.ctx.register_table(name, table.table)?;
612 Ok(())
613 }
614
615 pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
616 self.ctx.deregister_table(name)?;
617 Ok(())
618 }
619
620 pub fn register_catalog_provider(
621 &self,
622 name: &str,
623 provider: Bound<'_, PyAny>,
624 ) -> PyDataFusionResult<()> {
625 let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
626 let capsule = provider
627 .getattr("__datafusion_catalog_provider__")?
628 .call0()?;
629 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
630 validate_pycapsule(capsule, "datafusion_catalog_provider")?;
631
632 let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
633 let provider: ForeignCatalogProvider = provider.into();
634 Arc::new(provider) as Arc<dyn CatalogProvider>
635 } else {
636 match provider.extract::<PyCatalog>() {
637 Ok(py_catalog) => py_catalog.catalog,
638 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
639 as Arc<dyn CatalogProvider>,
640 }
641 };
642
643 let _ = self.ctx.register_catalog(name, provider);
644
645 Ok(())
646 }
647
648 pub fn register_table_provider(
650 &self,
651 name: &str,
652 provider: Bound<'_, PyAny>,
653 ) -> PyDataFusionResult<()> {
654 self.register_table(name, provider)
656 }
657
658 pub fn register_record_batches(
659 &self,
660 name: &str,
661 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
662 ) -> PyDataFusionResult<()> {
663 let schema = partitions.0[0][0].schema();
664 let table = MemTable::try_new(schema, partitions.0)?;
665 self.ctx.register_table(name, Arc::new(table))?;
666 Ok(())
667 }
668
669 #[allow(clippy::too_many_arguments)]
670 #[pyo3(signature = (name, path, table_partition_cols=vec![],
671 parquet_pruning=true,
672 file_extension=".parquet",
673 skip_metadata=true,
674 schema=None,
675 file_sort_order=None))]
676 pub fn register_parquet(
677 &self,
678 name: &str,
679 path: &str,
680 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
681 parquet_pruning: bool,
682 file_extension: &str,
683 skip_metadata: bool,
684 schema: Option<PyArrowType<Schema>>,
685 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
686 py: Python,
687 ) -> PyDataFusionResult<()> {
688 let mut options = ParquetReadOptions::default()
689 .table_partition_cols(
690 table_partition_cols
691 .into_iter()
692 .map(|(name, ty)| (name, ty.0))
693 .collect::<Vec<(String, DataType)>>(),
694 )
695 .parquet_pruning(parquet_pruning)
696 .skip_metadata(skip_metadata);
697 options.file_extension = file_extension;
698 options.schema = schema.as_ref().map(|x| &x.0);
699 options.file_sort_order = file_sort_order
700 .unwrap_or_default()
701 .into_iter()
702 .map(|e| e.into_iter().map(|f| f.into()).collect())
703 .collect();
704
705 let result = self.ctx.register_parquet(name, path, options);
706 wait_for_future(py, result)??;
707 Ok(())
708 }
709
710 #[allow(clippy::too_many_arguments)]
711 #[pyo3(signature = (name,
712 path,
713 schema=None,
714 has_header=true,
715 delimiter=",",
716 schema_infer_max_records=1000,
717 file_extension=".csv",
718 file_compression_type=None))]
719 pub fn register_csv(
720 &self,
721 name: &str,
722 path: &Bound<'_, PyAny>,
723 schema: Option<PyArrowType<Schema>>,
724 has_header: bool,
725 delimiter: &str,
726 schema_infer_max_records: usize,
727 file_extension: &str,
728 file_compression_type: Option<String>,
729 py: Python,
730 ) -> PyDataFusionResult<()> {
731 let delimiter = delimiter.as_bytes();
732 if delimiter.len() != 1 {
733 return Err(PyDataFusionError::PythonError(py_value_err(
734 "Delimiter must be a single character",
735 )));
736 }
737
738 let mut options = CsvReadOptions::new()
739 .has_header(has_header)
740 .delimiter(delimiter[0])
741 .schema_infer_max_records(schema_infer_max_records)
742 .file_extension(file_extension)
743 .file_compression_type(parse_file_compression_type(file_compression_type)?);
744 options.schema = schema.as_ref().map(|x| &x.0);
745
746 if path.is_instance_of::<PyList>() {
747 let paths = path.extract::<Vec<String>>()?;
748 let result = self.register_csv_from_multiple_paths(name, paths, options);
749 wait_for_future(py, result)??;
750 } else {
751 let path = path.extract::<String>()?;
752 let result = self.ctx.register_csv(name, &path, options);
753 wait_for_future(py, result)??;
754 }
755
756 Ok(())
757 }
758
759 #[allow(clippy::too_many_arguments)]
760 #[pyo3(signature = (name,
761 path,
762 schema=None,
763 schema_infer_max_records=1000,
764 file_extension=".json",
765 table_partition_cols=vec![],
766 file_compression_type=None))]
767 pub fn register_json(
768 &self,
769 name: &str,
770 path: PathBuf,
771 schema: Option<PyArrowType<Schema>>,
772 schema_infer_max_records: usize,
773 file_extension: &str,
774 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
775 file_compression_type: Option<String>,
776 py: Python,
777 ) -> PyDataFusionResult<()> {
778 let path = path
779 .to_str()
780 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
781
782 let mut options = NdJsonReadOptions::default()
783 .file_compression_type(parse_file_compression_type(file_compression_type)?)
784 .table_partition_cols(
785 table_partition_cols
786 .into_iter()
787 .map(|(name, ty)| (name, ty.0))
788 .collect::<Vec<(String, DataType)>>(),
789 );
790 options.schema_infer_max_records = schema_infer_max_records;
791 options.file_extension = file_extension;
792 options.schema = schema.as_ref().map(|x| &x.0);
793
794 let result = self.ctx.register_json(name, path, options);
795 wait_for_future(py, result)??;
796
797 Ok(())
798 }
799
800 #[allow(clippy::too_many_arguments)]
801 #[pyo3(signature = (name,
802 path,
803 schema=None,
804 file_extension=".avro",
805 table_partition_cols=vec![]))]
806 pub fn register_avro(
807 &self,
808 name: &str,
809 path: PathBuf,
810 schema: Option<PyArrowType<Schema>>,
811 file_extension: &str,
812 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
813 py: Python,
814 ) -> PyDataFusionResult<()> {
815 let path = path
816 .to_str()
817 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
818
819 let mut options = AvroReadOptions::default().table_partition_cols(
820 table_partition_cols
821 .into_iter()
822 .map(|(name, ty)| (name, ty.0))
823 .collect::<Vec<(String, DataType)>>(),
824 );
825 options.file_extension = file_extension;
826 options.schema = schema.as_ref().map(|x| &x.0);
827
828 let result = self.ctx.register_avro(name, path, options);
829 wait_for_future(py, result)??;
830
831 Ok(())
832 }
833
834 pub fn register_dataset(
836 &self,
837 name: &str,
838 dataset: &Bound<'_, PyAny>,
839 py: Python,
840 ) -> PyDataFusionResult<()> {
841 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
842
843 self.ctx.register_table(name, table)?;
844
845 Ok(())
846 }
847
848 pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
849 self.ctx.register_udf(udf.function);
850 Ok(())
851 }
852
853 pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
854 self.ctx.register_udaf(udaf.function);
855 Ok(())
856 }
857
858 pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
859 self.ctx.register_udwf(udwf.function);
860 Ok(())
861 }
862
863 #[pyo3(signature = (name="datafusion"))]
864 pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
865 let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
866 "Catalog with name {name} doesn't exist."
867 )))?;
868
869 Python::attach(|py| {
870 match catalog
871 .as_any()
872 .downcast_ref::<RustWrappedPyCatalogProvider>()
873 {
874 Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
875 None => PyCatalog::from(catalog).into_py_any(py),
876 }
877 })
878 }
879
880 pub fn catalog_names(&self) -> HashSet<String> {
881 self.ctx.catalog_names().into_iter().collect()
882 }
883
884 pub fn tables(&self) -> HashSet<String> {
885 self.ctx
886 .catalog_names()
887 .into_iter()
888 .filter_map(|name| self.ctx.catalog(&name))
889 .flat_map(move |catalog| {
890 catalog
891 .schema_names()
892 .into_iter()
893 .filter_map(move |name| catalog.schema(&name))
894 })
895 .flat_map(|schema| schema.table_names())
896 .collect()
897 }
898
899 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
900 let res = wait_for_future(py, self.ctx.table(name))
901 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
902 match res {
903 Ok(df) => Ok(PyDataFrame::new(df)),
904 Err(e) => {
905 if let datafusion::error::DataFusionError::Plan(msg) = &e {
906 if msg.contains("No table named") {
907 return Err(PyKeyError::new_err(msg.to_string()));
908 }
909 }
910 Err(py_datafusion_err(e))
911 }
912 }
913 }
914
915 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
916 Ok(self.ctx.table_exist(name)?)
917 }
918
919 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
920 Ok(PyDataFrame::new(self.ctx.read_empty()?))
921 }
922
923 pub fn session_id(&self) -> String {
924 self.ctx.session_id()
925 }
926
927 #[allow(clippy::too_many_arguments)]
928 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
929 pub fn read_json(
930 &self,
931 path: PathBuf,
932 schema: Option<PyArrowType<Schema>>,
933 schema_infer_max_records: usize,
934 file_extension: &str,
935 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
936 file_compression_type: Option<String>,
937 py: Python,
938 ) -> PyDataFusionResult<PyDataFrame> {
939 let path = path
940 .to_str()
941 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
942 let mut options = NdJsonReadOptions::default()
943 .table_partition_cols(
944 table_partition_cols
945 .into_iter()
946 .map(|(name, ty)| (name, ty.0))
947 .collect::<Vec<(String, DataType)>>(),
948 )
949 .file_compression_type(parse_file_compression_type(file_compression_type)?);
950 options.schema_infer_max_records = schema_infer_max_records;
951 options.file_extension = file_extension;
952 let df = if let Some(schema) = schema {
953 options.schema = Some(&schema.0);
954 let result = self.ctx.read_json(path, options);
955 wait_for_future(py, result)??
956 } else {
957 let result = self.ctx.read_json(path, options);
958 wait_for_future(py, result)??
959 };
960 Ok(PyDataFrame::new(df))
961 }
962
963 #[allow(clippy::too_many_arguments)]
964 #[pyo3(signature = (
965 path,
966 schema=None,
967 has_header=true,
968 delimiter=",",
969 schema_infer_max_records=1000,
970 file_extension=".csv",
971 table_partition_cols=vec![],
972 file_compression_type=None))]
973 pub fn read_csv(
974 &self,
975 path: &Bound<'_, PyAny>,
976 schema: Option<PyArrowType<Schema>>,
977 has_header: bool,
978 delimiter: &str,
979 schema_infer_max_records: usize,
980 file_extension: &str,
981 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
982 file_compression_type: Option<String>,
983 py: Python,
984 ) -> PyDataFusionResult<PyDataFrame> {
985 let delimiter = delimiter.as_bytes();
986 if delimiter.len() != 1 {
987 return Err(PyDataFusionError::PythonError(py_value_err(
988 "Delimiter must be a single character",
989 )));
990 };
991
992 let mut options = CsvReadOptions::new()
993 .has_header(has_header)
994 .delimiter(delimiter[0])
995 .schema_infer_max_records(schema_infer_max_records)
996 .file_extension(file_extension)
997 .table_partition_cols(
998 table_partition_cols
999 .into_iter()
1000 .map(|(name, ty)| (name, ty.0))
1001 .collect::<Vec<(String, DataType)>>(),
1002 )
1003 .file_compression_type(parse_file_compression_type(file_compression_type)?);
1004 options.schema = schema.as_ref().map(|x| &x.0);
1005
1006 if path.is_instance_of::<PyList>() {
1007 let paths = path.extract::<Vec<String>>()?;
1008 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1009 let result = self.ctx.read_csv(paths, options);
1010 let df = PyDataFrame::new(wait_for_future(py, result)??);
1011 Ok(df)
1012 } else {
1013 let path = path.extract::<String>()?;
1014 let result = self.ctx.read_csv(path, options);
1015 let df = PyDataFrame::new(wait_for_future(py, result)??);
1016 Ok(df)
1017 }
1018 }
1019
1020 #[allow(clippy::too_many_arguments)]
1021 #[pyo3(signature = (
1022 path,
1023 table_partition_cols=vec![],
1024 parquet_pruning=true,
1025 file_extension=".parquet",
1026 skip_metadata=true,
1027 schema=None,
1028 file_sort_order=None))]
1029 pub fn read_parquet(
1030 &self,
1031 path: &str,
1032 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1033 parquet_pruning: bool,
1034 file_extension: &str,
1035 skip_metadata: bool,
1036 schema: Option<PyArrowType<Schema>>,
1037 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1038 py: Python,
1039 ) -> PyDataFusionResult<PyDataFrame> {
1040 let mut options = ParquetReadOptions::default()
1041 .table_partition_cols(
1042 table_partition_cols
1043 .into_iter()
1044 .map(|(name, ty)| (name, ty.0))
1045 .collect::<Vec<(String, DataType)>>(),
1046 )
1047 .parquet_pruning(parquet_pruning)
1048 .skip_metadata(skip_metadata);
1049 options.file_extension = file_extension;
1050 options.schema = schema.as_ref().map(|x| &x.0);
1051 options.file_sort_order = file_sort_order
1052 .unwrap_or_default()
1053 .into_iter()
1054 .map(|e| e.into_iter().map(|f| f.into()).collect())
1055 .collect();
1056
1057 let result = self.ctx.read_parquet(path, options);
1058 let df = PyDataFrame::new(wait_for_future(py, result)??);
1059 Ok(df)
1060 }
1061
1062 #[allow(clippy::too_many_arguments)]
1063 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1064 pub fn read_avro(
1065 &self,
1066 path: &str,
1067 schema: Option<PyArrowType<Schema>>,
1068 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1069 file_extension: &str,
1070 py: Python,
1071 ) -> PyDataFusionResult<PyDataFrame> {
1072 let mut options = AvroReadOptions::default().table_partition_cols(
1073 table_partition_cols
1074 .into_iter()
1075 .map(|(name, ty)| (name, ty.0))
1076 .collect::<Vec<(String, DataType)>>(),
1077 );
1078 options.file_extension = file_extension;
1079 let df = if let Some(schema) = schema {
1080 options.schema = Some(&schema.0);
1081 let read_future = self.ctx.read_avro(path, options);
1082 wait_for_future(py, read_future)??
1083 } else {
1084 let read_future = self.ctx.read_avro(path, options);
1085 wait_for_future(py, read_future)??
1086 };
1087 Ok(PyDataFrame::new(df))
1088 }
1089
1090 pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1091 let table = PyTable::new(&table)?;
1092 let df = self.ctx.read_table(table.table())?;
1093 Ok(PyDataFrame::new(df))
1094 }
1095
1096 fn __repr__(&self) -> PyResult<String> {
1097 let config = self.ctx.copied_config();
1098 let mut config_entries = config
1099 .options()
1100 .entries()
1101 .iter()
1102 .filter(|e| e.value.is_some())
1103 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1104 .collect::<Vec<_>>();
1105 config_entries.sort();
1106 Ok(format!(
1107 "SessionContext: id={}; configs=[\n\t{}]",
1108 self.session_id(),
1109 config_entries.join("\n\t")
1110 ))
1111 }
1112
1113 pub fn execute(
1115 &self,
1116 plan: PyExecutionPlan,
1117 part: usize,
1118 py: Python,
1119 ) -> PyDataFusionResult<PyRecordBatchStream> {
1120 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1121 let plan = plan.plan.clone();
1122 let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1123 Ok(PyRecordBatchStream::new(stream))
1124 }
1125}
1126
1127impl PySessionContext {
1128 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1129 self.ctx.table(name).await
1130 }
1131
1132 async fn register_csv_from_multiple_paths(
1133 &self,
1134 name: &str,
1135 table_paths: Vec<String>,
1136 options: CsvReadOptions<'_>,
1137 ) -> datafusion::common::Result<()> {
1138 let table_paths = table_paths.to_urls()?;
1139 let session_config = self.ctx.copied_config();
1140 let listing_options =
1141 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1142
1143 let option_extension = listing_options.file_extension.clone();
1144
1145 if table_paths.is_empty() {
1146 return exec_err!("No table paths were provided");
1147 }
1148
1149 for path in &table_paths {
1151 let file_path = path.as_str();
1152 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1153 return exec_err!(
1154 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1155 );
1156 }
1157 }
1158
1159 let resolved_schema = options
1160 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1161 .await?;
1162
1163 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1164 .with_listing_options(listing_options)
1165 .with_schema(resolved_schema);
1166 let table = ListingTable::try_new(config)?;
1167 self.ctx
1168 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1169 Ok(())
1170 }
1171}
1172
1173pub fn parse_file_compression_type(
1174 file_compression_type: Option<String>,
1175) -> Result<FileCompressionType, PyErr> {
1176 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1177 .map_err(|_| {
1178 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1179 })
1180}
1181
1182impl From<PySessionContext> for SessionContext {
1183 fn from(ctx: PySessionContext) -> SessionContext {
1184 ctx.ctx
1185 }
1186}
1187
1188impl From<SessionContext> for PySessionContext {
1189 fn from(ctx: SessionContext) -> PySessionContext {
1190 PySessionContext { ctx }
1191 }
1192}