datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
27use datafusion::arrow::pyarrow::PyArrowType;
28use datafusion::arrow::record_batch::RecordBatch;
29use datafusion::catalog::CatalogProvider;
30use datafusion::common::{exec_err, ScalarValue, TableReference};
31use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
32use datafusion::datasource::file_format::parquet::ParquetFormat;
33use datafusion::datasource::listing::{
34    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35};
36use datafusion::datasource::{MemTable, TableProvider};
37use datafusion::execution::context::{
38    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
39};
40use datafusion::execution::disk_manager::DiskManagerMode;
41use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
42use datafusion::execution::options::ReadOptions;
43use datafusion::execution::runtime_env::RuntimeEnvBuilder;
44use datafusion::execution::session_state::SessionStateBuilder;
45use datafusion::prelude::{
46    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
47};
48use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider};
49use object_store::ObjectStore;
50use pyo3::exceptions::{PyKeyError, PyValueError};
51use pyo3::prelude::*;
52use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
53use pyo3::IntoPyObjectExt;
54use url::Url;
55use uuid::Uuid;
56
57use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
58use crate::common::data_type::PyScalarValue;
59use crate::dataframe::PyDataFrame;
60use crate::dataset::Dataset;
61use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
62use crate::expr::sort_expr::PySortExpr;
63use crate::physical_plan::PyExecutionPlan;
64use crate::record_batch::PyRecordBatchStream;
65use crate::sql::exceptions::py_value_err;
66use crate::sql::logical::PyLogicalPlan;
67use crate::sql::util::replace_placeholders_with_strings;
68use crate::store::StorageContexts;
69use crate::table::PyTable;
70use crate::udaf::PyAggregateUDF;
71use crate::udf::PyScalarUDF;
72use crate::udtf::PyTableFunction;
73use crate::udwf::PyWindowUDF;
74use crate::utils::{get_global_ctx, spawn_future, validate_pycapsule, wait_for_future};
75
76/// Configuration options for a SessionContext
77#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
78#[derive(Clone, Default)]
79pub struct PySessionConfig {
80    pub config: SessionConfig,
81}
82
83impl From<SessionConfig> for PySessionConfig {
84    fn from(config: SessionConfig) -> Self {
85        Self { config }
86    }
87}
88
89#[pymethods]
90impl PySessionConfig {
91    #[pyo3(signature = (config_options=None))]
92    #[new]
93    fn new(config_options: Option<HashMap<String, String>>) -> Self {
94        let mut config = SessionConfig::new();
95        if let Some(hash_map) = config_options {
96            for (k, v) in &hash_map {
97                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
98            }
99        }
100
101        Self { config }
102    }
103
104    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
105        Self::from(
106            self.config
107                .clone()
108                .with_create_default_catalog_and_schema(enabled),
109        )
110    }
111
112    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
113        Self::from(
114            self.config
115                .clone()
116                .with_default_catalog_and_schema(catalog, schema),
117        )
118    }
119
120    fn with_information_schema(&self, enabled: bool) -> Self {
121        Self::from(self.config.clone().with_information_schema(enabled))
122    }
123
124    fn with_batch_size(&self, batch_size: usize) -> Self {
125        Self::from(self.config.clone().with_batch_size(batch_size))
126    }
127
128    fn with_target_partitions(&self, target_partitions: usize) -> Self {
129        Self::from(
130            self.config
131                .clone()
132                .with_target_partitions(target_partitions),
133        )
134    }
135
136    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
137        Self::from(self.config.clone().with_repartition_aggregations(enabled))
138    }
139
140    fn with_repartition_joins(&self, enabled: bool) -> Self {
141        Self::from(self.config.clone().with_repartition_joins(enabled))
142    }
143
144    fn with_repartition_windows(&self, enabled: bool) -> Self {
145        Self::from(self.config.clone().with_repartition_windows(enabled))
146    }
147
148    fn with_repartition_sorts(&self, enabled: bool) -> Self {
149        Self::from(self.config.clone().with_repartition_sorts(enabled))
150    }
151
152    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
153        Self::from(self.config.clone().with_repartition_file_scans(enabled))
154    }
155
156    fn with_repartition_file_min_size(&self, size: usize) -> Self {
157        Self::from(self.config.clone().with_repartition_file_min_size(size))
158    }
159
160    fn with_parquet_pruning(&self, enabled: bool) -> Self {
161        Self::from(self.config.clone().with_parquet_pruning(enabled))
162    }
163
164    fn set(&self, key: &str, value: &str) -> Self {
165        Self::from(self.config.clone().set_str(key, value))
166    }
167}
168
169/// Runtime options for a SessionContext
170#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
171#[derive(Clone)]
172pub struct PyRuntimeEnvBuilder {
173    pub builder: RuntimeEnvBuilder,
174}
175
176#[pymethods]
177impl PyRuntimeEnvBuilder {
178    #[new]
179    fn new() -> Self {
180        Self {
181            builder: RuntimeEnvBuilder::default(),
182        }
183    }
184
185    fn with_disk_manager_disabled(&self) -> Self {
186        let mut runtime_builder = self.builder.clone();
187
188        let mut disk_mgr_builder = runtime_builder
189            .disk_manager_builder
190            .clone()
191            .unwrap_or_default();
192        disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
193
194        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
195        Self {
196            builder: runtime_builder,
197        }
198    }
199
200    fn with_disk_manager_os(&self) -> Self {
201        let mut runtime_builder = self.builder.clone();
202
203        let mut disk_mgr_builder = runtime_builder
204            .disk_manager_builder
205            .clone()
206            .unwrap_or_default();
207        disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
208
209        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
210        Self {
211            builder: runtime_builder,
212        }
213    }
214
215    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
216        let paths = paths.iter().map(|s| s.into()).collect();
217        let mut runtime_builder = self.builder.clone();
218
219        let mut disk_mgr_builder = runtime_builder
220            .disk_manager_builder
221            .clone()
222            .unwrap_or_default();
223        disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
224
225        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
226        Self {
227            builder: runtime_builder,
228        }
229    }
230
231    fn with_unbounded_memory_pool(&self) -> Self {
232        let builder = self.builder.clone();
233        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
234        Self { builder }
235    }
236
237    fn with_fair_spill_pool(&self, size: usize) -> Self {
238        let builder = self.builder.clone();
239        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
240        Self { builder }
241    }
242
243    fn with_greedy_memory_pool(&self, size: usize) -> Self {
244        let builder = self.builder.clone();
245        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
246        Self { builder }
247    }
248
249    fn with_temp_file_path(&self, path: &str) -> Self {
250        let builder = self.builder.clone();
251        let builder = builder.with_temp_file_path(path);
252        Self { builder }
253    }
254}
255
256/// `PySQLOptions` allows you to specify options to the sql execution.
257#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
258#[derive(Clone)]
259pub struct PySQLOptions {
260    pub options: SQLOptions,
261}
262
263impl From<SQLOptions> for PySQLOptions {
264    fn from(options: SQLOptions) -> Self {
265        Self { options }
266    }
267}
268
269#[pymethods]
270impl PySQLOptions {
271    #[new]
272    fn new() -> Self {
273        let options = SQLOptions::new();
274        Self { options }
275    }
276
277    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
278    fn with_allow_ddl(&self, allow: bool) -> Self {
279        Self::from(self.options.with_allow_ddl(allow))
280    }
281
282    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
283    pub fn with_allow_dml(&self, allow: bool) -> Self {
284        Self::from(self.options.with_allow_dml(allow))
285    }
286
287    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
288    pub fn with_allow_statements(&self, allow: bool) -> Self {
289        Self::from(self.options.with_allow_statements(allow))
290    }
291}
292
293/// `PySessionContext` is able to plan and execute DataFusion plans.
294/// It has a powerful optimizer, a physical planner for local execution, and a
295/// multi-threaded execution engine to perform the execution.
296#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
297#[derive(Clone)]
298pub struct PySessionContext {
299    pub ctx: SessionContext,
300}
301
302#[pymethods]
303impl PySessionContext {
304    #[pyo3(signature = (config=None, runtime=None))]
305    #[new]
306    pub fn new(
307        config: Option<PySessionConfig>,
308        runtime: Option<PyRuntimeEnvBuilder>,
309    ) -> PyDataFusionResult<Self> {
310        let config = if let Some(c) = config {
311            c.config
312        } else {
313            SessionConfig::default().with_information_schema(true)
314        };
315        let runtime_env_builder = if let Some(c) = runtime {
316            c.builder
317        } else {
318            RuntimeEnvBuilder::default()
319        };
320        let runtime = Arc::new(runtime_env_builder.build()?);
321        let session_state = SessionStateBuilder::new()
322            .with_config(config)
323            .with_runtime_env(runtime)
324            .with_default_features()
325            .build();
326        Ok(PySessionContext {
327            ctx: SessionContext::new_with_state(session_state),
328        })
329    }
330
331    pub fn enable_url_table(&self) -> PyResult<Self> {
332        Ok(PySessionContext {
333            ctx: self.ctx.clone().enable_url_table(),
334        })
335    }
336
337    #[classmethod]
338    #[pyo3(signature = ())]
339    fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
340        Ok(Self {
341            ctx: get_global_ctx().clone(),
342        })
343    }
344
345    /// Register an object store with the given name
346    #[pyo3(signature = (scheme, store, host=None))]
347    pub fn register_object_store(
348        &self,
349        scheme: &str,
350        store: StorageContexts,
351        host: Option<&str>,
352    ) -> PyResult<()> {
353        // for most stores the "host" is the bucket name and can be inferred from the store
354        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
355            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
356            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
357            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
358            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
359            StorageContexts::HTTP(http) => (http.store, http.url),
360        };
361
362        // let users override the host to match the api signature from upstream
363        let derived_host = if let Some(host) = host {
364            host
365        } else {
366            &upstream_host
367        };
368        let url_string = format!("{scheme}{derived_host}");
369        let url = Url::parse(&url_string).unwrap();
370        self.ctx.runtime_env().register_object_store(&url, store);
371        Ok(())
372    }
373
374    #[allow(clippy::too_many_arguments)]
375    #[pyo3(signature = (name, path, table_partition_cols=vec![],
376    file_extension=".parquet",
377    schema=None,
378    file_sort_order=None))]
379    pub fn register_listing_table(
380        &self,
381        name: &str,
382        path: &str,
383        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
384        file_extension: &str,
385        schema: Option<PyArrowType<Schema>>,
386        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
387        py: Python,
388    ) -> PyDataFusionResult<()> {
389        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
390            .with_file_extension(file_extension)
391            .with_table_partition_cols(
392                table_partition_cols
393                    .into_iter()
394                    .map(|(name, ty)| (name, ty.0))
395                    .collect::<Vec<(String, DataType)>>(),
396            )
397            .with_file_sort_order(
398                file_sort_order
399                    .unwrap_or_default()
400                    .into_iter()
401                    .map(|e| e.into_iter().map(|f| f.into()).collect())
402                    .collect(),
403            );
404        let table_path = ListingTableUrl::parse(path)?;
405        let resolved_schema: SchemaRef = match schema {
406            Some(s) => Arc::new(s.0),
407            None => {
408                let state = self.ctx.state();
409                let schema = options.infer_schema(&state, &table_path);
410                wait_for_future(py, schema)??
411            }
412        };
413        let config = ListingTableConfig::new(table_path)
414            .with_listing_options(options)
415            .with_schema(resolved_schema);
416        let table = ListingTable::try_new(config)?;
417        self.ctx.register_table(name, Arc::new(table))?;
418        Ok(())
419    }
420
421    pub fn register_udtf(&self, func: PyTableFunction) {
422        let name = func.name.clone();
423        let func = Arc::new(func);
424        self.ctx.register_udtf(&name, func);
425    }
426
427    #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
428    pub fn sql_with_options(
429        &self,
430        py: Python,
431        mut query: String,
432        options: Option<PySQLOptions>,
433        param_values: HashMap<String, PyScalarValue>,
434        param_strings: HashMap<String, String>,
435    ) -> PyDataFusionResult<PyDataFrame> {
436        let options = if let Some(options) = options {
437            options.options
438        } else {
439            SQLOptions::new()
440        };
441
442        let param_values = param_values
443            .into_iter()
444            .map(|(name, value)| (name, ScalarValue::from(value)))
445            .collect::<HashMap<_, _>>();
446
447        let state = self.ctx.state();
448        let dialect = state.config().options().sql_parser.dialect.as_ref();
449
450        if !param_strings.is_empty() {
451            query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
452        }
453
454        let mut df = wait_for_future(py, async {
455            self.ctx.sql_with_options(&query, options).await
456        })??;
457
458        if !param_values.is_empty() {
459            df = df.with_param_values(param_values)?;
460        }
461
462        Ok(PyDataFrame::new(df))
463    }
464
465    #[pyo3(signature = (partitions, name=None, schema=None))]
466    pub fn create_dataframe(
467        &self,
468        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
469        name: Option<&str>,
470        schema: Option<PyArrowType<Schema>>,
471        py: Python,
472    ) -> PyDataFusionResult<PyDataFrame> {
473        let schema = if let Some(schema) = schema {
474            SchemaRef::from(schema.0)
475        } else {
476            partitions.0[0][0].schema()
477        };
478
479        let table = MemTable::try_new(schema, partitions.0)?;
480
481        // generate a random (unique) name for this table if none is provided
482        // table name cannot start with numeric digit
483        let table_name = match name {
484            Some(val) => val.to_owned(),
485            None => {
486                "c".to_owned()
487                    + Uuid::new_v4()
488                        .simple()
489                        .encode_lower(&mut Uuid::encode_buffer())
490            }
491        };
492
493        self.ctx.register_table(&*table_name, Arc::new(table))?;
494
495        let table = wait_for_future(py, self._table(&table_name))??;
496
497        let df = PyDataFrame::new(table);
498        Ok(df)
499    }
500
501    /// Create a DataFrame from an existing logical plan
502    pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
503        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
504    }
505
506    /// Construct datafusion dataframe from Python list
507    #[pyo3(signature = (data, name=None))]
508    pub fn from_pylist(
509        &self,
510        data: Bound<'_, PyList>,
511        name: Option<&str>,
512    ) -> PyResult<PyDataFrame> {
513        // Acquire GIL Token
514        let py = data.py();
515
516        // Instantiate pyarrow Table object & convert to Arrow Table
517        let table_class = py.import("pyarrow")?.getattr("Table")?;
518        let args = PyTuple::new(py, &[data])?;
519        let table = table_class.call_method1("from_pylist", args)?;
520
521        // Convert Arrow Table to datafusion DataFrame
522        let df = self.from_arrow(table, name, py)?;
523        Ok(df)
524    }
525
526    /// Construct datafusion dataframe from Python dictionary
527    #[pyo3(signature = (data, name=None))]
528    pub fn from_pydict(
529        &self,
530        data: Bound<'_, PyDict>,
531        name: Option<&str>,
532    ) -> PyResult<PyDataFrame> {
533        // Acquire GIL Token
534        let py = data.py();
535
536        // Instantiate pyarrow Table object & convert to Arrow Table
537        let table_class = py.import("pyarrow")?.getattr("Table")?;
538        let args = PyTuple::new(py, &[data])?;
539        let table = table_class.call_method1("from_pydict", args)?;
540
541        // Convert Arrow Table to datafusion DataFrame
542        let df = self.from_arrow(table, name, py)?;
543        Ok(df)
544    }
545
546    /// Construct datafusion dataframe from Arrow Table
547    #[pyo3(signature = (data, name=None))]
548    pub fn from_arrow(
549        &self,
550        data: Bound<'_, PyAny>,
551        name: Option<&str>,
552        py: Python,
553    ) -> PyDataFusionResult<PyDataFrame> {
554        let (schema, batches) =
555            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
556                // Works for any object that implements __arrow_c_stream__ in pycapsule.
557
558                let schema = stream_reader.schema().as_ref().to_owned();
559                let batches = stream_reader
560                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
561
562                (schema, batches)
563            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
564                // While this says RecordBatch, it will work for any object that implements
565                // __arrow_c_array__ and returns a StructArray.
566
567                (array.schema().as_ref().to_owned(), vec![array])
568            } else {
569                return Err(PyDataFusionError::Common(
570                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
571                ));
572            };
573
574        // Because create_dataframe() expects a vector of vectors of record batches
575        // here we need to wrap the vector of record batches in an additional vector
576        let list_of_batches = PyArrowType::from(vec![batches]);
577        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
578    }
579
580    /// Construct datafusion dataframe from pandas
581    #[allow(clippy::wrong_self_convention)]
582    #[pyo3(signature = (data, name=None))]
583    pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
584        // Obtain GIL token
585        let py = data.py();
586
587        // Instantiate pyarrow Table object & convert to Arrow Table
588        let table_class = py.import("pyarrow")?.getattr("Table")?;
589        let args = PyTuple::new(py, &[data])?;
590        let table = table_class.call_method1("from_pandas", args)?;
591
592        // Convert Arrow Table to datafusion DataFrame
593        let df = self.from_arrow(table, name, py)?;
594        Ok(df)
595    }
596
597    /// Construct datafusion dataframe from polars
598    #[pyo3(signature = (data, name=None))]
599    pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
600        // Convert Polars dataframe to Arrow Table
601        let table = data.call_method0("to_arrow")?;
602
603        // Convert Arrow Table to datafusion DataFrame
604        let df = self.from_arrow(table, name, data.py())?;
605        Ok(df)
606    }
607
608    pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
609        let table = PyTable::new(&table)?;
610
611        self.ctx.register_table(name, table.table)?;
612        Ok(())
613    }
614
615    pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
616        self.ctx.deregister_table(name)?;
617        Ok(())
618    }
619
620    pub fn register_catalog_provider(
621        &self,
622        name: &str,
623        provider: Bound<'_, PyAny>,
624    ) -> PyDataFusionResult<()> {
625        let provider = if provider.hasattr("__datafusion_catalog_provider__")? {
626            let capsule = provider
627                .getattr("__datafusion_catalog_provider__")?
628                .call0()?;
629            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
630            validate_pycapsule(capsule, "datafusion_catalog_provider")?;
631
632            let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
633            let provider: ForeignCatalogProvider = provider.into();
634            Arc::new(provider) as Arc<dyn CatalogProvider>
635        } else {
636            match provider.extract::<PyCatalog>() {
637                Ok(py_catalog) => py_catalog.catalog,
638                Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
639                    as Arc<dyn CatalogProvider>,
640            }
641        };
642
643        let _ = self.ctx.register_catalog(name, provider);
644
645        Ok(())
646    }
647
648    /// Construct datafusion dataframe from Arrow Table
649    pub fn register_table_provider(
650        &self,
651        name: &str,
652        provider: Bound<'_, PyAny>,
653    ) -> PyDataFusionResult<()> {
654        // Deprecated: use `register_table` instead
655        self.register_table(name, provider)
656    }
657
658    pub fn register_record_batches(
659        &self,
660        name: &str,
661        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
662    ) -> PyDataFusionResult<()> {
663        let schema = partitions.0[0][0].schema();
664        let table = MemTable::try_new(schema, partitions.0)?;
665        self.ctx.register_table(name, Arc::new(table))?;
666        Ok(())
667    }
668
669    #[allow(clippy::too_many_arguments)]
670    #[pyo3(signature = (name, path, table_partition_cols=vec![],
671                        parquet_pruning=true,
672                        file_extension=".parquet",
673                        skip_metadata=true,
674                        schema=None,
675                        file_sort_order=None))]
676    pub fn register_parquet(
677        &self,
678        name: &str,
679        path: &str,
680        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
681        parquet_pruning: bool,
682        file_extension: &str,
683        skip_metadata: bool,
684        schema: Option<PyArrowType<Schema>>,
685        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
686        py: Python,
687    ) -> PyDataFusionResult<()> {
688        let mut options = ParquetReadOptions::default()
689            .table_partition_cols(
690                table_partition_cols
691                    .into_iter()
692                    .map(|(name, ty)| (name, ty.0))
693                    .collect::<Vec<(String, DataType)>>(),
694            )
695            .parquet_pruning(parquet_pruning)
696            .skip_metadata(skip_metadata);
697        options.file_extension = file_extension;
698        options.schema = schema.as_ref().map(|x| &x.0);
699        options.file_sort_order = file_sort_order
700            .unwrap_or_default()
701            .into_iter()
702            .map(|e| e.into_iter().map(|f| f.into()).collect())
703            .collect();
704
705        let result = self.ctx.register_parquet(name, path, options);
706        wait_for_future(py, result)??;
707        Ok(())
708    }
709
710    #[allow(clippy::too_many_arguments)]
711    #[pyo3(signature = (name,
712                        path,
713                        schema=None,
714                        has_header=true,
715                        delimiter=",",
716                        schema_infer_max_records=1000,
717                        file_extension=".csv",
718                        file_compression_type=None))]
719    pub fn register_csv(
720        &self,
721        name: &str,
722        path: &Bound<'_, PyAny>,
723        schema: Option<PyArrowType<Schema>>,
724        has_header: bool,
725        delimiter: &str,
726        schema_infer_max_records: usize,
727        file_extension: &str,
728        file_compression_type: Option<String>,
729        py: Python,
730    ) -> PyDataFusionResult<()> {
731        let delimiter = delimiter.as_bytes();
732        if delimiter.len() != 1 {
733            return Err(PyDataFusionError::PythonError(py_value_err(
734                "Delimiter must be a single character",
735            )));
736        }
737
738        let mut options = CsvReadOptions::new()
739            .has_header(has_header)
740            .delimiter(delimiter[0])
741            .schema_infer_max_records(schema_infer_max_records)
742            .file_extension(file_extension)
743            .file_compression_type(parse_file_compression_type(file_compression_type)?);
744        options.schema = schema.as_ref().map(|x| &x.0);
745
746        if path.is_instance_of::<PyList>() {
747            let paths = path.extract::<Vec<String>>()?;
748            let result = self.register_csv_from_multiple_paths(name, paths, options);
749            wait_for_future(py, result)??;
750        } else {
751            let path = path.extract::<String>()?;
752            let result = self.ctx.register_csv(name, &path, options);
753            wait_for_future(py, result)??;
754        }
755
756        Ok(())
757    }
758
759    #[allow(clippy::too_many_arguments)]
760    #[pyo3(signature = (name,
761                        path,
762                        schema=None,
763                        schema_infer_max_records=1000,
764                        file_extension=".json",
765                        table_partition_cols=vec![],
766                        file_compression_type=None))]
767    pub fn register_json(
768        &self,
769        name: &str,
770        path: PathBuf,
771        schema: Option<PyArrowType<Schema>>,
772        schema_infer_max_records: usize,
773        file_extension: &str,
774        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
775        file_compression_type: Option<String>,
776        py: Python,
777    ) -> PyDataFusionResult<()> {
778        let path = path
779            .to_str()
780            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
781
782        let mut options = NdJsonReadOptions::default()
783            .file_compression_type(parse_file_compression_type(file_compression_type)?)
784            .table_partition_cols(
785                table_partition_cols
786                    .into_iter()
787                    .map(|(name, ty)| (name, ty.0))
788                    .collect::<Vec<(String, DataType)>>(),
789            );
790        options.schema_infer_max_records = schema_infer_max_records;
791        options.file_extension = file_extension;
792        options.schema = schema.as_ref().map(|x| &x.0);
793
794        let result = self.ctx.register_json(name, path, options);
795        wait_for_future(py, result)??;
796
797        Ok(())
798    }
799
800    #[allow(clippy::too_many_arguments)]
801    #[pyo3(signature = (name,
802                        path,
803                        schema=None,
804                        file_extension=".avro",
805                        table_partition_cols=vec![]))]
806    pub fn register_avro(
807        &self,
808        name: &str,
809        path: PathBuf,
810        schema: Option<PyArrowType<Schema>>,
811        file_extension: &str,
812        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
813        py: Python,
814    ) -> PyDataFusionResult<()> {
815        let path = path
816            .to_str()
817            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
818
819        let mut options = AvroReadOptions::default().table_partition_cols(
820            table_partition_cols
821                .into_iter()
822                .map(|(name, ty)| (name, ty.0))
823                .collect::<Vec<(String, DataType)>>(),
824        );
825        options.file_extension = file_extension;
826        options.schema = schema.as_ref().map(|x| &x.0);
827
828        let result = self.ctx.register_avro(name, path, options);
829        wait_for_future(py, result)??;
830
831        Ok(())
832    }
833
834    // Registers a PyArrow.Dataset
835    pub fn register_dataset(
836        &self,
837        name: &str,
838        dataset: &Bound<'_, PyAny>,
839        py: Python,
840    ) -> PyDataFusionResult<()> {
841        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
842
843        self.ctx.register_table(name, table)?;
844
845        Ok(())
846    }
847
848    pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
849        self.ctx.register_udf(udf.function);
850        Ok(())
851    }
852
853    pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
854        self.ctx.register_udaf(udaf.function);
855        Ok(())
856    }
857
858    pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
859        self.ctx.register_udwf(udwf.function);
860        Ok(())
861    }
862
863    #[pyo3(signature = (name="datafusion"))]
864    pub fn catalog(&self, name: &str) -> PyResult<Py<PyAny>> {
865        let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
866            "Catalog with name {name} doesn't exist."
867        )))?;
868
869        Python::attach(|py| {
870            match catalog
871                .as_any()
872                .downcast_ref::<RustWrappedPyCatalogProvider>()
873            {
874                Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
875                None => PyCatalog::from(catalog).into_py_any(py),
876            }
877        })
878    }
879
880    pub fn catalog_names(&self) -> HashSet<String> {
881        self.ctx.catalog_names().into_iter().collect()
882    }
883
884    pub fn tables(&self) -> HashSet<String> {
885        self.ctx
886            .catalog_names()
887            .into_iter()
888            .filter_map(|name| self.ctx.catalog(&name))
889            .flat_map(move |catalog| {
890                catalog
891                    .schema_names()
892                    .into_iter()
893                    .filter_map(move |name| catalog.schema(&name))
894            })
895            .flat_map(|schema| schema.table_names())
896            .collect()
897    }
898
899    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
900        let res = wait_for_future(py, self.ctx.table(name))
901            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
902        match res {
903            Ok(df) => Ok(PyDataFrame::new(df)),
904            Err(e) => {
905                if let datafusion::error::DataFusionError::Plan(msg) = &e {
906                    if msg.contains("No table named") {
907                        return Err(PyKeyError::new_err(msg.to_string()));
908                    }
909                }
910                Err(py_datafusion_err(e))
911            }
912        }
913    }
914
915    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
916        Ok(self.ctx.table_exist(name)?)
917    }
918
919    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
920        Ok(PyDataFrame::new(self.ctx.read_empty()?))
921    }
922
923    pub fn session_id(&self) -> String {
924        self.ctx.session_id()
925    }
926
927    #[allow(clippy::too_many_arguments)]
928    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
929    pub fn read_json(
930        &self,
931        path: PathBuf,
932        schema: Option<PyArrowType<Schema>>,
933        schema_infer_max_records: usize,
934        file_extension: &str,
935        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
936        file_compression_type: Option<String>,
937        py: Python,
938    ) -> PyDataFusionResult<PyDataFrame> {
939        let path = path
940            .to_str()
941            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
942        let mut options = NdJsonReadOptions::default()
943            .table_partition_cols(
944                table_partition_cols
945                    .into_iter()
946                    .map(|(name, ty)| (name, ty.0))
947                    .collect::<Vec<(String, DataType)>>(),
948            )
949            .file_compression_type(parse_file_compression_type(file_compression_type)?);
950        options.schema_infer_max_records = schema_infer_max_records;
951        options.file_extension = file_extension;
952        let df = if let Some(schema) = schema {
953            options.schema = Some(&schema.0);
954            let result = self.ctx.read_json(path, options);
955            wait_for_future(py, result)??
956        } else {
957            let result = self.ctx.read_json(path, options);
958            wait_for_future(py, result)??
959        };
960        Ok(PyDataFrame::new(df))
961    }
962
963    #[allow(clippy::too_many_arguments)]
964    #[pyo3(signature = (
965        path,
966        schema=None,
967        has_header=true,
968        delimiter=",",
969        schema_infer_max_records=1000,
970        file_extension=".csv",
971        table_partition_cols=vec![],
972        file_compression_type=None))]
973    pub fn read_csv(
974        &self,
975        path: &Bound<'_, PyAny>,
976        schema: Option<PyArrowType<Schema>>,
977        has_header: bool,
978        delimiter: &str,
979        schema_infer_max_records: usize,
980        file_extension: &str,
981        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
982        file_compression_type: Option<String>,
983        py: Python,
984    ) -> PyDataFusionResult<PyDataFrame> {
985        let delimiter = delimiter.as_bytes();
986        if delimiter.len() != 1 {
987            return Err(PyDataFusionError::PythonError(py_value_err(
988                "Delimiter must be a single character",
989            )));
990        };
991
992        let mut options = CsvReadOptions::new()
993            .has_header(has_header)
994            .delimiter(delimiter[0])
995            .schema_infer_max_records(schema_infer_max_records)
996            .file_extension(file_extension)
997            .table_partition_cols(
998                table_partition_cols
999                    .into_iter()
1000                    .map(|(name, ty)| (name, ty.0))
1001                    .collect::<Vec<(String, DataType)>>(),
1002            )
1003            .file_compression_type(parse_file_compression_type(file_compression_type)?);
1004        options.schema = schema.as_ref().map(|x| &x.0);
1005
1006        if path.is_instance_of::<PyList>() {
1007            let paths = path.extract::<Vec<String>>()?;
1008            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1009            let result = self.ctx.read_csv(paths, options);
1010            let df = PyDataFrame::new(wait_for_future(py, result)??);
1011            Ok(df)
1012        } else {
1013            let path = path.extract::<String>()?;
1014            let result = self.ctx.read_csv(path, options);
1015            let df = PyDataFrame::new(wait_for_future(py, result)??);
1016            Ok(df)
1017        }
1018    }
1019
1020    #[allow(clippy::too_many_arguments)]
1021    #[pyo3(signature = (
1022        path,
1023        table_partition_cols=vec![],
1024        parquet_pruning=true,
1025        file_extension=".parquet",
1026        skip_metadata=true,
1027        schema=None,
1028        file_sort_order=None))]
1029    pub fn read_parquet(
1030        &self,
1031        path: &str,
1032        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1033        parquet_pruning: bool,
1034        file_extension: &str,
1035        skip_metadata: bool,
1036        schema: Option<PyArrowType<Schema>>,
1037        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1038        py: Python,
1039    ) -> PyDataFusionResult<PyDataFrame> {
1040        let mut options = ParquetReadOptions::default()
1041            .table_partition_cols(
1042                table_partition_cols
1043                    .into_iter()
1044                    .map(|(name, ty)| (name, ty.0))
1045                    .collect::<Vec<(String, DataType)>>(),
1046            )
1047            .parquet_pruning(parquet_pruning)
1048            .skip_metadata(skip_metadata);
1049        options.file_extension = file_extension;
1050        options.schema = schema.as_ref().map(|x| &x.0);
1051        options.file_sort_order = file_sort_order
1052            .unwrap_or_default()
1053            .into_iter()
1054            .map(|e| e.into_iter().map(|f| f.into()).collect())
1055            .collect();
1056
1057        let result = self.ctx.read_parquet(path, options);
1058        let df = PyDataFrame::new(wait_for_future(py, result)??);
1059        Ok(df)
1060    }
1061
1062    #[allow(clippy::too_many_arguments)]
1063    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1064    pub fn read_avro(
1065        &self,
1066        path: &str,
1067        schema: Option<PyArrowType<Schema>>,
1068        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1069        file_extension: &str,
1070        py: Python,
1071    ) -> PyDataFusionResult<PyDataFrame> {
1072        let mut options = AvroReadOptions::default().table_partition_cols(
1073            table_partition_cols
1074                .into_iter()
1075                .map(|(name, ty)| (name, ty.0))
1076                .collect::<Vec<(String, DataType)>>(),
1077        );
1078        options.file_extension = file_extension;
1079        let df = if let Some(schema) = schema {
1080            options.schema = Some(&schema.0);
1081            let read_future = self.ctx.read_avro(path, options);
1082            wait_for_future(py, read_future)??
1083        } else {
1084            let read_future = self.ctx.read_avro(path, options);
1085            wait_for_future(py, read_future)??
1086        };
1087        Ok(PyDataFrame::new(df))
1088    }
1089
1090    pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1091        let table = PyTable::new(&table)?;
1092        let df = self.ctx.read_table(table.table())?;
1093        Ok(PyDataFrame::new(df))
1094    }
1095
1096    fn __repr__(&self) -> PyResult<String> {
1097        let config = self.ctx.copied_config();
1098        let mut config_entries = config
1099            .options()
1100            .entries()
1101            .iter()
1102            .filter(|e| e.value.is_some())
1103            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1104            .collect::<Vec<_>>();
1105        config_entries.sort();
1106        Ok(format!(
1107            "SessionContext: id={}; configs=[\n\t{}]",
1108            self.session_id(),
1109            config_entries.join("\n\t")
1110        ))
1111    }
1112
1113    /// Execute a partition of an execution plan and return a stream of record batches
1114    pub fn execute(
1115        &self,
1116        plan: PyExecutionPlan,
1117        part: usize,
1118        py: Python,
1119    ) -> PyDataFusionResult<PyRecordBatchStream> {
1120        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1121        let plan = plan.plan.clone();
1122        let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1123        Ok(PyRecordBatchStream::new(stream))
1124    }
1125}
1126
1127impl PySessionContext {
1128    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1129        self.ctx.table(name).await
1130    }
1131
1132    async fn register_csv_from_multiple_paths(
1133        &self,
1134        name: &str,
1135        table_paths: Vec<String>,
1136        options: CsvReadOptions<'_>,
1137    ) -> datafusion::common::Result<()> {
1138        let table_paths = table_paths.to_urls()?;
1139        let session_config = self.ctx.copied_config();
1140        let listing_options =
1141            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1142
1143        let option_extension = listing_options.file_extension.clone();
1144
1145        if table_paths.is_empty() {
1146            return exec_err!("No table paths were provided");
1147        }
1148
1149        // check if the file extension matches the expected extension
1150        for path in &table_paths {
1151            let file_path = path.as_str();
1152            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1153                return exec_err!(
1154                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1155                );
1156            }
1157        }
1158
1159        let resolved_schema = options
1160            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1161            .await?;
1162
1163        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1164            .with_listing_options(listing_options)
1165            .with_schema(resolved_schema);
1166        let table = ListingTable::try_new(config)?;
1167        self.ctx
1168            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1169        Ok(())
1170    }
1171}
1172
1173pub fn parse_file_compression_type(
1174    file_compression_type: Option<String>,
1175) -> Result<FileCompressionType, PyErr> {
1176    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1177        .map_err(|_| {
1178            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1179        })
1180}
1181
1182impl From<PySessionContext> for SessionContext {
1183    fn from(ctx: PySessionContext) -> SessionContext {
1184        ctx.ctx
1185    }
1186}
1187
1188impl From<SessionContext> for PySessionContext {
1189    fn from(ctx: SessionContext) -> PySessionContext {
1190        PySessionContext { ctx }
1191    }
1192}