datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::table::PyTable;
45use crate::udaf::PyAggregateUDF;
46use crate::udf::PyScalarUDF;
47use crate::udtf::PyTableFunction;
48use crate::udwf::PyWindowUDF;
49use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
50use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
51use datafusion::arrow::pyarrow::PyArrowType;
52use datafusion::arrow::record_batch::RecordBatch;
53use datafusion::catalog::CatalogProvider;
54use datafusion::common::TableReference;
55use datafusion::common::{exec_err, ScalarValue};
56use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
57use datafusion::datasource::file_format::parquet::ParquetFormat;
58use datafusion::datasource::listing::{
59    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
60};
61use datafusion::datasource::MemTable;
62use datafusion::datasource::TableProvider;
63use datafusion::execution::context::{
64    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
65};
66use datafusion::execution::disk_manager::DiskManagerMode;
67use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
68use datafusion::execution::options::ReadOptions;
69use datafusion::execution::runtime_env::RuntimeEnvBuilder;
70use datafusion::physical_plan::SendableRecordBatchStream;
71use datafusion::prelude::{
72    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
73};
74use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
75use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
76use pyo3::IntoPyObjectExt;
77use tokio::task::JoinHandle;
78
79/// Configuration options for a SessionContext
80#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
81#[derive(Clone, Default)]
82pub struct PySessionConfig {
83    pub config: SessionConfig,
84}
85
86impl From<SessionConfig> for PySessionConfig {
87    fn from(config: SessionConfig) -> Self {
88        Self { config }
89    }
90}
91
92#[pymethods]
93impl PySessionConfig {
94    #[pyo3(signature = (config_options=None))]
95    #[new]
96    fn new(config_options: Option<HashMap<String, String>>) -> Self {
97        let mut config = SessionConfig::new();
98        if let Some(hash_map) = config_options {
99            for (k, v) in &hash_map {
100                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
101            }
102        }
103
104        Self { config }
105    }
106
107    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
108        Self::from(
109            self.config
110                .clone()
111                .with_create_default_catalog_and_schema(enabled),
112        )
113    }
114
115    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
116        Self::from(
117            self.config
118                .clone()
119                .with_default_catalog_and_schema(catalog, schema),
120        )
121    }
122
123    fn with_information_schema(&self, enabled: bool) -> Self {
124        Self::from(self.config.clone().with_information_schema(enabled))
125    }
126
127    fn with_batch_size(&self, batch_size: usize) -> Self {
128        Self::from(self.config.clone().with_batch_size(batch_size))
129    }
130
131    fn with_target_partitions(&self, target_partitions: usize) -> Self {
132        Self::from(
133            self.config
134                .clone()
135                .with_target_partitions(target_partitions),
136        )
137    }
138
139    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
140        Self::from(self.config.clone().with_repartition_aggregations(enabled))
141    }
142
143    fn with_repartition_joins(&self, enabled: bool) -> Self {
144        Self::from(self.config.clone().with_repartition_joins(enabled))
145    }
146
147    fn with_repartition_windows(&self, enabled: bool) -> Self {
148        Self::from(self.config.clone().with_repartition_windows(enabled))
149    }
150
151    fn with_repartition_sorts(&self, enabled: bool) -> Self {
152        Self::from(self.config.clone().with_repartition_sorts(enabled))
153    }
154
155    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
156        Self::from(self.config.clone().with_repartition_file_scans(enabled))
157    }
158
159    fn with_repartition_file_min_size(&self, size: usize) -> Self {
160        Self::from(self.config.clone().with_repartition_file_min_size(size))
161    }
162
163    fn with_parquet_pruning(&self, enabled: bool) -> Self {
164        Self::from(self.config.clone().with_parquet_pruning(enabled))
165    }
166
167    fn set(&self, key: &str, value: &str) -> Self {
168        Self::from(self.config.clone().set_str(key, value))
169    }
170}
171
172/// Runtime options for a SessionContext
173#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
174#[derive(Clone)]
175pub struct PyRuntimeEnvBuilder {
176    pub builder: RuntimeEnvBuilder,
177}
178
179#[pymethods]
180impl PyRuntimeEnvBuilder {
181    #[new]
182    fn new() -> Self {
183        Self {
184            builder: RuntimeEnvBuilder::default(),
185        }
186    }
187
188    fn with_disk_manager_disabled(&self) -> Self {
189        let mut runtime_builder = self.builder.clone();
190
191        let mut disk_mgr_builder = runtime_builder
192            .disk_manager_builder
193            .clone()
194            .unwrap_or_default();
195        disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
196
197        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
198        Self {
199            builder: runtime_builder,
200        }
201    }
202
203    fn with_disk_manager_os(&self) -> Self {
204        let mut runtime_builder = self.builder.clone();
205
206        let mut disk_mgr_builder = runtime_builder
207            .disk_manager_builder
208            .clone()
209            .unwrap_or_default();
210        disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
211
212        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
213        Self {
214            builder: runtime_builder,
215        }
216    }
217
218    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
219        let paths = paths.iter().map(|s| s.into()).collect();
220        let mut runtime_builder = self.builder.clone();
221
222        let mut disk_mgr_builder = runtime_builder
223            .disk_manager_builder
224            .clone()
225            .unwrap_or_default();
226        disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
227
228        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
229        Self {
230            builder: runtime_builder,
231        }
232    }
233
234    fn with_unbounded_memory_pool(&self) -> Self {
235        let builder = self.builder.clone();
236        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
237        Self { builder }
238    }
239
240    fn with_fair_spill_pool(&self, size: usize) -> Self {
241        let builder = self.builder.clone();
242        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
243        Self { builder }
244    }
245
246    fn with_greedy_memory_pool(&self, size: usize) -> Self {
247        let builder = self.builder.clone();
248        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
249        Self { builder }
250    }
251
252    fn with_temp_file_path(&self, path: &str) -> Self {
253        let builder = self.builder.clone();
254        let builder = builder.with_temp_file_path(path);
255        Self { builder }
256    }
257}
258
259/// `PySQLOptions` allows you to specify options to the sql execution.
260#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
261#[derive(Clone)]
262pub struct PySQLOptions {
263    pub options: SQLOptions,
264}
265
266impl From<SQLOptions> for PySQLOptions {
267    fn from(options: SQLOptions) -> Self {
268        Self { options }
269    }
270}
271
272#[pymethods]
273impl PySQLOptions {
274    #[new]
275    fn new() -> Self {
276        let options = SQLOptions::new();
277        Self { options }
278    }
279
280    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
281    fn with_allow_ddl(&self, allow: bool) -> Self {
282        Self::from(self.options.with_allow_ddl(allow))
283    }
284
285    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
286    pub fn with_allow_dml(&self, allow: bool) -> Self {
287        Self::from(self.options.with_allow_dml(allow))
288    }
289
290    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
291    pub fn with_allow_statements(&self, allow: bool) -> Self {
292        Self::from(self.options.with_allow_statements(allow))
293    }
294}
295
296/// `PySessionContext` is able to plan and execute DataFusion plans.
297/// It has a powerful optimizer, a physical planner for local execution, and a
298/// multi-threaded execution engine to perform the execution.
299#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
300#[derive(Clone)]
301pub struct PySessionContext {
302    pub ctx: SessionContext,
303}
304
305#[pymethods]
306impl PySessionContext {
307    #[pyo3(signature = (config=None, runtime=None))]
308    #[new]
309    pub fn new(
310        config: Option<PySessionConfig>,
311        runtime: Option<PyRuntimeEnvBuilder>,
312    ) -> PyDataFusionResult<Self> {
313        let config = if let Some(c) = config {
314            c.config
315        } else {
316            SessionConfig::default().with_information_schema(true)
317        };
318        let runtime_env_builder = if let Some(c) = runtime {
319            c.builder
320        } else {
321            RuntimeEnvBuilder::default()
322        };
323        let runtime = Arc::new(runtime_env_builder.build()?);
324        let session_state = SessionStateBuilder::new()
325            .with_config(config)
326            .with_runtime_env(runtime)
327            .with_default_features()
328            .build();
329        Ok(PySessionContext {
330            ctx: SessionContext::new_with_state(session_state),
331        })
332    }
333
334    pub fn enable_url_table(&self) -> PyResult<Self> {
335        Ok(PySessionContext {
336            ctx: self.ctx.clone().enable_url_table(),
337        })
338    }
339
340    #[classmethod]
341    #[pyo3(signature = ())]
342    fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
343        Ok(Self {
344            ctx: get_global_ctx().clone(),
345        })
346    }
347
348    /// Register an object store with the given name
349    #[pyo3(signature = (scheme, store, host=None))]
350    pub fn register_object_store(
351        &self,
352        scheme: &str,
353        store: StorageContexts,
354        host: Option<&str>,
355    ) -> PyResult<()> {
356        // for most stores the "host" is the bucket name and can be inferred from the store
357        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
358            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
359            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
360            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
361            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
362            StorageContexts::HTTP(http) => (http.store, http.url),
363        };
364
365        // let users override the host to match the api signature from upstream
366        let derived_host = if let Some(host) = host {
367            host
368        } else {
369            &upstream_host
370        };
371        let url_string = format!("{scheme}{derived_host}");
372        let url = Url::parse(&url_string).unwrap();
373        self.ctx.runtime_env().register_object_store(&url, store);
374        Ok(())
375    }
376
377    #[allow(clippy::too_many_arguments)]
378    #[pyo3(signature = (name, path, table_partition_cols=vec![],
379    file_extension=".parquet",
380    schema=None,
381    file_sort_order=None))]
382    pub fn register_listing_table(
383        &self,
384        name: &str,
385        path: &str,
386        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
387        file_extension: &str,
388        schema: Option<PyArrowType<Schema>>,
389        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
390        py: Python,
391    ) -> PyDataFusionResult<()> {
392        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
393            .with_file_extension(file_extension)
394            .with_table_partition_cols(
395                table_partition_cols
396                    .into_iter()
397                    .map(|(name, ty)| (name, ty.0))
398                    .collect::<Vec<(String, DataType)>>(),
399            )
400            .with_file_sort_order(
401                file_sort_order
402                    .unwrap_or_default()
403                    .into_iter()
404                    .map(|e| e.into_iter().map(|f| f.into()).collect())
405                    .collect(),
406            );
407        let table_path = ListingTableUrl::parse(path)?;
408        let resolved_schema: SchemaRef = match schema {
409            Some(s) => Arc::new(s.0),
410            None => {
411                let state = self.ctx.state();
412                let schema = options.infer_schema(&state, &table_path);
413                wait_for_future(py, schema)??
414            }
415        };
416        let config = ListingTableConfig::new(table_path)
417            .with_listing_options(options)
418            .with_schema(resolved_schema);
419        let table = ListingTable::try_new(config)?;
420        self.ctx.register_table(name, Arc::new(table))?;
421        Ok(())
422    }
423
424    pub fn register_udtf(&self, func: PyTableFunction) {
425        let name = func.name.clone();
426        let func = Arc::new(func);
427        self.ctx.register_udtf(&name, func);
428    }
429
430    /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
431    pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
432        let result = self.ctx.sql(query);
433        let df = wait_for_future(py, result)??;
434        Ok(PyDataFrame::new(df))
435    }
436
437    #[pyo3(signature = (query, options=None))]
438    pub fn sql_with_options(
439        &self,
440        query: &str,
441        options: Option<PySQLOptions>,
442        py: Python,
443    ) -> PyDataFusionResult<PyDataFrame> {
444        let options = if let Some(options) = options {
445            options.options
446        } else {
447            SQLOptions::new()
448        };
449        let result = self.ctx.sql_with_options(query, options);
450        let df = wait_for_future(py, result)??;
451        Ok(PyDataFrame::new(df))
452    }
453
454    #[pyo3(signature = (partitions, name=None, schema=None))]
455    pub fn create_dataframe(
456        &self,
457        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
458        name: Option<&str>,
459        schema: Option<PyArrowType<Schema>>,
460        py: Python,
461    ) -> PyDataFusionResult<PyDataFrame> {
462        let schema = if let Some(schema) = schema {
463            SchemaRef::from(schema.0)
464        } else {
465            partitions.0[0][0].schema()
466        };
467
468        let table = MemTable::try_new(schema, partitions.0)?;
469
470        // generate a random (unique) name for this table if none is provided
471        // table name cannot start with numeric digit
472        let table_name = match name {
473            Some(val) => val.to_owned(),
474            None => {
475                "c".to_owned()
476                    + Uuid::new_v4()
477                        .simple()
478                        .encode_lower(&mut Uuid::encode_buffer())
479            }
480        };
481
482        self.ctx.register_table(&*table_name, Arc::new(table))?;
483
484        let table = wait_for_future(py, self._table(&table_name))??;
485
486        let df = PyDataFrame::new(table);
487        Ok(df)
488    }
489
490    /// Create a DataFrame from an existing logical plan
491    pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
492        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
493    }
494
495    /// Construct datafusion dataframe from Python list
496    #[pyo3(signature = (data, name=None))]
497    pub fn from_pylist(
498        &self,
499        data: Bound<'_, PyList>,
500        name: Option<&str>,
501    ) -> PyResult<PyDataFrame> {
502        // Acquire GIL Token
503        let py = data.py();
504
505        // Instantiate pyarrow Table object & convert to Arrow Table
506        let table_class = py.import("pyarrow")?.getattr("Table")?;
507        let args = PyTuple::new(py, &[data])?;
508        let table = table_class.call_method1("from_pylist", args)?;
509
510        // Convert Arrow Table to datafusion DataFrame
511        let df = self.from_arrow(table, name, py)?;
512        Ok(df)
513    }
514
515    /// Construct datafusion dataframe from Python dictionary
516    #[pyo3(signature = (data, name=None))]
517    pub fn from_pydict(
518        &self,
519        data: Bound<'_, PyDict>,
520        name: Option<&str>,
521    ) -> PyResult<PyDataFrame> {
522        // Acquire GIL Token
523        let py = data.py();
524
525        // Instantiate pyarrow Table object & convert to Arrow Table
526        let table_class = py.import("pyarrow")?.getattr("Table")?;
527        let args = PyTuple::new(py, &[data])?;
528        let table = table_class.call_method1("from_pydict", args)?;
529
530        // Convert Arrow Table to datafusion DataFrame
531        let df = self.from_arrow(table, name, py)?;
532        Ok(df)
533    }
534
535    /// Construct datafusion dataframe from Arrow Table
536    #[pyo3(signature = (data, name=None))]
537    pub fn from_arrow(
538        &self,
539        data: Bound<'_, PyAny>,
540        name: Option<&str>,
541        py: Python,
542    ) -> PyDataFusionResult<PyDataFrame> {
543        let (schema, batches) =
544            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
545                // Works for any object that implements __arrow_c_stream__ in pycapsule.
546
547                let schema = stream_reader.schema().as_ref().to_owned();
548                let batches = stream_reader
549                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
550
551                (schema, batches)
552            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
553                // While this says RecordBatch, it will work for any object that implements
554                // __arrow_c_array__ and returns a StructArray.
555
556                (array.schema().as_ref().to_owned(), vec![array])
557            } else {
558                return Err(crate::errors::PyDataFusionError::Common(
559                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
560                ));
561            };
562
563        // Because create_dataframe() expects a vector of vectors of record batches
564        // here we need to wrap the vector of record batches in an additional vector
565        let list_of_batches = PyArrowType::from(vec![batches]);
566        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
567    }
568
569    /// Construct datafusion dataframe from pandas
570    #[allow(clippy::wrong_self_convention)]
571    #[pyo3(signature = (data, name=None))]
572    pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
573        // Obtain GIL token
574        let py = data.py();
575
576        // Instantiate pyarrow Table object & convert to Arrow Table
577        let table_class = py.import("pyarrow")?.getattr("Table")?;
578        let args = PyTuple::new(py, &[data])?;
579        let table = table_class.call_method1("from_pandas", args)?;
580
581        // Convert Arrow Table to datafusion DataFrame
582        let df = self.from_arrow(table, name, py)?;
583        Ok(df)
584    }
585
586    /// Construct datafusion dataframe from polars
587    #[pyo3(signature = (data, name=None))]
588    pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
589        // Convert Polars dataframe to Arrow Table
590        let table = data.call_method0("to_arrow")?;
591
592        // Convert Arrow Table to datafusion DataFrame
593        let df = self.from_arrow(table, name, data.py())?;
594        Ok(df)
595    }
596
597    pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
598        let table = PyTable::new(&table)?;
599
600        self.ctx.register_table(name, table.table)?;
601        Ok(())
602    }
603
604    pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
605        self.ctx.deregister_table(name)?;
606        Ok(())
607    }
608
609    pub fn register_catalog_provider(
610        &self,
611        name: &str,
612        provider: Bound<'_, PyAny>,
613    ) -> PyDataFusionResult<()> {
614        let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
615            let capsule = provider
616                .getattr("__datafusion_catalog_provider__")?
617                .call0()?;
618            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
619            validate_pycapsule(capsule, "datafusion_catalog_provider")?;
620
621            let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
622            let provider: ForeignCatalogProvider = provider.into();
623            Arc::new(provider) as Arc<dyn CatalogProvider>
624        } else {
625            match provider.extract::<PyCatalog>() {
626                Ok(py_catalog) => py_catalog.catalog,
627                Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
628                    as Arc<dyn CatalogProvider>,
629            }
630        };
631
632        let _ = self.ctx.register_catalog(name, provider);
633
634        Ok(())
635    }
636
637    /// Construct datafusion dataframe from Arrow Table
638    pub fn register_table_provider(
639        &self,
640        name: &str,
641        provider: Bound<'_, PyAny>,
642    ) -> PyDataFusionResult<()> {
643        // Deprecated: use `register_table` instead
644        self.register_table(name, provider)
645    }
646
647    pub fn register_record_batches(
648        &self,
649        name: &str,
650        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
651    ) -> PyDataFusionResult<()> {
652        let schema = partitions.0[0][0].schema();
653        let table = MemTable::try_new(schema, partitions.0)?;
654        self.ctx.register_table(name, Arc::new(table))?;
655        Ok(())
656    }
657
658    #[allow(clippy::too_many_arguments)]
659    #[pyo3(signature = (name, path, table_partition_cols=vec![],
660                        parquet_pruning=true,
661                        file_extension=".parquet",
662                        skip_metadata=true,
663                        schema=None,
664                        file_sort_order=None))]
665    pub fn register_parquet(
666        &self,
667        name: &str,
668        path: &str,
669        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
670        parquet_pruning: bool,
671        file_extension: &str,
672        skip_metadata: bool,
673        schema: Option<PyArrowType<Schema>>,
674        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
675        py: Python,
676    ) -> PyDataFusionResult<()> {
677        let mut options = ParquetReadOptions::default()
678            .table_partition_cols(
679                table_partition_cols
680                    .into_iter()
681                    .map(|(name, ty)| (name, ty.0))
682                    .collect::<Vec<(String, DataType)>>(),
683            )
684            .parquet_pruning(parquet_pruning)
685            .skip_metadata(skip_metadata);
686        options.file_extension = file_extension;
687        options.schema = schema.as_ref().map(|x| &x.0);
688        options.file_sort_order = file_sort_order
689            .unwrap_or_default()
690            .into_iter()
691            .map(|e| e.into_iter().map(|f| f.into()).collect())
692            .collect();
693
694        let result = self.ctx.register_parquet(name, path, options);
695        wait_for_future(py, result)??;
696        Ok(())
697    }
698
699    #[allow(clippy::too_many_arguments)]
700    #[pyo3(signature = (name,
701                        path,
702                        schema=None,
703                        has_header=true,
704                        delimiter=",",
705                        schema_infer_max_records=1000,
706                        file_extension=".csv",
707                        file_compression_type=None))]
708    pub fn register_csv(
709        &self,
710        name: &str,
711        path: &Bound<'_, PyAny>,
712        schema: Option<PyArrowType<Schema>>,
713        has_header: bool,
714        delimiter: &str,
715        schema_infer_max_records: usize,
716        file_extension: &str,
717        file_compression_type: Option<String>,
718        py: Python,
719    ) -> PyDataFusionResult<()> {
720        let delimiter = delimiter.as_bytes();
721        if delimiter.len() != 1 {
722            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
723                "Delimiter must be a single character",
724            )));
725        }
726
727        let mut options = CsvReadOptions::new()
728            .has_header(has_header)
729            .delimiter(delimiter[0])
730            .schema_infer_max_records(schema_infer_max_records)
731            .file_extension(file_extension)
732            .file_compression_type(parse_file_compression_type(file_compression_type)?);
733        options.schema = schema.as_ref().map(|x| &x.0);
734
735        if path.is_instance_of::<PyList>() {
736            let paths = path.extract::<Vec<String>>()?;
737            let result = self.register_csv_from_multiple_paths(name, paths, options);
738            wait_for_future(py, result)??;
739        } else {
740            let path = path.extract::<String>()?;
741            let result = self.ctx.register_csv(name, &path, options);
742            wait_for_future(py, result)??;
743        }
744
745        Ok(())
746    }
747
748    #[allow(clippy::too_many_arguments)]
749    #[pyo3(signature = (name,
750                        path,
751                        schema=None,
752                        schema_infer_max_records=1000,
753                        file_extension=".json",
754                        table_partition_cols=vec![],
755                        file_compression_type=None))]
756    pub fn register_json(
757        &self,
758        name: &str,
759        path: PathBuf,
760        schema: Option<PyArrowType<Schema>>,
761        schema_infer_max_records: usize,
762        file_extension: &str,
763        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
764        file_compression_type: Option<String>,
765        py: Python,
766    ) -> PyDataFusionResult<()> {
767        let path = path
768            .to_str()
769            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
770
771        let mut options = NdJsonReadOptions::default()
772            .file_compression_type(parse_file_compression_type(file_compression_type)?)
773            .table_partition_cols(
774                table_partition_cols
775                    .into_iter()
776                    .map(|(name, ty)| (name, ty.0))
777                    .collect::<Vec<(String, DataType)>>(),
778            );
779        options.schema_infer_max_records = schema_infer_max_records;
780        options.file_extension = file_extension;
781        options.schema = schema.as_ref().map(|x| &x.0);
782
783        let result = self.ctx.register_json(name, path, options);
784        wait_for_future(py, result)??;
785
786        Ok(())
787    }
788
789    #[allow(clippy::too_many_arguments)]
790    #[pyo3(signature = (name,
791                        path,
792                        schema=None,
793                        file_extension=".avro",
794                        table_partition_cols=vec![]))]
795    pub fn register_avro(
796        &self,
797        name: &str,
798        path: PathBuf,
799        schema: Option<PyArrowType<Schema>>,
800        file_extension: &str,
801        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
802        py: Python,
803    ) -> PyDataFusionResult<()> {
804        let path = path
805            .to_str()
806            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
807
808        let mut options = AvroReadOptions::default().table_partition_cols(
809            table_partition_cols
810                .into_iter()
811                .map(|(name, ty)| (name, ty.0))
812                .collect::<Vec<(String, DataType)>>(),
813        );
814        options.file_extension = file_extension;
815        options.schema = schema.as_ref().map(|x| &x.0);
816
817        let result = self.ctx.register_avro(name, path, options);
818        wait_for_future(py, result)??;
819
820        Ok(())
821    }
822
823    // Registers a PyArrow.Dataset
824    pub fn register_dataset(
825        &self,
826        name: &str,
827        dataset: &Bound<'_, PyAny>,
828        py: Python,
829    ) -> PyDataFusionResult<()> {
830        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
831
832        self.ctx.register_table(name, table)?;
833
834        Ok(())
835    }
836
837    pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
838        self.ctx.register_udf(udf.function);
839        Ok(())
840    }
841
842    pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
843        self.ctx.register_udaf(udaf.function);
844        Ok(())
845    }
846
847    pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
848        self.ctx.register_udwf(udwf.function);
849        Ok(())
850    }
851
852    #[pyo3(signature = (name="datafusion"))]
853    pub fn catalog(&self, name: &str) -> PyResult<PyObject> {
854        let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
855            "Catalog with name {name} doesn't exist."
856        )))?;
857
858        Python::with_gil(|py| {
859            match catalog
860                .as_any()
861                .downcast_ref::<RustWrappedPyCatalogProvider>()
862            {
863                Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
864                None => PyCatalog::from(catalog).into_py_any(py),
865            }
866        })
867    }
868
869    pub fn catalog_names(&self) -> HashSet<String> {
870        self.ctx.catalog_names().into_iter().collect()
871    }
872
873    pub fn tables(&self) -> HashSet<String> {
874        self.ctx
875            .catalog_names()
876            .into_iter()
877            .filter_map(|name| self.ctx.catalog(&name))
878            .flat_map(move |catalog| {
879                catalog
880                    .schema_names()
881                    .into_iter()
882                    .filter_map(move |name| catalog.schema(&name))
883            })
884            .flat_map(|schema| schema.table_names())
885            .collect()
886    }
887
888    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
889        let res = wait_for_future(py, self.ctx.table(name))
890            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
891        match res {
892            Ok(df) => Ok(PyDataFrame::new(df)),
893            Err(e) => {
894                if let datafusion::error::DataFusionError::Plan(msg) = &e {
895                    if msg.contains("No table named") {
896                        return Err(PyKeyError::new_err(msg.to_string()));
897                    }
898                }
899                Err(py_datafusion_err(e))
900            }
901        }
902    }
903
904    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
905        Ok(self.ctx.table_exist(name)?)
906    }
907
908    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
909        Ok(PyDataFrame::new(self.ctx.read_empty()?))
910    }
911
912    pub fn session_id(&self) -> String {
913        self.ctx.session_id()
914    }
915
916    #[allow(clippy::too_many_arguments)]
917    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
918    pub fn read_json(
919        &self,
920        path: PathBuf,
921        schema: Option<PyArrowType<Schema>>,
922        schema_infer_max_records: usize,
923        file_extension: &str,
924        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
925        file_compression_type: Option<String>,
926        py: Python,
927    ) -> PyDataFusionResult<PyDataFrame> {
928        let path = path
929            .to_str()
930            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
931        let mut options = NdJsonReadOptions::default()
932            .table_partition_cols(
933                table_partition_cols
934                    .into_iter()
935                    .map(|(name, ty)| (name, ty.0))
936                    .collect::<Vec<(String, DataType)>>(),
937            )
938            .file_compression_type(parse_file_compression_type(file_compression_type)?);
939        options.schema_infer_max_records = schema_infer_max_records;
940        options.file_extension = file_extension;
941        let df = if let Some(schema) = schema {
942            options.schema = Some(&schema.0);
943            let result = self.ctx.read_json(path, options);
944            wait_for_future(py, result)??
945        } else {
946            let result = self.ctx.read_json(path, options);
947            wait_for_future(py, result)??
948        };
949        Ok(PyDataFrame::new(df))
950    }
951
952    #[allow(clippy::too_many_arguments)]
953    #[pyo3(signature = (
954        path,
955        schema=None,
956        has_header=true,
957        delimiter=",",
958        schema_infer_max_records=1000,
959        file_extension=".csv",
960        table_partition_cols=vec![],
961        file_compression_type=None))]
962    pub fn read_csv(
963        &self,
964        path: &Bound<'_, PyAny>,
965        schema: Option<PyArrowType<Schema>>,
966        has_header: bool,
967        delimiter: &str,
968        schema_infer_max_records: usize,
969        file_extension: &str,
970        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
971        file_compression_type: Option<String>,
972        py: Python,
973    ) -> PyDataFusionResult<PyDataFrame> {
974        let delimiter = delimiter.as_bytes();
975        if delimiter.len() != 1 {
976            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
977                "Delimiter must be a single character",
978            )));
979        };
980
981        let mut options = CsvReadOptions::new()
982            .has_header(has_header)
983            .delimiter(delimiter[0])
984            .schema_infer_max_records(schema_infer_max_records)
985            .file_extension(file_extension)
986            .table_partition_cols(
987                table_partition_cols
988                    .into_iter()
989                    .map(|(name, ty)| (name, ty.0))
990                    .collect::<Vec<(String, DataType)>>(),
991            )
992            .file_compression_type(parse_file_compression_type(file_compression_type)?);
993        options.schema = schema.as_ref().map(|x| &x.0);
994
995        if path.is_instance_of::<PyList>() {
996            let paths = path.extract::<Vec<String>>()?;
997            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
998            let result = self.ctx.read_csv(paths, options);
999            let df = PyDataFrame::new(wait_for_future(py, result)??);
1000            Ok(df)
1001        } else {
1002            let path = path.extract::<String>()?;
1003            let result = self.ctx.read_csv(path, options);
1004            let df = PyDataFrame::new(wait_for_future(py, result)??);
1005            Ok(df)
1006        }
1007    }
1008
1009    #[allow(clippy::too_many_arguments)]
1010    #[pyo3(signature = (
1011        path,
1012        table_partition_cols=vec![],
1013        parquet_pruning=true,
1014        file_extension=".parquet",
1015        skip_metadata=true,
1016        schema=None,
1017        file_sort_order=None))]
1018    pub fn read_parquet(
1019        &self,
1020        path: &str,
1021        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1022        parquet_pruning: bool,
1023        file_extension: &str,
1024        skip_metadata: bool,
1025        schema: Option<PyArrowType<Schema>>,
1026        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1027        py: Python,
1028    ) -> PyDataFusionResult<PyDataFrame> {
1029        let mut options = ParquetReadOptions::default()
1030            .table_partition_cols(
1031                table_partition_cols
1032                    .into_iter()
1033                    .map(|(name, ty)| (name, ty.0))
1034                    .collect::<Vec<(String, DataType)>>(),
1035            )
1036            .parquet_pruning(parquet_pruning)
1037            .skip_metadata(skip_metadata);
1038        options.file_extension = file_extension;
1039        options.schema = schema.as_ref().map(|x| &x.0);
1040        options.file_sort_order = file_sort_order
1041            .unwrap_or_default()
1042            .into_iter()
1043            .map(|e| e.into_iter().map(|f| f.into()).collect())
1044            .collect();
1045
1046        let result = self.ctx.read_parquet(path, options);
1047        let df = PyDataFrame::new(wait_for_future(py, result)??);
1048        Ok(df)
1049    }
1050
1051    #[allow(clippy::too_many_arguments)]
1052    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1053    pub fn read_avro(
1054        &self,
1055        path: &str,
1056        schema: Option<PyArrowType<Schema>>,
1057        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1058        file_extension: &str,
1059        py: Python,
1060    ) -> PyDataFusionResult<PyDataFrame> {
1061        let mut options = AvroReadOptions::default().table_partition_cols(
1062            table_partition_cols
1063                .into_iter()
1064                .map(|(name, ty)| (name, ty.0))
1065                .collect::<Vec<(String, DataType)>>(),
1066        );
1067        options.file_extension = file_extension;
1068        let df = if let Some(schema) = schema {
1069            options.schema = Some(&schema.0);
1070            let read_future = self.ctx.read_avro(path, options);
1071            wait_for_future(py, read_future)??
1072        } else {
1073            let read_future = self.ctx.read_avro(path, options);
1074            wait_for_future(py, read_future)??
1075        };
1076        Ok(PyDataFrame::new(df))
1077    }
1078
1079    pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1080        let table = PyTable::new(&table)?;
1081        let df = self.ctx.read_table(table.table())?;
1082        Ok(PyDataFrame::new(df))
1083    }
1084
1085    fn __repr__(&self) -> PyResult<String> {
1086        let config = self.ctx.copied_config();
1087        let mut config_entries = config
1088            .options()
1089            .entries()
1090            .iter()
1091            .filter(|e| e.value.is_some())
1092            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1093            .collect::<Vec<_>>();
1094        config_entries.sort();
1095        Ok(format!(
1096            "SessionContext: id={}; configs=[\n\t{}]",
1097            self.session_id(),
1098            config_entries.join("\n\t")
1099        ))
1100    }
1101
1102    /// Execute a partition of an execution plan and return a stream of record batches
1103    pub fn execute(
1104        &self,
1105        plan: PyExecutionPlan,
1106        part: usize,
1107        py: Python,
1108    ) -> PyDataFusionResult<PyRecordBatchStream> {
1109        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1110        // create a Tokio runtime to run the async code
1111        let rt = &get_tokio_runtime().0;
1112        let plan = plan.plan.clone();
1113        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1114            rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1115        let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1116        Ok(PyRecordBatchStream::new(stream))
1117    }
1118}
1119
1120impl PySessionContext {
1121    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1122        self.ctx.table(name).await
1123    }
1124
1125    async fn register_csv_from_multiple_paths(
1126        &self,
1127        name: &str,
1128        table_paths: Vec<String>,
1129        options: CsvReadOptions<'_>,
1130    ) -> datafusion::common::Result<()> {
1131        let table_paths = table_paths.to_urls()?;
1132        let session_config = self.ctx.copied_config();
1133        let listing_options =
1134            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1135
1136        let option_extension = listing_options.file_extension.clone();
1137
1138        if table_paths.is_empty() {
1139            return exec_err!("No table paths were provided");
1140        }
1141
1142        // check if the file extension matches the expected extension
1143        for path in &table_paths {
1144            let file_path = path.as_str();
1145            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1146                return exec_err!(
1147                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1148                );
1149            }
1150        }
1151
1152        let resolved_schema = options
1153            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1154            .await?;
1155
1156        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1157            .with_listing_options(listing_options)
1158            .with_schema(resolved_schema);
1159        let table = ListingTable::try_new(config)?;
1160        self.ctx
1161            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1162        Ok(())
1163    }
1164}
1165
1166pub fn parse_file_compression_type(
1167    file_compression_type: Option<String>,
1168) -> Result<FileCompressionType, PyErr> {
1169    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1170        .map_err(|_| {
1171            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1172        })
1173}
1174
1175impl From<PySessionContext> for SessionContext {
1176    fn from(ctx: PySessionContext) -> SessionContext {
1177        ctx.ctx
1178    }
1179}
1180
1181impl From<SessionContext> for PySessionContext {
1182    fn from(ctx: SessionContext) -> PySessionContext {
1183        PySessionContext { ctx }
1184    }
1185}