datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udwf::PyWindowUDF;
47use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
49use datafusion::arrow::pyarrow::PyArrowType;
50use datafusion::arrow::record_batch::RecordBatch;
51use datafusion::common::TableReference;
52use datafusion::common::{exec_err, ScalarValue};
53use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
54use datafusion::datasource::file_format::parquet::ParquetFormat;
55use datafusion::datasource::listing::{
56    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
57};
58use datafusion::datasource::MemTable;
59use datafusion::datasource::TableProvider;
60use datafusion::execution::context::{
61    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
62};
63use datafusion::execution::disk_manager::DiskManagerConfig;
64use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
65use datafusion::execution::options::ReadOptions;
66use datafusion::execution::runtime_env::RuntimeEnvBuilder;
67use datafusion::physical_plan::SendableRecordBatchStream;
68use datafusion::prelude::{
69    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
70};
71use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
72use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
73use tokio::task::JoinHandle;
74
75/// Configuration options for a SessionContext
76#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
77#[derive(Clone, Default)]
78pub struct PySessionConfig {
79    pub config: SessionConfig,
80}
81
82impl From<SessionConfig> for PySessionConfig {
83    fn from(config: SessionConfig) -> Self {
84        Self { config }
85    }
86}
87
88#[pymethods]
89impl PySessionConfig {
90    #[pyo3(signature = (config_options=None))]
91    #[new]
92    fn new(config_options: Option<HashMap<String, String>>) -> Self {
93        let mut config = SessionConfig::new();
94        if let Some(hash_map) = config_options {
95            for (k, v) in &hash_map {
96                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
97            }
98        }
99
100        Self { config }
101    }
102
103    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
104        Self::from(
105            self.config
106                .clone()
107                .with_create_default_catalog_and_schema(enabled),
108        )
109    }
110
111    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
112        Self::from(
113            self.config
114                .clone()
115                .with_default_catalog_and_schema(catalog, schema),
116        )
117    }
118
119    fn with_information_schema(&self, enabled: bool) -> Self {
120        Self::from(self.config.clone().with_information_schema(enabled))
121    }
122
123    fn with_batch_size(&self, batch_size: usize) -> Self {
124        Self::from(self.config.clone().with_batch_size(batch_size))
125    }
126
127    fn with_target_partitions(&self, target_partitions: usize) -> Self {
128        Self::from(
129            self.config
130                .clone()
131                .with_target_partitions(target_partitions),
132        )
133    }
134
135    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
136        Self::from(self.config.clone().with_repartition_aggregations(enabled))
137    }
138
139    fn with_repartition_joins(&self, enabled: bool) -> Self {
140        Self::from(self.config.clone().with_repartition_joins(enabled))
141    }
142
143    fn with_repartition_windows(&self, enabled: bool) -> Self {
144        Self::from(self.config.clone().with_repartition_windows(enabled))
145    }
146
147    fn with_repartition_sorts(&self, enabled: bool) -> Self {
148        Self::from(self.config.clone().with_repartition_sorts(enabled))
149    }
150
151    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
152        Self::from(self.config.clone().with_repartition_file_scans(enabled))
153    }
154
155    fn with_repartition_file_min_size(&self, size: usize) -> Self {
156        Self::from(self.config.clone().with_repartition_file_min_size(size))
157    }
158
159    fn with_parquet_pruning(&self, enabled: bool) -> Self {
160        Self::from(self.config.clone().with_parquet_pruning(enabled))
161    }
162
163    fn set(&self, key: &str, value: &str) -> Self {
164        Self::from(self.config.clone().set_str(key, value))
165    }
166}
167
168/// Runtime options for a SessionContext
169#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
170#[derive(Clone)]
171pub struct PyRuntimeEnvBuilder {
172    pub builder: RuntimeEnvBuilder,
173}
174
175#[pymethods]
176impl PyRuntimeEnvBuilder {
177    #[new]
178    fn new() -> Self {
179        Self {
180            builder: RuntimeEnvBuilder::default(),
181        }
182    }
183
184    fn with_disk_manager_disabled(&self) -> Self {
185        let mut builder = self.builder.clone();
186        builder = builder.with_disk_manager(DiskManagerConfig::Disabled);
187        Self { builder }
188    }
189
190    fn with_disk_manager_os(&self) -> Self {
191        let builder = self.builder.clone();
192        let builder = builder.with_disk_manager(DiskManagerConfig::NewOs);
193        Self { builder }
194    }
195
196    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
197        let builder = self.builder.clone();
198        let paths = paths.iter().map(|s| s.into()).collect();
199        let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths));
200        Self { builder }
201    }
202
203    fn with_unbounded_memory_pool(&self) -> Self {
204        let builder = self.builder.clone();
205        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
206        Self { builder }
207    }
208
209    fn with_fair_spill_pool(&self, size: usize) -> Self {
210        let builder = self.builder.clone();
211        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
212        Self { builder }
213    }
214
215    fn with_greedy_memory_pool(&self, size: usize) -> Self {
216        let builder = self.builder.clone();
217        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
218        Self { builder }
219    }
220
221    fn with_temp_file_path(&self, path: &str) -> Self {
222        let builder = self.builder.clone();
223        let builder = builder.with_temp_file_path(path);
224        Self { builder }
225    }
226}
227
228/// `PySQLOptions` allows you to specify options to the sql execution.
229#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
230#[derive(Clone)]
231pub struct PySQLOptions {
232    pub options: SQLOptions,
233}
234
235impl From<SQLOptions> for PySQLOptions {
236    fn from(options: SQLOptions) -> Self {
237        Self { options }
238    }
239}
240
241#[pymethods]
242impl PySQLOptions {
243    #[new]
244    fn new() -> Self {
245        let options = SQLOptions::new();
246        Self { options }
247    }
248
249    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
250    fn with_allow_ddl(&self, allow: bool) -> Self {
251        Self::from(self.options.with_allow_ddl(allow))
252    }
253
254    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
255    pub fn with_allow_dml(&self, allow: bool) -> Self {
256        Self::from(self.options.with_allow_dml(allow))
257    }
258
259    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
260    pub fn with_allow_statements(&self, allow: bool) -> Self {
261        Self::from(self.options.with_allow_statements(allow))
262    }
263}
264
265/// `PySessionContext` is able to plan and execute DataFusion plans.
266/// It has a powerful optimizer, a physical planner for local execution, and a
267/// multi-threaded execution engine to perform the execution.
268#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
269#[derive(Clone)]
270pub struct PySessionContext {
271    pub ctx: SessionContext,
272}
273
274#[pymethods]
275impl PySessionContext {
276    #[pyo3(signature = (config=None, runtime=None))]
277    #[new]
278    pub fn new(
279        config: Option<PySessionConfig>,
280        runtime: Option<PyRuntimeEnvBuilder>,
281    ) -> PyDataFusionResult<Self> {
282        let config = if let Some(c) = config {
283            c.config
284        } else {
285            SessionConfig::default().with_information_schema(true)
286        };
287        let runtime_env_builder = if let Some(c) = runtime {
288            c.builder
289        } else {
290            RuntimeEnvBuilder::default()
291        };
292        let runtime = Arc::new(runtime_env_builder.build()?);
293        let session_state = SessionStateBuilder::new()
294            .with_config(config)
295            .with_runtime_env(runtime)
296            .with_default_features()
297            .build();
298        Ok(PySessionContext {
299            ctx: SessionContext::new_with_state(session_state),
300        })
301    }
302
303    pub fn enable_url_table(&self) -> PyResult<Self> {
304        Ok(PySessionContext {
305            ctx: self.ctx.clone().enable_url_table(),
306        })
307    }
308
309    #[classmethod]
310    #[pyo3(signature = ())]
311    fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
312        Ok(Self {
313            ctx: get_global_ctx().clone(),
314        })
315    }
316
317    /// Register an object store with the given name
318    #[pyo3(signature = (scheme, store, host=None))]
319    pub fn register_object_store(
320        &mut self,
321        scheme: &str,
322        store: StorageContexts,
323        host: Option<&str>,
324    ) -> PyResult<()> {
325        // for most stores the "host" is the bucket name and can be inferred from the store
326        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
327            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
328            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
329            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
330            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
331            StorageContexts::HTTP(http) => (http.store, http.url),
332        };
333
334        // let users override the host to match the api signature from upstream
335        let derived_host = if let Some(host) = host {
336            host
337        } else {
338            &upstream_host
339        };
340        let url_string = format!("{}{}", scheme, derived_host);
341        let url = Url::parse(&url_string).unwrap();
342        self.ctx.runtime_env().register_object_store(&url, store);
343        Ok(())
344    }
345
346    #[allow(clippy::too_many_arguments)]
347    #[pyo3(signature = (name, path, table_partition_cols=vec![],
348    file_extension=".parquet",
349    schema=None,
350    file_sort_order=None))]
351    pub fn register_listing_table(
352        &mut self,
353        name: &str,
354        path: &str,
355        table_partition_cols: Vec<(String, String)>,
356        file_extension: &str,
357        schema: Option<PyArrowType<Schema>>,
358        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
359        py: Python,
360    ) -> PyDataFusionResult<()> {
361        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
362            .with_file_extension(file_extension)
363            .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
364            .with_file_sort_order(
365                file_sort_order
366                    .unwrap_or_default()
367                    .into_iter()
368                    .map(|e| e.into_iter().map(|f| f.into()).collect())
369                    .collect(),
370            );
371        let table_path = ListingTableUrl::parse(path)?;
372        let resolved_schema: SchemaRef = match schema {
373            Some(s) => Arc::new(s.0),
374            None => {
375                let state = self.ctx.state();
376                let schema = options.infer_schema(&state, &table_path);
377                wait_for_future(py, schema)?
378            }
379        };
380        let config = ListingTableConfig::new(table_path)
381            .with_listing_options(options)
382            .with_schema(resolved_schema);
383        let table = ListingTable::try_new(config)?;
384        self.register_table(
385            name,
386            &PyTable {
387                table: Arc::new(table),
388            },
389        )?;
390        Ok(())
391    }
392
393    /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
394    pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
395        let result = self.ctx.sql(query);
396        let df = wait_for_future(py, result)?;
397        Ok(PyDataFrame::new(df))
398    }
399
400    #[pyo3(signature = (query, options=None))]
401    pub fn sql_with_options(
402        &mut self,
403        query: &str,
404        options: Option<PySQLOptions>,
405        py: Python,
406    ) -> PyDataFusionResult<PyDataFrame> {
407        let options = if let Some(options) = options {
408            options.options
409        } else {
410            SQLOptions::new()
411        };
412        let result = self.ctx.sql_with_options(query, options);
413        let df = wait_for_future(py, result)?;
414        Ok(PyDataFrame::new(df))
415    }
416
417    #[pyo3(signature = (partitions, name=None, schema=None))]
418    pub fn create_dataframe(
419        &mut self,
420        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
421        name: Option<&str>,
422        schema: Option<PyArrowType<Schema>>,
423        py: Python,
424    ) -> PyDataFusionResult<PyDataFrame> {
425        let schema = if let Some(schema) = schema {
426            SchemaRef::from(schema.0)
427        } else {
428            partitions.0[0][0].schema()
429        };
430
431        let table = MemTable::try_new(schema, partitions.0)?;
432
433        // generate a random (unique) name for this table if none is provided
434        // table name cannot start with numeric digit
435        let table_name = match name {
436            Some(val) => val.to_owned(),
437            None => {
438                "c".to_owned()
439                    + Uuid::new_v4()
440                        .simple()
441                        .encode_lower(&mut Uuid::encode_buffer())
442            }
443        };
444
445        self.ctx.register_table(&*table_name, Arc::new(table))?;
446
447        let table = wait_for_future(py, self._table(&table_name))?;
448
449        let df = PyDataFrame::new(table);
450        Ok(df)
451    }
452
453    /// Create a DataFrame from an existing logical plan
454    pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
455        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
456    }
457
458    /// Construct datafusion dataframe from Python list
459    #[pyo3(signature = (data, name=None))]
460    pub fn from_pylist(
461        &mut self,
462        data: Bound<'_, PyList>,
463        name: Option<&str>,
464    ) -> PyResult<PyDataFrame> {
465        // Acquire GIL Token
466        let py = data.py();
467
468        // Instantiate pyarrow Table object & convert to Arrow Table
469        let table_class = py.import("pyarrow")?.getattr("Table")?;
470        let args = PyTuple::new(py, &[data])?;
471        let table = table_class.call_method1("from_pylist", args)?;
472
473        // Convert Arrow Table to datafusion DataFrame
474        let df = self.from_arrow(table, name, py)?;
475        Ok(df)
476    }
477
478    /// Construct datafusion dataframe from Python dictionary
479    #[pyo3(signature = (data, name=None))]
480    pub fn from_pydict(
481        &mut self,
482        data: Bound<'_, PyDict>,
483        name: Option<&str>,
484    ) -> PyResult<PyDataFrame> {
485        // Acquire GIL Token
486        let py = data.py();
487
488        // Instantiate pyarrow Table object & convert to Arrow Table
489        let table_class = py.import("pyarrow")?.getattr("Table")?;
490        let args = PyTuple::new(py, &[data])?;
491        let table = table_class.call_method1("from_pydict", args)?;
492
493        // Convert Arrow Table to datafusion DataFrame
494        let df = self.from_arrow(table, name, py)?;
495        Ok(df)
496    }
497
498    /// Construct datafusion dataframe from Arrow Table
499    #[pyo3(signature = (data, name=None))]
500    pub fn from_arrow(
501        &mut self,
502        data: Bound<'_, PyAny>,
503        name: Option<&str>,
504        py: Python,
505    ) -> PyDataFusionResult<PyDataFrame> {
506        let (schema, batches) =
507            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
508                // Works for any object that implements __arrow_c_stream__ in pycapsule.
509
510                let schema = stream_reader.schema().as_ref().to_owned();
511                let batches = stream_reader
512                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
513
514                (schema, batches)
515            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
516                // While this says RecordBatch, it will work for any object that implements
517                // __arrow_c_array__ and returns a StructArray.
518
519                (array.schema().as_ref().to_owned(), vec![array])
520            } else {
521                return Err(crate::errors::PyDataFusionError::Common(
522                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
523                ));
524            };
525
526        // Because create_dataframe() expects a vector of vectors of record batches
527        // here we need to wrap the vector of record batches in an additional vector
528        let list_of_batches = PyArrowType::from(vec![batches]);
529        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
530    }
531
532    /// Construct datafusion dataframe from pandas
533    #[allow(clippy::wrong_self_convention)]
534    #[pyo3(signature = (data, name=None))]
535    pub fn from_pandas(
536        &mut self,
537        data: Bound<'_, PyAny>,
538        name: Option<&str>,
539    ) -> PyResult<PyDataFrame> {
540        // Obtain GIL token
541        let py = data.py();
542
543        // Instantiate pyarrow Table object & convert to Arrow Table
544        let table_class = py.import("pyarrow")?.getattr("Table")?;
545        let args = PyTuple::new(py, &[data])?;
546        let table = table_class.call_method1("from_pandas", args)?;
547
548        // Convert Arrow Table to datafusion DataFrame
549        let df = self.from_arrow(table, name, py)?;
550        Ok(df)
551    }
552
553    /// Construct datafusion dataframe from polars
554    #[pyo3(signature = (data, name=None))]
555    pub fn from_polars(
556        &mut self,
557        data: Bound<'_, PyAny>,
558        name: Option<&str>,
559    ) -> PyResult<PyDataFrame> {
560        // Convert Polars dataframe to Arrow Table
561        let table = data.call_method0("to_arrow")?;
562
563        // Convert Arrow Table to datafusion DataFrame
564        let df = self.from_arrow(table, name, data.py())?;
565        Ok(df)
566    }
567
568    pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
569        self.ctx.register_table(name, table.table())?;
570        Ok(())
571    }
572
573    pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
574        self.ctx.deregister_table(name)?;
575        Ok(())
576    }
577
578    /// Construct datafusion dataframe from Arrow Table
579    pub fn register_table_provider(
580        &mut self,
581        name: &str,
582        provider: Bound<'_, PyAny>,
583    ) -> PyDataFusionResult<()> {
584        if provider.hasattr("__datafusion_table_provider__")? {
585            let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
586            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
587            validate_pycapsule(capsule, "datafusion_table_provider")?;
588
589            let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
590            let provider: ForeignTableProvider = provider.into();
591
592            let _ = self.ctx.register_table(name, Arc::new(provider))?;
593
594            Ok(())
595        } else {
596            Err(crate::errors::PyDataFusionError::Common(
597                "__datafusion_table_provider__ does not exist on Table Provider object."
598                    .to_string(),
599            ))
600        }
601    }
602
603    pub fn register_record_batches(
604        &mut self,
605        name: &str,
606        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
607    ) -> PyDataFusionResult<()> {
608        let schema = partitions.0[0][0].schema();
609        let table = MemTable::try_new(schema, partitions.0)?;
610        self.ctx.register_table(name, Arc::new(table))?;
611        Ok(())
612    }
613
614    #[allow(clippy::too_many_arguments)]
615    #[pyo3(signature = (name, path, table_partition_cols=vec![],
616                        parquet_pruning=true,
617                        file_extension=".parquet",
618                        skip_metadata=true,
619                        schema=None,
620                        file_sort_order=None))]
621    pub fn register_parquet(
622        &mut self,
623        name: &str,
624        path: &str,
625        table_partition_cols: Vec<(String, String)>,
626        parquet_pruning: bool,
627        file_extension: &str,
628        skip_metadata: bool,
629        schema: Option<PyArrowType<Schema>>,
630        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
631        py: Python,
632    ) -> PyDataFusionResult<()> {
633        let mut options = ParquetReadOptions::default()
634            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
635            .parquet_pruning(parquet_pruning)
636            .skip_metadata(skip_metadata);
637        options.file_extension = file_extension;
638        options.schema = schema.as_ref().map(|x| &x.0);
639        options.file_sort_order = file_sort_order
640            .unwrap_or_default()
641            .into_iter()
642            .map(|e| e.into_iter().map(|f| f.into()).collect())
643            .collect();
644
645        let result = self.ctx.register_parquet(name, path, options);
646        wait_for_future(py, result)?;
647        Ok(())
648    }
649
650    #[allow(clippy::too_many_arguments)]
651    #[pyo3(signature = (name,
652                        path,
653                        schema=None,
654                        has_header=true,
655                        delimiter=",",
656                        schema_infer_max_records=1000,
657                        file_extension=".csv",
658                        file_compression_type=None))]
659    pub fn register_csv(
660        &mut self,
661        name: &str,
662        path: &Bound<'_, PyAny>,
663        schema: Option<PyArrowType<Schema>>,
664        has_header: bool,
665        delimiter: &str,
666        schema_infer_max_records: usize,
667        file_extension: &str,
668        file_compression_type: Option<String>,
669        py: Python,
670    ) -> PyDataFusionResult<()> {
671        let delimiter = delimiter.as_bytes();
672        if delimiter.len() != 1 {
673            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
674                "Delimiter must be a single character",
675            )));
676        }
677
678        let mut options = CsvReadOptions::new()
679            .has_header(has_header)
680            .delimiter(delimiter[0])
681            .schema_infer_max_records(schema_infer_max_records)
682            .file_extension(file_extension)
683            .file_compression_type(parse_file_compression_type(file_compression_type)?);
684        options.schema = schema.as_ref().map(|x| &x.0);
685
686        if path.is_instance_of::<PyList>() {
687            let paths = path.extract::<Vec<String>>()?;
688            let result = self.register_csv_from_multiple_paths(name, paths, options);
689            wait_for_future(py, result)?;
690        } else {
691            let path = path.extract::<String>()?;
692            let result = self.ctx.register_csv(name, &path, options);
693            wait_for_future(py, result)?;
694        }
695
696        Ok(())
697    }
698
699    #[allow(clippy::too_many_arguments)]
700    #[pyo3(signature = (name,
701                        path,
702                        schema=None,
703                        schema_infer_max_records=1000,
704                        file_extension=".json",
705                        table_partition_cols=vec![],
706                        file_compression_type=None))]
707    pub fn register_json(
708        &mut self,
709        name: &str,
710        path: PathBuf,
711        schema: Option<PyArrowType<Schema>>,
712        schema_infer_max_records: usize,
713        file_extension: &str,
714        table_partition_cols: Vec<(String, String)>,
715        file_compression_type: Option<String>,
716        py: Python,
717    ) -> PyDataFusionResult<()> {
718        let path = path
719            .to_str()
720            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
721
722        let mut options = NdJsonReadOptions::default()
723            .file_compression_type(parse_file_compression_type(file_compression_type)?)
724            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
725        options.schema_infer_max_records = schema_infer_max_records;
726        options.file_extension = file_extension;
727        options.schema = schema.as_ref().map(|x| &x.0);
728
729        let result = self.ctx.register_json(name, path, options);
730        wait_for_future(py, result)?;
731
732        Ok(())
733    }
734
735    #[allow(clippy::too_many_arguments)]
736    #[pyo3(signature = (name,
737                        path,
738                        schema=None,
739                        file_extension=".avro",
740                        table_partition_cols=vec![]))]
741    pub fn register_avro(
742        &mut self,
743        name: &str,
744        path: PathBuf,
745        schema: Option<PyArrowType<Schema>>,
746        file_extension: &str,
747        table_partition_cols: Vec<(String, String)>,
748        py: Python,
749    ) -> PyDataFusionResult<()> {
750        let path = path
751            .to_str()
752            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
753
754        let mut options = AvroReadOptions::default()
755            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
756        options.file_extension = file_extension;
757        options.schema = schema.as_ref().map(|x| &x.0);
758
759        let result = self.ctx.register_avro(name, path, options);
760        wait_for_future(py, result)?;
761
762        Ok(())
763    }
764
765    // Registers a PyArrow.Dataset
766    pub fn register_dataset(
767        &self,
768        name: &str,
769        dataset: &Bound<'_, PyAny>,
770        py: Python,
771    ) -> PyDataFusionResult<()> {
772        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
773
774        self.ctx.register_table(name, table)?;
775
776        Ok(())
777    }
778
779    pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
780        self.ctx.register_udf(udf.function);
781        Ok(())
782    }
783
784    pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
785        self.ctx.register_udaf(udaf.function);
786        Ok(())
787    }
788
789    pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
790        self.ctx.register_udwf(udwf.function);
791        Ok(())
792    }
793
794    #[pyo3(signature = (name="datafusion"))]
795    pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
796        match self.ctx.catalog(name) {
797            Some(catalog) => Ok(PyCatalog::new(catalog)),
798            None => Err(PyKeyError::new_err(format!(
799                "Catalog with name {} doesn't exist.",
800                &name,
801            ))),
802        }
803    }
804
805    pub fn tables(&self) -> HashSet<String> {
806        self.ctx
807            .catalog_names()
808            .into_iter()
809            .filter_map(|name| self.ctx.catalog(&name))
810            .flat_map(move |catalog| {
811                catalog
812                    .schema_names()
813                    .into_iter()
814                    .filter_map(move |name| catalog.schema(&name))
815            })
816            .flat_map(|schema| schema.table_names())
817            .collect()
818    }
819
820    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
821        let x = wait_for_future(py, self.ctx.table(name))
822            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
823        Ok(PyDataFrame::new(x))
824    }
825
826    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
827        Ok(self.ctx.table_exist(name)?)
828    }
829
830    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
831        Ok(PyDataFrame::new(self.ctx.read_empty()?))
832    }
833
834    pub fn session_id(&self) -> String {
835        self.ctx.session_id()
836    }
837
838    #[allow(clippy::too_many_arguments)]
839    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
840    pub fn read_json(
841        &mut self,
842        path: PathBuf,
843        schema: Option<PyArrowType<Schema>>,
844        schema_infer_max_records: usize,
845        file_extension: &str,
846        table_partition_cols: Vec<(String, String)>,
847        file_compression_type: Option<String>,
848        py: Python,
849    ) -> PyDataFusionResult<PyDataFrame> {
850        let path = path
851            .to_str()
852            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
853        let mut options = NdJsonReadOptions::default()
854            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
855            .file_compression_type(parse_file_compression_type(file_compression_type)?);
856        options.schema_infer_max_records = schema_infer_max_records;
857        options.file_extension = file_extension;
858        let df = if let Some(schema) = schema {
859            options.schema = Some(&schema.0);
860            let result = self.ctx.read_json(path, options);
861            wait_for_future(py, result)?
862        } else {
863            let result = self.ctx.read_json(path, options);
864            wait_for_future(py, result)?
865        };
866        Ok(PyDataFrame::new(df))
867    }
868
869    #[allow(clippy::too_many_arguments)]
870    #[pyo3(signature = (
871        path,
872        schema=None,
873        has_header=true,
874        delimiter=",",
875        schema_infer_max_records=1000,
876        file_extension=".csv",
877        table_partition_cols=vec![],
878        file_compression_type=None))]
879    pub fn read_csv(
880        &self,
881        path: &Bound<'_, PyAny>,
882        schema: Option<PyArrowType<Schema>>,
883        has_header: bool,
884        delimiter: &str,
885        schema_infer_max_records: usize,
886        file_extension: &str,
887        table_partition_cols: Vec<(String, String)>,
888        file_compression_type: Option<String>,
889        py: Python,
890    ) -> PyDataFusionResult<PyDataFrame> {
891        let delimiter = delimiter.as_bytes();
892        if delimiter.len() != 1 {
893            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
894                "Delimiter must be a single character",
895            )));
896        };
897
898        let mut options = CsvReadOptions::new()
899            .has_header(has_header)
900            .delimiter(delimiter[0])
901            .schema_infer_max_records(schema_infer_max_records)
902            .file_extension(file_extension)
903            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
904            .file_compression_type(parse_file_compression_type(file_compression_type)?);
905        options.schema = schema.as_ref().map(|x| &x.0);
906
907        if path.is_instance_of::<PyList>() {
908            let paths = path.extract::<Vec<String>>()?;
909            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
910            let result = self.ctx.read_csv(paths, options);
911            let df = PyDataFrame::new(wait_for_future(py, result)?);
912            Ok(df)
913        } else {
914            let path = path.extract::<String>()?;
915            let result = self.ctx.read_csv(path, options);
916            let df = PyDataFrame::new(wait_for_future(py, result)?);
917            Ok(df)
918        }
919    }
920
921    #[allow(clippy::too_many_arguments)]
922    #[pyo3(signature = (
923        path,
924        table_partition_cols=vec![],
925        parquet_pruning=true,
926        file_extension=".parquet",
927        skip_metadata=true,
928        schema=None,
929        file_sort_order=None))]
930    pub fn read_parquet(
931        &self,
932        path: &str,
933        table_partition_cols: Vec<(String, String)>,
934        parquet_pruning: bool,
935        file_extension: &str,
936        skip_metadata: bool,
937        schema: Option<PyArrowType<Schema>>,
938        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
939        py: Python,
940    ) -> PyDataFusionResult<PyDataFrame> {
941        let mut options = ParquetReadOptions::default()
942            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
943            .parquet_pruning(parquet_pruning)
944            .skip_metadata(skip_metadata);
945        options.file_extension = file_extension;
946        options.schema = schema.as_ref().map(|x| &x.0);
947        options.file_sort_order = file_sort_order
948            .unwrap_or_default()
949            .into_iter()
950            .map(|e| e.into_iter().map(|f| f.into()).collect())
951            .collect();
952
953        let result = self.ctx.read_parquet(path, options);
954        let df = PyDataFrame::new(wait_for_future(py, result)?);
955        Ok(df)
956    }
957
958    #[allow(clippy::too_many_arguments)]
959    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
960    pub fn read_avro(
961        &self,
962        path: &str,
963        schema: Option<PyArrowType<Schema>>,
964        table_partition_cols: Vec<(String, String)>,
965        file_extension: &str,
966        py: Python,
967    ) -> PyDataFusionResult<PyDataFrame> {
968        let mut options = AvroReadOptions::default()
969            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
970        options.file_extension = file_extension;
971        let df = if let Some(schema) = schema {
972            options.schema = Some(&schema.0);
973            let read_future = self.ctx.read_avro(path, options);
974            wait_for_future(py, read_future)?
975        } else {
976            let read_future = self.ctx.read_avro(path, options);
977            wait_for_future(py, read_future)?
978        };
979        Ok(PyDataFrame::new(df))
980    }
981
982    pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
983        let df = self.ctx.read_table(table.table())?;
984        Ok(PyDataFrame::new(df))
985    }
986
987    fn __repr__(&self) -> PyResult<String> {
988        let config = self.ctx.copied_config();
989        let mut config_entries = config
990            .options()
991            .entries()
992            .iter()
993            .filter(|e| e.value.is_some())
994            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
995            .collect::<Vec<_>>();
996        config_entries.sort();
997        Ok(format!(
998            "SessionContext: id={}; configs=[\n\t{}]",
999            self.session_id(),
1000            config_entries.join("\n\t")
1001        ))
1002    }
1003
1004    /// Execute a partition of an execution plan and return a stream of record batches
1005    pub fn execute(
1006        &self,
1007        plan: PyExecutionPlan,
1008        part: usize,
1009        py: Python,
1010    ) -> PyDataFusionResult<PyRecordBatchStream> {
1011        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1012        // create a Tokio runtime to run the async code
1013        let rt = &get_tokio_runtime().0;
1014        let plan = plan.plan.clone();
1015        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1016            rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1017        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1018        Ok(PyRecordBatchStream::new(stream?))
1019    }
1020}
1021
1022impl PySessionContext {
1023    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1024        self.ctx.table(name).await
1025    }
1026
1027    async fn register_csv_from_multiple_paths(
1028        &self,
1029        name: &str,
1030        table_paths: Vec<String>,
1031        options: CsvReadOptions<'_>,
1032    ) -> datafusion::common::Result<()> {
1033        let table_paths = table_paths.to_urls()?;
1034        let session_config = self.ctx.copied_config();
1035        let listing_options =
1036            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1037
1038        let option_extension = listing_options.file_extension.clone();
1039
1040        if table_paths.is_empty() {
1041            return exec_err!("No table paths were provided");
1042        }
1043
1044        // check if the file extension matches the expected extension
1045        for path in &table_paths {
1046            let file_path = path.as_str();
1047            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1048                return exec_err!(
1049                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1050                );
1051            }
1052        }
1053
1054        let resolved_schema = options
1055            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1056            .await?;
1057
1058        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1059            .with_listing_options(listing_options)
1060            .with_schema(resolved_schema);
1061        let table = ListingTable::try_new(config)?;
1062        self.ctx
1063            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1064        Ok(())
1065    }
1066}
1067
1068pub fn convert_table_partition_cols(
1069    table_partition_cols: Vec<(String, String)>,
1070) -> PyDataFusionResult<Vec<(String, DataType)>> {
1071    table_partition_cols
1072        .into_iter()
1073        .map(|(name, ty)| match ty.as_str() {
1074            "string" => Ok((name, DataType::Utf8)),
1075            "int" => Ok((name, DataType::Int32)),
1076            _ => Err(crate::errors::PyDataFusionError::Common(format!(
1077                "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
1078            ))),
1079        })
1080        .collect::<Result<Vec<_>, _>>()
1081}
1082
1083pub fn parse_file_compression_type(
1084    file_compression_type: Option<String>,
1085) -> Result<FileCompressionType, PyErr> {
1086    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1087        .map_err(|_| {
1088            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1089        })
1090}
1091
1092impl From<PySessionContext> for SessionContext {
1093    fn from(ctx: PySessionContext) -> SessionContext {
1094        ctx.ctx
1095    }
1096}
1097
1098impl From<SessionContext> for PySessionContext {
1099    fn from(ctx: SessionContext) -> PySessionContext {
1100        PySessionContext { ctx }
1101    }
1102}