1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
27use datafusion::arrow::pyarrow::PyArrowType;
28use datafusion::arrow::record_batch::RecordBatch;
29use datafusion::catalog::{CatalogProvider, CatalogProviderList};
30use datafusion::common::{ScalarValue, TableReference, exec_err};
31use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
32use datafusion::datasource::file_format::parquet::ParquetFormat;
33use datafusion::datasource::listing::{
34 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35};
36use datafusion::datasource::{MemTable, TableProvider};
37use datafusion::execution::TaskContextProvider;
38use datafusion::execution::context::{
39 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
40};
41use datafusion::execution::disk_manager::DiskManagerMode;
42use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
43use datafusion::execution::options::ReadOptions;
44use datafusion::execution::runtime_env::RuntimeEnvBuilder;
45use datafusion::execution::session_state::SessionStateBuilder;
46use datafusion::prelude::{
47 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
48};
49use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
50use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
51use datafusion_ffi::execution::FFI_TaskContextProvider;
52use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
53use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
54use object_store::ObjectStore;
55use pyo3::IntoPyObjectExt;
56use pyo3::exceptions::{PyKeyError, PyValueError};
57use pyo3::prelude::*;
58use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
59use url::Url;
60use uuid::Uuid;
61
62use crate::catalog::{
63 PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList,
64};
65use crate::common::data_type::PyScalarValue;
66use crate::dataframe::PyDataFrame;
67use crate::dataset::Dataset;
68use crate::errors::{
69 PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err,
70};
71use crate::expr::sort_expr::PySortExpr;
72use crate::options::PyCsvReadOptions;
73use crate::physical_plan::PyExecutionPlan;
74use crate::record_batch::PyRecordBatchStream;
75use crate::sql::logical::PyLogicalPlan;
76use crate::sql::util::replace_placeholders_with_strings;
77use crate::store::StorageContexts;
78use crate::table::PyTable;
79use crate::udaf::PyAggregateUDF;
80use crate::udf::PyScalarUDF;
81use crate::udtf::PyTableFunction;
82use crate::udwf::PyWindowUDF;
83use crate::utils::{
84 create_logical_extension_capsule, extract_logical_extension_codec, get_global_ctx,
85 get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
86};
87
88#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
90#[derive(Clone, Default)]
91pub struct PySessionConfig {
92 pub config: SessionConfig,
93}
94
95impl From<SessionConfig> for PySessionConfig {
96 fn from(config: SessionConfig) -> Self {
97 Self { config }
98 }
99}
100
101#[pymethods]
102impl PySessionConfig {
103 #[pyo3(signature = (config_options=None))]
104 #[new]
105 fn new(config_options: Option<HashMap<String, String>>) -> Self {
106 let mut config = SessionConfig::new();
107 if let Some(hash_map) = config_options {
108 for (k, v) in &hash_map {
109 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
110 }
111 }
112
113 Self { config }
114 }
115
116 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
117 Self::from(
118 self.config
119 .clone()
120 .with_create_default_catalog_and_schema(enabled),
121 )
122 }
123
124 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
125 Self::from(
126 self.config
127 .clone()
128 .with_default_catalog_and_schema(catalog, schema),
129 )
130 }
131
132 fn with_information_schema(&self, enabled: bool) -> Self {
133 Self::from(self.config.clone().with_information_schema(enabled))
134 }
135
136 fn with_batch_size(&self, batch_size: usize) -> Self {
137 Self::from(self.config.clone().with_batch_size(batch_size))
138 }
139
140 fn with_target_partitions(&self, target_partitions: usize) -> Self {
141 Self::from(
142 self.config
143 .clone()
144 .with_target_partitions(target_partitions),
145 )
146 }
147
148 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
149 Self::from(self.config.clone().with_repartition_aggregations(enabled))
150 }
151
152 fn with_repartition_joins(&self, enabled: bool) -> Self {
153 Self::from(self.config.clone().with_repartition_joins(enabled))
154 }
155
156 fn with_repartition_windows(&self, enabled: bool) -> Self {
157 Self::from(self.config.clone().with_repartition_windows(enabled))
158 }
159
160 fn with_repartition_sorts(&self, enabled: bool) -> Self {
161 Self::from(self.config.clone().with_repartition_sorts(enabled))
162 }
163
164 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
165 Self::from(self.config.clone().with_repartition_file_scans(enabled))
166 }
167
168 fn with_repartition_file_min_size(&self, size: usize) -> Self {
169 Self::from(self.config.clone().with_repartition_file_min_size(size))
170 }
171
172 fn with_parquet_pruning(&self, enabled: bool) -> Self {
173 Self::from(self.config.clone().with_parquet_pruning(enabled))
174 }
175
176 fn set(&self, key: &str, value: &str) -> Self {
177 Self::from(self.config.clone().set_str(key, value))
178 }
179}
180
181#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
183#[derive(Clone)]
184pub struct PyRuntimeEnvBuilder {
185 pub builder: RuntimeEnvBuilder,
186}
187
188#[pymethods]
189impl PyRuntimeEnvBuilder {
190 #[new]
191 fn new() -> Self {
192 Self {
193 builder: RuntimeEnvBuilder::default(),
194 }
195 }
196
197 fn with_disk_manager_disabled(&self) -> Self {
198 let mut runtime_builder = self.builder.clone();
199
200 let mut disk_mgr_builder = runtime_builder
201 .disk_manager_builder
202 .clone()
203 .unwrap_or_default();
204 disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
205
206 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
207 Self {
208 builder: runtime_builder,
209 }
210 }
211
212 fn with_disk_manager_os(&self) -> Self {
213 let mut runtime_builder = self.builder.clone();
214
215 let mut disk_mgr_builder = runtime_builder
216 .disk_manager_builder
217 .clone()
218 .unwrap_or_default();
219 disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
220
221 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
222 Self {
223 builder: runtime_builder,
224 }
225 }
226
227 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
228 let paths = paths.iter().map(|s| s.into()).collect();
229 let mut runtime_builder = self.builder.clone();
230
231 let mut disk_mgr_builder = runtime_builder
232 .disk_manager_builder
233 .clone()
234 .unwrap_or_default();
235 disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
236
237 runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
238 Self {
239 builder: runtime_builder,
240 }
241 }
242
243 fn with_unbounded_memory_pool(&self) -> Self {
244 let builder = self.builder.clone();
245 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
246 Self { builder }
247 }
248
249 fn with_fair_spill_pool(&self, size: usize) -> Self {
250 let builder = self.builder.clone();
251 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
252 Self { builder }
253 }
254
255 fn with_greedy_memory_pool(&self, size: usize) -> Self {
256 let builder = self.builder.clone();
257 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
258 Self { builder }
259 }
260
261 fn with_temp_file_path(&self, path: &str) -> Self {
262 let builder = self.builder.clone();
263 let builder = builder.with_temp_file_path(path);
264 Self { builder }
265 }
266}
267
268#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
270#[derive(Clone)]
271pub struct PySQLOptions {
272 pub options: SQLOptions,
273}
274
275impl From<SQLOptions> for PySQLOptions {
276 fn from(options: SQLOptions) -> Self {
277 Self { options }
278 }
279}
280
281#[pymethods]
282impl PySQLOptions {
283 #[new]
284 fn new() -> Self {
285 let options = SQLOptions::new();
286 Self { options }
287 }
288
289 fn with_allow_ddl(&self, allow: bool) -> Self {
291 Self::from(self.options.with_allow_ddl(allow))
292 }
293
294 pub fn with_allow_dml(&self, allow: bool) -> Self {
296 Self::from(self.options.with_allow_dml(allow))
297 }
298
299 pub fn with_allow_statements(&self, allow: bool) -> Self {
301 Self::from(self.options.with_allow_statements(allow))
302 }
303}
304
305#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
309#[derive(Clone)]
310pub struct PySessionContext {
311 pub ctx: Arc<SessionContext>,
312 logical_codec: Arc<FFI_LogicalExtensionCodec>,
313}
314
315#[pymethods]
316impl PySessionContext {
317 #[pyo3(signature = (config=None, runtime=None))]
318 #[new]
319 pub fn new(
320 config: Option<PySessionConfig>,
321 runtime: Option<PyRuntimeEnvBuilder>,
322 ) -> PyDataFusionResult<Self> {
323 let config = if let Some(c) = config {
324 c.config
325 } else {
326 SessionConfig::default().with_information_schema(true)
327 };
328 let runtime_env_builder = if let Some(c) = runtime {
329 c.builder
330 } else {
331 RuntimeEnvBuilder::default()
332 };
333 let runtime = Arc::new(runtime_env_builder.build()?);
334 let session_state = SessionStateBuilder::new()
335 .with_config(config)
336 .with_runtime_env(runtime)
337 .with_default_features()
338 .build();
339 let ctx = Arc::new(SessionContext::new_with_state(session_state));
340 let logical_codec = Self::default_logical_codec(&ctx);
341 Ok(PySessionContext { ctx, logical_codec })
342 }
343
344 pub fn enable_url_table(&self) -> PyResult<Self> {
345 Ok(PySessionContext {
346 ctx: Arc::new(self.ctx.as_ref().clone().enable_url_table()),
347 logical_codec: Arc::clone(&self.logical_codec),
348 })
349 }
350
351 #[staticmethod]
352 #[pyo3(signature = ())]
353 pub fn global_ctx() -> PyResult<Self> {
354 let ctx = get_global_ctx().clone();
355 let logical_codec = Self::default_logical_codec(&ctx);
356 Ok(Self { ctx, logical_codec })
357 }
358
359 #[pyo3(signature = (scheme, store, host=None))]
361 pub fn register_object_store(
362 &self,
363 scheme: &str,
364 store: StorageContexts,
365 host: Option<&str>,
366 ) -> PyResult<()> {
367 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
369 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
370 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
371 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
372 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
373 StorageContexts::HTTP(http) => (http.store, http.url),
374 };
375
376 let derived_host = if let Some(host) = host {
378 host
379 } else {
380 &upstream_host
381 };
382 let url_string = format!("{scheme}{derived_host}");
383 let url = Url::parse(&url_string).unwrap();
384 self.ctx.runtime_env().register_object_store(&url, store);
385 Ok(())
386 }
387
388 #[allow(clippy::too_many_arguments)]
389 #[pyo3(signature = (name, path, table_partition_cols=vec![],
390 file_extension=".parquet",
391 schema=None,
392 file_sort_order=None))]
393 pub fn register_listing_table(
394 &self,
395 name: &str,
396 path: &str,
397 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
398 file_extension: &str,
399 schema: Option<PyArrowType<Schema>>,
400 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
401 py: Python,
402 ) -> PyDataFusionResult<()> {
403 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
404 .with_file_extension(file_extension)
405 .with_table_partition_cols(
406 table_partition_cols
407 .into_iter()
408 .map(|(name, ty)| (name, ty.0))
409 .collect::<Vec<(String, DataType)>>(),
410 )
411 .with_file_sort_order(
412 file_sort_order
413 .unwrap_or_default()
414 .into_iter()
415 .map(|e| e.into_iter().map(|f| f.into()).collect())
416 .collect(),
417 );
418 let table_path = ListingTableUrl::parse(path)?;
419 let resolved_schema: SchemaRef = match schema {
420 Some(s) => Arc::new(s.0),
421 None => {
422 let state = self.ctx.state();
423 let schema = options.infer_schema(&state, &table_path);
424 wait_for_future(py, schema)??
425 }
426 };
427 let config = ListingTableConfig::new(table_path)
428 .with_listing_options(options)
429 .with_schema(resolved_schema);
430 let table = ListingTable::try_new(config)?;
431 self.ctx.register_table(name, Arc::new(table))?;
432 Ok(())
433 }
434
435 pub fn register_udtf(&self, func: PyTableFunction) {
436 let name = func.name.clone();
437 let func = Arc::new(func);
438 self.ctx.register_udtf(&name, func);
439 }
440
441 #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
442 pub fn sql_with_options(
443 &self,
444 py: Python,
445 mut query: String,
446 options: Option<PySQLOptions>,
447 param_values: HashMap<String, PyScalarValue>,
448 param_strings: HashMap<String, String>,
449 ) -> PyDataFusionResult<PyDataFrame> {
450 let options = if let Some(options) = options {
451 options.options
452 } else {
453 SQLOptions::new()
454 };
455
456 let param_values = param_values
457 .into_iter()
458 .map(|(name, value)| (name, ScalarValue::from(value)))
459 .collect::<HashMap<_, _>>();
460
461 let state = self.ctx.state();
462 let dialect = state.config().options().sql_parser.dialect.as_ref();
463
464 if !param_strings.is_empty() {
465 query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
466 }
467
468 let mut df = wait_for_future(py, async {
469 self.ctx.sql_with_options(&query, options).await
470 })?
471 .map_err(from_datafusion_error)?;
472
473 if !param_values.is_empty() {
474 df = df.with_param_values(param_values)?;
475 }
476
477 Ok(PyDataFrame::new(df))
478 }
479
480 #[pyo3(signature = (partitions, name=None, schema=None))]
481 pub fn create_dataframe(
482 &self,
483 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
484 name: Option<&str>,
485 schema: Option<PyArrowType<Schema>>,
486 py: Python,
487 ) -> PyDataFusionResult<PyDataFrame> {
488 let schema = if let Some(schema) = schema {
489 SchemaRef::from(schema.0)
490 } else {
491 partitions.0[0][0].schema()
492 };
493
494 let table = MemTable::try_new(schema, partitions.0)?;
495
496 let table_name = match name {
499 Some(val) => val.to_owned(),
500 None => {
501 "c".to_owned()
502 + Uuid::new_v4()
503 .simple()
504 .encode_lower(&mut Uuid::encode_buffer())
505 }
506 };
507
508 self.ctx.register_table(&*table_name, Arc::new(table))?;
509
510 let table = wait_for_future(py, self._table(&table_name))??;
511
512 let df = PyDataFrame::new(table);
513 Ok(df)
514 }
515
516 pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
518 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
519 }
520
521 #[pyo3(signature = (data, name=None))]
523 pub fn from_pylist(
524 &self,
525 data: Bound<'_, PyList>,
526 name: Option<&str>,
527 ) -> PyResult<PyDataFrame> {
528 let py = data.py();
530
531 let table_class = py.import("pyarrow")?.getattr("Table")?;
533 let args = PyTuple::new(py, &[data])?;
534 let table = table_class.call_method1("from_pylist", args)?;
535
536 let df = self.from_arrow(table, name, py)?;
538 Ok(df)
539 }
540
541 #[pyo3(signature = (data, name=None))]
543 pub fn from_pydict(
544 &self,
545 data: Bound<'_, PyDict>,
546 name: Option<&str>,
547 ) -> PyResult<PyDataFrame> {
548 let py = data.py();
550
551 let table_class = py.import("pyarrow")?.getattr("Table")?;
553 let args = PyTuple::new(py, &[data])?;
554 let table = table_class.call_method1("from_pydict", args)?;
555
556 let df = self.from_arrow(table, name, py)?;
558 Ok(df)
559 }
560
561 #[pyo3(signature = (data, name=None))]
563 pub fn from_arrow(
564 &self,
565 data: Bound<'_, PyAny>,
566 name: Option<&str>,
567 py: Python,
568 ) -> PyDataFusionResult<PyDataFrame> {
569 let (schema, batches) =
570 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
571 let schema = stream_reader.schema().as_ref().to_owned();
574 let batches = stream_reader
575 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
576
577 (schema, batches)
578 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
579 (array.schema().as_ref().to_owned(), vec![array])
583 } else {
584 return Err(PyDataFusionError::Common(
585 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
586 ));
587 };
588
589 let list_of_batches = PyArrowType::from(vec![batches]);
592 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
593 }
594
595 #[allow(clippy::wrong_self_convention)]
597 #[pyo3(signature = (data, name=None))]
598 pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
599 let py = data.py();
601
602 let table_class = py.import("pyarrow")?.getattr("Table")?;
604 let args = PyTuple::new(py, &[data])?;
605 let table = table_class.call_method1("from_pandas", args)?;
606
607 let df = self.from_arrow(table, name, py)?;
609 Ok(df)
610 }
611
612 #[pyo3(signature = (data, name=None))]
614 pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
615 let table = data.call_method0("to_arrow")?;
617
618 let df = self.from_arrow(table, name, data.py())?;
620 Ok(df)
621 }
622
623 pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
624 let session = self.clone().into_bound_py_any(table.py())?;
625 let table = PyTable::new(table, Some(session))?;
626
627 self.ctx.register_table(name, table.table)?;
628 Ok(())
629 }
630
631 pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
632 self.ctx.deregister_table(name)?;
633 Ok(())
634 }
635
636 pub fn register_catalog_provider_list(
637 &self,
638 mut provider: Bound<PyAny>,
639 ) -> PyDataFusionResult<()> {
640 if provider.hasattr("__datafusion_catalog_provider_list__")? {
641 let py = provider.py();
642 let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
643 provider = provider
644 .getattr("__datafusion_catalog_provider_list__")?
645 .call1((codec_capsule,))?;
646 }
647
648 let provider =
649 if let Ok(capsule) = provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
650 validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;
651
652 let provider = unsafe { capsule.reference::<FFI_CatalogProviderList>() };
653 let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
654 provider as Arc<dyn CatalogProviderList>
655 } else {
656 match provider.extract::<PyCatalogList>() {
657 Ok(py_catalog_list) => py_catalog_list.catalog_list,
658 Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new(
659 provider.into(),
660 Arc::clone(&self.logical_codec),
661 )) as Arc<dyn CatalogProviderList>,
662 }
663 };
664
665 self.ctx.register_catalog_list(provider);
666
667 Ok(())
668 }
669
670 pub fn register_catalog_provider(
671 &self,
672 name: &str,
673 mut provider: Bound<'_, PyAny>,
674 ) -> PyDataFusionResult<()> {
675 if provider.hasattr("__datafusion_catalog_provider__")? {
676 let py = provider.py();
677 let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
678 provider = provider
679 .getattr("__datafusion_catalog_provider__")?
680 .call1((codec_capsule,))?;
681 }
682
683 let provider =
684 if let Ok(capsule) = provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
685 validate_pycapsule(capsule, "datafusion_catalog_provider")?;
686
687 let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
688 let provider: Arc<dyn CatalogProvider + Send> = provider.into();
689 provider as Arc<dyn CatalogProvider>
690 } else {
691 match provider.extract::<PyCatalog>() {
692 Ok(py_catalog) => py_catalog.catalog,
693 Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
694 provider.into(),
695 Arc::clone(&self.logical_codec),
696 )) as Arc<dyn CatalogProvider>,
697 }
698 };
699
700 let _ = self.ctx.register_catalog(name, provider);
701
702 Ok(())
703 }
704
705 pub fn register_table_provider(
707 &self,
708 name: &str,
709 provider: Bound<'_, PyAny>,
710 ) -> PyDataFusionResult<()> {
711 self.register_table(name, provider)
713 }
714
715 pub fn register_record_batches(
716 &self,
717 name: &str,
718 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
719 ) -> PyDataFusionResult<()> {
720 let schema = partitions.0[0][0].schema();
721 let table = MemTable::try_new(schema, partitions.0)?;
722 self.ctx.register_table(name, Arc::new(table))?;
723 Ok(())
724 }
725
726 #[allow(clippy::too_many_arguments)]
727 #[pyo3(signature = (name, path, table_partition_cols=vec![],
728 parquet_pruning=true,
729 file_extension=".parquet",
730 skip_metadata=true,
731 schema=None,
732 file_sort_order=None))]
733 pub fn register_parquet(
734 &self,
735 name: &str,
736 path: &str,
737 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
738 parquet_pruning: bool,
739 file_extension: &str,
740 skip_metadata: bool,
741 schema: Option<PyArrowType<Schema>>,
742 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
743 py: Python,
744 ) -> PyDataFusionResult<()> {
745 let mut options = ParquetReadOptions::default()
746 .table_partition_cols(
747 table_partition_cols
748 .into_iter()
749 .map(|(name, ty)| (name, ty.0))
750 .collect::<Vec<(String, DataType)>>(),
751 )
752 .parquet_pruning(parquet_pruning)
753 .skip_metadata(skip_metadata);
754 options.file_extension = file_extension;
755 options.schema = schema.as_ref().map(|x| &x.0);
756 options.file_sort_order = file_sort_order
757 .unwrap_or_default()
758 .into_iter()
759 .map(|e| e.into_iter().map(|f| f.into()).collect())
760 .collect();
761
762 let result = self.ctx.register_parquet(name, path, options);
763 wait_for_future(py, result)??;
764 Ok(())
765 }
766
767 #[pyo3(signature = (name,
768 path,
769 options=None))]
770 pub fn register_csv(
771 &self,
772 name: &str,
773 path: &Bound<'_, PyAny>,
774 options: Option<&PyCsvReadOptions>,
775 py: Python,
776 ) -> PyDataFusionResult<()> {
777 let options = options
778 .map(|opts| opts.try_into())
779 .transpose()?
780 .unwrap_or_default();
781
782 if path.is_instance_of::<PyList>() {
783 let paths = path.extract::<Vec<String>>()?;
784 let result = self.register_csv_from_multiple_paths(name, paths, options);
785 wait_for_future(py, result)??;
786 } else {
787 let path = path.extract::<String>()?;
788 let result = self.ctx.register_csv(name, &path, options);
789 wait_for_future(py, result)??;
790 }
791
792 Ok(())
793 }
794
795 #[allow(clippy::too_many_arguments)]
796 #[pyo3(signature = (name,
797 path,
798 schema=None,
799 schema_infer_max_records=1000,
800 file_extension=".json",
801 table_partition_cols=vec![],
802 file_compression_type=None))]
803 pub fn register_json(
804 &self,
805 name: &str,
806 path: PathBuf,
807 schema: Option<PyArrowType<Schema>>,
808 schema_infer_max_records: usize,
809 file_extension: &str,
810 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
811 file_compression_type: Option<String>,
812 py: Python,
813 ) -> PyDataFusionResult<()> {
814 let path = path
815 .to_str()
816 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
817
818 let mut options = NdJsonReadOptions::default()
819 .file_compression_type(parse_file_compression_type(file_compression_type)?)
820 .table_partition_cols(
821 table_partition_cols
822 .into_iter()
823 .map(|(name, ty)| (name, ty.0))
824 .collect::<Vec<(String, DataType)>>(),
825 );
826 options.schema_infer_max_records = schema_infer_max_records;
827 options.file_extension = file_extension;
828 options.schema = schema.as_ref().map(|x| &x.0);
829
830 let result = self.ctx.register_json(name, path, options);
831 wait_for_future(py, result)??;
832
833 Ok(())
834 }
835
836 #[allow(clippy::too_many_arguments)]
837 #[pyo3(signature = (name,
838 path,
839 schema=None,
840 file_extension=".avro",
841 table_partition_cols=vec![]))]
842 pub fn register_avro(
843 &self,
844 name: &str,
845 path: PathBuf,
846 schema: Option<PyArrowType<Schema>>,
847 file_extension: &str,
848 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
849 py: Python,
850 ) -> PyDataFusionResult<()> {
851 let path = path
852 .to_str()
853 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
854
855 let mut options = AvroReadOptions::default().table_partition_cols(
856 table_partition_cols
857 .into_iter()
858 .map(|(name, ty)| (name, ty.0))
859 .collect::<Vec<(String, DataType)>>(),
860 );
861 options.file_extension = file_extension;
862 options.schema = schema.as_ref().map(|x| &x.0);
863
864 let result = self.ctx.register_avro(name, path, options);
865 wait_for_future(py, result)??;
866
867 Ok(())
868 }
869
870 pub fn register_dataset(
872 &self,
873 name: &str,
874 dataset: &Bound<'_, PyAny>,
875 py: Python,
876 ) -> PyDataFusionResult<()> {
877 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
878
879 self.ctx.register_table(name, table)?;
880
881 Ok(())
882 }
883
884 pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
885 self.ctx.register_udf(udf.function);
886 Ok(())
887 }
888
889 pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
890 self.ctx.register_udaf(udaf.function);
891 Ok(())
892 }
893
894 pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
895 self.ctx.register_udwf(udwf.function);
896 Ok(())
897 }
898
899 #[pyo3(signature = (name="datafusion"))]
900 pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
901 let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
902 "Catalog with name {name} doesn't exist."
903 )))?;
904
905 match catalog
906 .as_any()
907 .downcast_ref::<RustWrappedPyCatalogProvider>()
908 {
909 Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
910 None => Ok(
911 PyCatalog::new_from_parts(catalog, Arc::clone(&self.logical_codec))
912 .into_py_any(py)?,
913 ),
914 }
915 }
916
917 pub fn catalog_names(&self) -> HashSet<String> {
918 self.ctx.catalog_names().into_iter().collect()
919 }
920
921 pub fn tables(&self) -> HashSet<String> {
922 self.ctx
923 .catalog_names()
924 .into_iter()
925 .filter_map(|name| self.ctx.catalog(&name))
926 .flat_map(move |catalog| {
927 catalog
928 .schema_names()
929 .into_iter()
930 .filter_map(move |name| catalog.schema(&name))
931 })
932 .flat_map(|schema| schema.table_names())
933 .collect()
934 }
935
936 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
937 let res = wait_for_future(py, self.ctx.table(name))
938 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
939 match res {
940 Ok(df) => Ok(PyDataFrame::new(df)),
941 Err(e) => {
942 if let datafusion::error::DataFusionError::Plan(msg) = &e
943 && msg.contains("No table named")
944 {
945 return Err(PyKeyError::new_err(msg.to_string()));
946 }
947 Err(py_datafusion_err(e))
948 }
949 }
950 }
951
952 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
953 Ok(self.ctx.table_exist(name)?)
954 }
955
956 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
957 Ok(PyDataFrame::new(self.ctx.read_empty()?))
958 }
959
960 pub fn session_id(&self) -> String {
961 self.ctx.session_id()
962 }
963
964 #[allow(clippy::too_many_arguments)]
965 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
966 pub fn read_json(
967 &self,
968 path: PathBuf,
969 schema: Option<PyArrowType<Schema>>,
970 schema_infer_max_records: usize,
971 file_extension: &str,
972 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
973 file_compression_type: Option<String>,
974 py: Python,
975 ) -> PyDataFusionResult<PyDataFrame> {
976 let path = path
977 .to_str()
978 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
979 let mut options = NdJsonReadOptions::default()
980 .table_partition_cols(
981 table_partition_cols
982 .into_iter()
983 .map(|(name, ty)| (name, ty.0))
984 .collect::<Vec<(String, DataType)>>(),
985 )
986 .file_compression_type(parse_file_compression_type(file_compression_type)?);
987 options.schema_infer_max_records = schema_infer_max_records;
988 options.file_extension = file_extension;
989 let df = if let Some(schema) = schema {
990 options.schema = Some(&schema.0);
991 let result = self.ctx.read_json(path, options);
992 wait_for_future(py, result)??
993 } else {
994 let result = self.ctx.read_json(path, options);
995 wait_for_future(py, result)??
996 };
997 Ok(PyDataFrame::new(df))
998 }
999
1000 #[pyo3(signature = (
1001 path,
1002 options=None))]
1003 pub fn read_csv(
1004 &self,
1005 path: &Bound<'_, PyAny>,
1006 options: Option<&PyCsvReadOptions>,
1007 py: Python,
1008 ) -> PyDataFusionResult<PyDataFrame> {
1009 let options = options
1010 .map(|opts| opts.try_into())
1011 .transpose()?
1012 .unwrap_or_default();
1013
1014 if path.is_instance_of::<PyList>() {
1015 let paths = path.extract::<Vec<String>>()?;
1016 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1017 let result = self.ctx.read_csv(paths, options);
1018 let df = PyDataFrame::new(wait_for_future(py, result)??);
1019 Ok(df)
1020 } else {
1021 let path = path.extract::<String>()?;
1022 let result = self.ctx.read_csv(path, options);
1023 let df = PyDataFrame::new(wait_for_future(py, result)??);
1024 Ok(df)
1025 }
1026 }
1027
1028 #[allow(clippy::too_many_arguments)]
1029 #[pyo3(signature = (
1030 path,
1031 table_partition_cols=vec![],
1032 parquet_pruning=true,
1033 file_extension=".parquet",
1034 skip_metadata=true,
1035 schema=None,
1036 file_sort_order=None))]
1037 pub fn read_parquet(
1038 &self,
1039 path: &str,
1040 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1041 parquet_pruning: bool,
1042 file_extension: &str,
1043 skip_metadata: bool,
1044 schema: Option<PyArrowType<Schema>>,
1045 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1046 py: Python,
1047 ) -> PyDataFusionResult<PyDataFrame> {
1048 let mut options = ParquetReadOptions::default()
1049 .table_partition_cols(
1050 table_partition_cols
1051 .into_iter()
1052 .map(|(name, ty)| (name, ty.0))
1053 .collect::<Vec<(String, DataType)>>(),
1054 )
1055 .parquet_pruning(parquet_pruning)
1056 .skip_metadata(skip_metadata);
1057 options.file_extension = file_extension;
1058 options.schema = schema.as_ref().map(|x| &x.0);
1059 options.file_sort_order = file_sort_order
1060 .unwrap_or_default()
1061 .into_iter()
1062 .map(|e| e.into_iter().map(|f| f.into()).collect())
1063 .collect();
1064
1065 let result = self.ctx.read_parquet(path, options);
1066 let df = PyDataFrame::new(wait_for_future(py, result)??);
1067 Ok(df)
1068 }
1069
1070 #[allow(clippy::too_many_arguments)]
1071 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1072 pub fn read_avro(
1073 &self,
1074 path: &str,
1075 schema: Option<PyArrowType<Schema>>,
1076 table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1077 file_extension: &str,
1078 py: Python,
1079 ) -> PyDataFusionResult<PyDataFrame> {
1080 let mut options = AvroReadOptions::default().table_partition_cols(
1081 table_partition_cols
1082 .into_iter()
1083 .map(|(name, ty)| (name, ty.0))
1084 .collect::<Vec<(String, DataType)>>(),
1085 );
1086 options.file_extension = file_extension;
1087 let df = if let Some(schema) = schema {
1088 options.schema = Some(&schema.0);
1089 let read_future = self.ctx.read_avro(path, options);
1090 wait_for_future(py, read_future)??
1091 } else {
1092 let read_future = self.ctx.read_avro(path, options);
1093 wait_for_future(py, read_future)??
1094 };
1095 Ok(PyDataFrame::new(df))
1096 }
1097
1098 pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1099 let session = self.clone().into_bound_py_any(table.py())?;
1100 let table = PyTable::new(table, Some(session))?;
1101 let df = self.ctx.read_table(table.table())?;
1102 Ok(PyDataFrame::new(df))
1103 }
1104
1105 fn __repr__(&self) -> PyResult<String> {
1106 let config = self.ctx.copied_config();
1107 let mut config_entries = config
1108 .options()
1109 .entries()
1110 .iter()
1111 .filter(|e| e.value.is_some())
1112 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1113 .collect::<Vec<_>>();
1114 config_entries.sort();
1115 Ok(format!(
1116 "SessionContext: id={}; configs=[\n\t{}]",
1117 self.session_id(),
1118 config_entries.join("\n\t")
1119 ))
1120 }
1121
1122 pub fn execute(
1124 &self,
1125 plan: PyExecutionPlan,
1126 part: usize,
1127 py: Python,
1128 ) -> PyDataFusionResult<PyRecordBatchStream> {
1129 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1130 let plan = plan.plan.clone();
1131 let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1132 Ok(PyRecordBatchStream::new(stream))
1133 }
1134
1135 pub fn __datafusion_task_context_provider__<'py>(
1136 &self,
1137 py: Python<'py>,
1138 ) -> PyResult<Bound<'py, PyCapsule>> {
1139 let name = cr"datafusion_task_context_provider".into();
1140
1141 let ctx_provider = Arc::clone(&self.ctx) as Arc<dyn TaskContextProvider>;
1142 let ffi_ctx_provider = FFI_TaskContextProvider::from(&ctx_provider);
1143
1144 PyCapsule::new(py, ffi_ctx_provider, Some(name))
1145 }
1146
1147 pub fn __datafusion_logical_extension_codec__<'py>(
1148 &self,
1149 py: Python<'py>,
1150 ) -> PyResult<Bound<'py, PyCapsule>> {
1151 create_logical_extension_capsule(py, self.logical_codec.as_ref())
1152 }
1153
1154 pub fn with_logical_extension_codec<'py>(
1155 &self,
1156 codec: Bound<'py, PyAny>,
1157 ) -> PyDataFusionResult<Self> {
1158 let py = codec.py();
1159 let logical_codec = extract_logical_extension_codec(py, Some(codec))?;
1160
1161 Ok({
1162 Self {
1163 ctx: Arc::clone(&self.ctx),
1164 logical_codec,
1165 }
1166 })
1167 }
1168}
1169
1170impl PySessionContext {
1171 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1172 self.ctx.table(name).await
1173 }
1174
1175 async fn register_csv_from_multiple_paths(
1176 &self,
1177 name: &str,
1178 table_paths: Vec<String>,
1179 options: CsvReadOptions<'_>,
1180 ) -> datafusion::common::Result<()> {
1181 let table_paths = table_paths.to_urls()?;
1182 let session_config = self.ctx.copied_config();
1183 let listing_options =
1184 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1185
1186 let option_extension = listing_options.file_extension.clone();
1187
1188 if table_paths.is_empty() {
1189 return exec_err!("No table paths were provided");
1190 }
1191
1192 for path in &table_paths {
1194 let file_path = path.as_str();
1195 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1196 return exec_err!(
1197 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1198 );
1199 }
1200 }
1201
1202 let resolved_schema = options
1203 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1204 .await?;
1205
1206 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1207 .with_listing_options(listing_options)
1208 .with_schema(resolved_schema);
1209 let table = ListingTable::try_new(config)?;
1210 self.ctx
1211 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1212 Ok(())
1213 }
1214
1215 fn default_logical_codec(ctx: &Arc<SessionContext>) -> Arc<FFI_LogicalExtensionCodec> {
1216 let codec = Arc::new(DefaultLogicalExtensionCodec {});
1217 let runtime = get_tokio_runtime().0.handle().clone();
1218 let ctx_provider = Arc::clone(ctx) as Arc<dyn TaskContextProvider>;
1219 Arc::new(FFI_LogicalExtensionCodec::new(
1220 codec,
1221 Some(runtime),
1222 &ctx_provider,
1223 ))
1224 }
1225}
1226
1227pub fn parse_file_compression_type(
1228 file_compression_type: Option<String>,
1229) -> Result<FileCompressionType, PyErr> {
1230 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1231 .map_err(|_| {
1232 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1233 })
1234}
1235
1236impl From<PySessionContext> for SessionContext {
1237 fn from(ctx: PySessionContext) -> SessionContext {
1238 ctx.ctx.as_ref().clone()
1239 }
1240}
1241
1242impl From<SessionContext> for PySessionContext {
1243 fn from(ctx: SessionContext) -> PySessionContext {
1244 let ctx = Arc::new(ctx);
1245 let logical_codec = Self::default_logical_codec(&ctx);
1246
1247 PySessionContext { ctx, logical_codec }
1248 }
1249}