datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udtf::PyTableFunction;
47use crate::udwf::PyWindowUDF;
48use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
49use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
50use datafusion::arrow::pyarrow::PyArrowType;
51use datafusion::arrow::record_batch::RecordBatch;
52use datafusion::common::TableReference;
53use datafusion::common::{exec_err, ScalarValue};
54use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
55use datafusion::datasource::file_format::parquet::ParquetFormat;
56use datafusion::datasource::listing::{
57    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
58};
59use datafusion::datasource::MemTable;
60use datafusion::datasource::TableProvider;
61use datafusion::execution::context::{
62    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
63};
64use datafusion::execution::disk_manager::DiskManagerConfig;
65use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
66use datafusion::execution::options::ReadOptions;
67use datafusion::execution::runtime_env::RuntimeEnvBuilder;
68use datafusion::physical_plan::SendableRecordBatchStream;
69use datafusion::prelude::{
70    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
71};
72use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
73use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
74use tokio::task::JoinHandle;
75
76/// Configuration options for a SessionContext
77#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
78#[derive(Clone, Default)]
79pub struct PySessionConfig {
80    pub config: SessionConfig,
81}
82
83impl From<SessionConfig> for PySessionConfig {
84    fn from(config: SessionConfig) -> Self {
85        Self { config }
86    }
87}
88
89#[pymethods]
90impl PySessionConfig {
91    #[pyo3(signature = (config_options=None))]
92    #[new]
93    fn new(config_options: Option<HashMap<String, String>>) -> Self {
94        let mut config = SessionConfig::new();
95        if let Some(hash_map) = config_options {
96            for (k, v) in &hash_map {
97                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
98            }
99        }
100
101        Self { config }
102    }
103
104    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
105        Self::from(
106            self.config
107                .clone()
108                .with_create_default_catalog_and_schema(enabled),
109        )
110    }
111
112    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
113        Self::from(
114            self.config
115                .clone()
116                .with_default_catalog_and_schema(catalog, schema),
117        )
118    }
119
120    fn with_information_schema(&self, enabled: bool) -> Self {
121        Self::from(self.config.clone().with_information_schema(enabled))
122    }
123
124    fn with_batch_size(&self, batch_size: usize) -> Self {
125        Self::from(self.config.clone().with_batch_size(batch_size))
126    }
127
128    fn with_target_partitions(&self, target_partitions: usize) -> Self {
129        Self::from(
130            self.config
131                .clone()
132                .with_target_partitions(target_partitions),
133        )
134    }
135
136    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
137        Self::from(self.config.clone().with_repartition_aggregations(enabled))
138    }
139
140    fn with_repartition_joins(&self, enabled: bool) -> Self {
141        Self::from(self.config.clone().with_repartition_joins(enabled))
142    }
143
144    fn with_repartition_windows(&self, enabled: bool) -> Self {
145        Self::from(self.config.clone().with_repartition_windows(enabled))
146    }
147
148    fn with_repartition_sorts(&self, enabled: bool) -> Self {
149        Self::from(self.config.clone().with_repartition_sorts(enabled))
150    }
151
152    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
153        Self::from(self.config.clone().with_repartition_file_scans(enabled))
154    }
155
156    fn with_repartition_file_min_size(&self, size: usize) -> Self {
157        Self::from(self.config.clone().with_repartition_file_min_size(size))
158    }
159
160    fn with_parquet_pruning(&self, enabled: bool) -> Self {
161        Self::from(self.config.clone().with_parquet_pruning(enabled))
162    }
163
164    fn set(&self, key: &str, value: &str) -> Self {
165        Self::from(self.config.clone().set_str(key, value))
166    }
167}
168
169/// Runtime options for a SessionContext
170#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
171#[derive(Clone)]
172pub struct PyRuntimeEnvBuilder {
173    pub builder: RuntimeEnvBuilder,
174}
175
176#[pymethods]
177impl PyRuntimeEnvBuilder {
178    #[new]
179    fn new() -> Self {
180        Self {
181            builder: RuntimeEnvBuilder::default(),
182        }
183    }
184
185    fn with_disk_manager_disabled(&self) -> Self {
186        let mut builder = self.builder.clone();
187        builder = builder.with_disk_manager(DiskManagerConfig::Disabled);
188        Self { builder }
189    }
190
191    fn with_disk_manager_os(&self) -> Self {
192        let builder = self.builder.clone();
193        let builder = builder.with_disk_manager(DiskManagerConfig::NewOs);
194        Self { builder }
195    }
196
197    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
198        let builder = self.builder.clone();
199        let paths = paths.iter().map(|s| s.into()).collect();
200        let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths));
201        Self { builder }
202    }
203
204    fn with_unbounded_memory_pool(&self) -> Self {
205        let builder = self.builder.clone();
206        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
207        Self { builder }
208    }
209
210    fn with_fair_spill_pool(&self, size: usize) -> Self {
211        let builder = self.builder.clone();
212        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
213        Self { builder }
214    }
215
216    fn with_greedy_memory_pool(&self, size: usize) -> Self {
217        let builder = self.builder.clone();
218        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
219        Self { builder }
220    }
221
222    fn with_temp_file_path(&self, path: &str) -> Self {
223        let builder = self.builder.clone();
224        let builder = builder.with_temp_file_path(path);
225        Self { builder }
226    }
227}
228
229/// `PySQLOptions` allows you to specify options to the sql execution.
230#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
231#[derive(Clone)]
232pub struct PySQLOptions {
233    pub options: SQLOptions,
234}
235
236impl From<SQLOptions> for PySQLOptions {
237    fn from(options: SQLOptions) -> Self {
238        Self { options }
239    }
240}
241
242#[pymethods]
243impl PySQLOptions {
244    #[new]
245    fn new() -> Self {
246        let options = SQLOptions::new();
247        Self { options }
248    }
249
250    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
251    fn with_allow_ddl(&self, allow: bool) -> Self {
252        Self::from(self.options.with_allow_ddl(allow))
253    }
254
255    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
256    pub fn with_allow_dml(&self, allow: bool) -> Self {
257        Self::from(self.options.with_allow_dml(allow))
258    }
259
260    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
261    pub fn with_allow_statements(&self, allow: bool) -> Self {
262        Self::from(self.options.with_allow_statements(allow))
263    }
264}
265
266/// `PySessionContext` is able to plan and execute DataFusion plans.
267/// It has a powerful optimizer, a physical planner for local execution, and a
268/// multi-threaded execution engine to perform the execution.
269#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
270#[derive(Clone)]
271pub struct PySessionContext {
272    pub ctx: SessionContext,
273}
274
275#[pymethods]
276impl PySessionContext {
277    #[pyo3(signature = (config=None, runtime=None))]
278    #[new]
279    pub fn new(
280        config: Option<PySessionConfig>,
281        runtime: Option<PyRuntimeEnvBuilder>,
282    ) -> PyDataFusionResult<Self> {
283        let config = if let Some(c) = config {
284            c.config
285        } else {
286            SessionConfig::default().with_information_schema(true)
287        };
288        let runtime_env_builder = if let Some(c) = runtime {
289            c.builder
290        } else {
291            RuntimeEnvBuilder::default()
292        };
293        let runtime = Arc::new(runtime_env_builder.build()?);
294        let session_state = SessionStateBuilder::new()
295            .with_config(config)
296            .with_runtime_env(runtime)
297            .with_default_features()
298            .build();
299        Ok(PySessionContext {
300            ctx: SessionContext::new_with_state(session_state),
301        })
302    }
303
304    pub fn enable_url_table(&self) -> PyResult<Self> {
305        Ok(PySessionContext {
306            ctx: self.ctx.clone().enable_url_table(),
307        })
308    }
309
310    #[classmethod]
311    #[pyo3(signature = ())]
312    fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
313        Ok(Self {
314            ctx: get_global_ctx().clone(),
315        })
316    }
317
318    /// Register an object store with the given name
319    #[pyo3(signature = (scheme, store, host=None))]
320    pub fn register_object_store(
321        &mut self,
322        scheme: &str,
323        store: StorageContexts,
324        host: Option<&str>,
325    ) -> PyResult<()> {
326        // for most stores the "host" is the bucket name and can be inferred from the store
327        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
328            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
329            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
330            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
331            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
332            StorageContexts::HTTP(http) => (http.store, http.url),
333        };
334
335        // let users override the host to match the api signature from upstream
336        let derived_host = if let Some(host) = host {
337            host
338        } else {
339            &upstream_host
340        };
341        let url_string = format!("{}{}", scheme, derived_host);
342        let url = Url::parse(&url_string).unwrap();
343        self.ctx.runtime_env().register_object_store(&url, store);
344        Ok(())
345    }
346
347    #[allow(clippy::too_many_arguments)]
348    #[pyo3(signature = (name, path, table_partition_cols=vec![],
349    file_extension=".parquet",
350    schema=None,
351    file_sort_order=None))]
352    pub fn register_listing_table(
353        &mut self,
354        name: &str,
355        path: &str,
356        table_partition_cols: Vec<(String, String)>,
357        file_extension: &str,
358        schema: Option<PyArrowType<Schema>>,
359        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
360        py: Python,
361    ) -> PyDataFusionResult<()> {
362        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
363            .with_file_extension(file_extension)
364            .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
365            .with_file_sort_order(
366                file_sort_order
367                    .unwrap_or_default()
368                    .into_iter()
369                    .map(|e| e.into_iter().map(|f| f.into()).collect())
370                    .collect(),
371            );
372        let table_path = ListingTableUrl::parse(path)?;
373        let resolved_schema: SchemaRef = match schema {
374            Some(s) => Arc::new(s.0),
375            None => {
376                let state = self.ctx.state();
377                let schema = options.infer_schema(&state, &table_path);
378                wait_for_future(py, schema)?
379            }
380        };
381        let config = ListingTableConfig::new(table_path)
382            .with_listing_options(options)
383            .with_schema(resolved_schema);
384        let table = ListingTable::try_new(config)?;
385        self.register_table(
386            name,
387            &PyTable {
388                table: Arc::new(table),
389            },
390        )?;
391        Ok(())
392    }
393
394    pub fn register_udtf(&mut self, func: PyTableFunction) {
395        let name = func.name.clone();
396        let func = Arc::new(func);
397        self.ctx.register_udtf(&name, func);
398    }
399
400    /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
401    pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
402        let result = self.ctx.sql(query);
403        let df = wait_for_future(py, result)?;
404        Ok(PyDataFrame::new(df))
405    }
406
407    #[pyo3(signature = (query, options=None))]
408    pub fn sql_with_options(
409        &mut self,
410        query: &str,
411        options: Option<PySQLOptions>,
412        py: Python,
413    ) -> PyDataFusionResult<PyDataFrame> {
414        let options = if let Some(options) = options {
415            options.options
416        } else {
417            SQLOptions::new()
418        };
419        let result = self.ctx.sql_with_options(query, options);
420        let df = wait_for_future(py, result)?;
421        Ok(PyDataFrame::new(df))
422    }
423
424    #[pyo3(signature = (partitions, name=None, schema=None))]
425    pub fn create_dataframe(
426        &mut self,
427        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
428        name: Option<&str>,
429        schema: Option<PyArrowType<Schema>>,
430        py: Python,
431    ) -> PyDataFusionResult<PyDataFrame> {
432        let schema = if let Some(schema) = schema {
433            SchemaRef::from(schema.0)
434        } else {
435            partitions.0[0][0].schema()
436        };
437
438        let table = MemTable::try_new(schema, partitions.0)?;
439
440        // generate a random (unique) name for this table if none is provided
441        // table name cannot start with numeric digit
442        let table_name = match name {
443            Some(val) => val.to_owned(),
444            None => {
445                "c".to_owned()
446                    + Uuid::new_v4()
447                        .simple()
448                        .encode_lower(&mut Uuid::encode_buffer())
449            }
450        };
451
452        self.ctx.register_table(&*table_name, Arc::new(table))?;
453
454        let table = wait_for_future(py, self._table(&table_name))?;
455
456        let df = PyDataFrame::new(table);
457        Ok(df)
458    }
459
460    /// Create a DataFrame from an existing logical plan
461    pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
462        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
463    }
464
465    /// Construct datafusion dataframe from Python list
466    #[pyo3(signature = (data, name=None))]
467    pub fn from_pylist(
468        &mut self,
469        data: Bound<'_, PyList>,
470        name: Option<&str>,
471    ) -> PyResult<PyDataFrame> {
472        // Acquire GIL Token
473        let py = data.py();
474
475        // Instantiate pyarrow Table object & convert to Arrow Table
476        let table_class = py.import("pyarrow")?.getattr("Table")?;
477        let args = PyTuple::new(py, &[data])?;
478        let table = table_class.call_method1("from_pylist", args)?;
479
480        // Convert Arrow Table to datafusion DataFrame
481        let df = self.from_arrow(table, name, py)?;
482        Ok(df)
483    }
484
485    /// Construct datafusion dataframe from Python dictionary
486    #[pyo3(signature = (data, name=None))]
487    pub fn from_pydict(
488        &mut self,
489        data: Bound<'_, PyDict>,
490        name: Option<&str>,
491    ) -> PyResult<PyDataFrame> {
492        // Acquire GIL Token
493        let py = data.py();
494
495        // Instantiate pyarrow Table object & convert to Arrow Table
496        let table_class = py.import("pyarrow")?.getattr("Table")?;
497        let args = PyTuple::new(py, &[data])?;
498        let table = table_class.call_method1("from_pydict", args)?;
499
500        // Convert Arrow Table to datafusion DataFrame
501        let df = self.from_arrow(table, name, py)?;
502        Ok(df)
503    }
504
505    /// Construct datafusion dataframe from Arrow Table
506    #[pyo3(signature = (data, name=None))]
507    pub fn from_arrow(
508        &mut self,
509        data: Bound<'_, PyAny>,
510        name: Option<&str>,
511        py: Python,
512    ) -> PyDataFusionResult<PyDataFrame> {
513        let (schema, batches) =
514            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
515                // Works for any object that implements __arrow_c_stream__ in pycapsule.
516
517                let schema = stream_reader.schema().as_ref().to_owned();
518                let batches = stream_reader
519                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
520
521                (schema, batches)
522            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
523                // While this says RecordBatch, it will work for any object that implements
524                // __arrow_c_array__ and returns a StructArray.
525
526                (array.schema().as_ref().to_owned(), vec![array])
527            } else {
528                return Err(crate::errors::PyDataFusionError::Common(
529                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
530                ));
531            };
532
533        // Because create_dataframe() expects a vector of vectors of record batches
534        // here we need to wrap the vector of record batches in an additional vector
535        let list_of_batches = PyArrowType::from(vec![batches]);
536        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
537    }
538
539    /// Construct datafusion dataframe from pandas
540    #[allow(clippy::wrong_self_convention)]
541    #[pyo3(signature = (data, name=None))]
542    pub fn from_pandas(
543        &mut self,
544        data: Bound<'_, PyAny>,
545        name: Option<&str>,
546    ) -> PyResult<PyDataFrame> {
547        // Obtain GIL token
548        let py = data.py();
549
550        // Instantiate pyarrow Table object & convert to Arrow Table
551        let table_class = py.import("pyarrow")?.getattr("Table")?;
552        let args = PyTuple::new(py, &[data])?;
553        let table = table_class.call_method1("from_pandas", args)?;
554
555        // Convert Arrow Table to datafusion DataFrame
556        let df = self.from_arrow(table, name, py)?;
557        Ok(df)
558    }
559
560    /// Construct datafusion dataframe from polars
561    #[pyo3(signature = (data, name=None))]
562    pub fn from_polars(
563        &mut self,
564        data: Bound<'_, PyAny>,
565        name: Option<&str>,
566    ) -> PyResult<PyDataFrame> {
567        // Convert Polars dataframe to Arrow Table
568        let table = data.call_method0("to_arrow")?;
569
570        // Convert Arrow Table to datafusion DataFrame
571        let df = self.from_arrow(table, name, data.py())?;
572        Ok(df)
573    }
574
575    pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
576        self.ctx.register_table(name, table.table())?;
577        Ok(())
578    }
579
580    pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
581        self.ctx.deregister_table(name)?;
582        Ok(())
583    }
584
585    /// Construct datafusion dataframe from Arrow Table
586    pub fn register_table_provider(
587        &mut self,
588        name: &str,
589        provider: Bound<'_, PyAny>,
590    ) -> PyDataFusionResult<()> {
591        if provider.hasattr("__datafusion_table_provider__")? {
592            let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
593            let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
594            validate_pycapsule(capsule, "datafusion_table_provider")?;
595
596            let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
597            let provider: ForeignTableProvider = provider.into();
598
599            let _ = self.ctx.register_table(name, Arc::new(provider))?;
600
601            Ok(())
602        } else {
603            Err(crate::errors::PyDataFusionError::Common(
604                "__datafusion_table_provider__ does not exist on Table Provider object."
605                    .to_string(),
606            ))
607        }
608    }
609
610    pub fn register_record_batches(
611        &mut self,
612        name: &str,
613        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
614    ) -> PyDataFusionResult<()> {
615        let schema = partitions.0[0][0].schema();
616        let table = MemTable::try_new(schema, partitions.0)?;
617        self.ctx.register_table(name, Arc::new(table))?;
618        Ok(())
619    }
620
621    #[allow(clippy::too_many_arguments)]
622    #[pyo3(signature = (name, path, table_partition_cols=vec![],
623                        parquet_pruning=true,
624                        file_extension=".parquet",
625                        skip_metadata=true,
626                        schema=None,
627                        file_sort_order=None))]
628    pub fn register_parquet(
629        &mut self,
630        name: &str,
631        path: &str,
632        table_partition_cols: Vec<(String, String)>,
633        parquet_pruning: bool,
634        file_extension: &str,
635        skip_metadata: bool,
636        schema: Option<PyArrowType<Schema>>,
637        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
638        py: Python,
639    ) -> PyDataFusionResult<()> {
640        let mut options = ParquetReadOptions::default()
641            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
642            .parquet_pruning(parquet_pruning)
643            .skip_metadata(skip_metadata);
644        options.file_extension = file_extension;
645        options.schema = schema.as_ref().map(|x| &x.0);
646        options.file_sort_order = file_sort_order
647            .unwrap_or_default()
648            .into_iter()
649            .map(|e| e.into_iter().map(|f| f.into()).collect())
650            .collect();
651
652        let result = self.ctx.register_parquet(name, path, options);
653        wait_for_future(py, result)?;
654        Ok(())
655    }
656
657    #[allow(clippy::too_many_arguments)]
658    #[pyo3(signature = (name,
659                        path,
660                        schema=None,
661                        has_header=true,
662                        delimiter=",",
663                        schema_infer_max_records=1000,
664                        file_extension=".csv",
665                        file_compression_type=None))]
666    pub fn register_csv(
667        &mut self,
668        name: &str,
669        path: &Bound<'_, PyAny>,
670        schema: Option<PyArrowType<Schema>>,
671        has_header: bool,
672        delimiter: &str,
673        schema_infer_max_records: usize,
674        file_extension: &str,
675        file_compression_type: Option<String>,
676        py: Python,
677    ) -> PyDataFusionResult<()> {
678        let delimiter = delimiter.as_bytes();
679        if delimiter.len() != 1 {
680            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
681                "Delimiter must be a single character",
682            )));
683        }
684
685        let mut options = CsvReadOptions::new()
686            .has_header(has_header)
687            .delimiter(delimiter[0])
688            .schema_infer_max_records(schema_infer_max_records)
689            .file_extension(file_extension)
690            .file_compression_type(parse_file_compression_type(file_compression_type)?);
691        options.schema = schema.as_ref().map(|x| &x.0);
692
693        if path.is_instance_of::<PyList>() {
694            let paths = path.extract::<Vec<String>>()?;
695            let result = self.register_csv_from_multiple_paths(name, paths, options);
696            wait_for_future(py, result)?;
697        } else {
698            let path = path.extract::<String>()?;
699            let result = self.ctx.register_csv(name, &path, options);
700            wait_for_future(py, result)?;
701        }
702
703        Ok(())
704    }
705
706    #[allow(clippy::too_many_arguments)]
707    #[pyo3(signature = (name,
708                        path,
709                        schema=None,
710                        schema_infer_max_records=1000,
711                        file_extension=".json",
712                        table_partition_cols=vec![],
713                        file_compression_type=None))]
714    pub fn register_json(
715        &mut self,
716        name: &str,
717        path: PathBuf,
718        schema: Option<PyArrowType<Schema>>,
719        schema_infer_max_records: usize,
720        file_extension: &str,
721        table_partition_cols: Vec<(String, String)>,
722        file_compression_type: Option<String>,
723        py: Python,
724    ) -> PyDataFusionResult<()> {
725        let path = path
726            .to_str()
727            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
728
729        let mut options = NdJsonReadOptions::default()
730            .file_compression_type(parse_file_compression_type(file_compression_type)?)
731            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
732        options.schema_infer_max_records = schema_infer_max_records;
733        options.file_extension = file_extension;
734        options.schema = schema.as_ref().map(|x| &x.0);
735
736        let result = self.ctx.register_json(name, path, options);
737        wait_for_future(py, result)?;
738
739        Ok(())
740    }
741
742    #[allow(clippy::too_many_arguments)]
743    #[pyo3(signature = (name,
744                        path,
745                        schema=None,
746                        file_extension=".avro",
747                        table_partition_cols=vec![]))]
748    pub fn register_avro(
749        &mut self,
750        name: &str,
751        path: PathBuf,
752        schema: Option<PyArrowType<Schema>>,
753        file_extension: &str,
754        table_partition_cols: Vec<(String, String)>,
755        py: Python,
756    ) -> PyDataFusionResult<()> {
757        let path = path
758            .to_str()
759            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
760
761        let mut options = AvroReadOptions::default()
762            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
763        options.file_extension = file_extension;
764        options.schema = schema.as_ref().map(|x| &x.0);
765
766        let result = self.ctx.register_avro(name, path, options);
767        wait_for_future(py, result)?;
768
769        Ok(())
770    }
771
772    // Registers a PyArrow.Dataset
773    pub fn register_dataset(
774        &self,
775        name: &str,
776        dataset: &Bound<'_, PyAny>,
777        py: Python,
778    ) -> PyDataFusionResult<()> {
779        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
780
781        self.ctx.register_table(name, table)?;
782
783        Ok(())
784    }
785
786    pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
787        self.ctx.register_udf(udf.function);
788        Ok(())
789    }
790
791    pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
792        self.ctx.register_udaf(udaf.function);
793        Ok(())
794    }
795
796    pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
797        self.ctx.register_udwf(udwf.function);
798        Ok(())
799    }
800
801    #[pyo3(signature = (name="datafusion"))]
802    pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
803        match self.ctx.catalog(name) {
804            Some(catalog) => Ok(PyCatalog::new(catalog)),
805            None => Err(PyKeyError::new_err(format!(
806                "Catalog with name {} doesn't exist.",
807                &name,
808            ))),
809        }
810    }
811
812    pub fn tables(&self) -> HashSet<String> {
813        self.ctx
814            .catalog_names()
815            .into_iter()
816            .filter_map(|name| self.ctx.catalog(&name))
817            .flat_map(move |catalog| {
818                catalog
819                    .schema_names()
820                    .into_iter()
821                    .filter_map(move |name| catalog.schema(&name))
822            })
823            .flat_map(|schema| schema.table_names())
824            .collect()
825    }
826
827    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
828        let x = wait_for_future(py, self.ctx.table(name))
829            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
830        Ok(PyDataFrame::new(x))
831    }
832
833    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
834        Ok(self.ctx.table_exist(name)?)
835    }
836
837    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
838        Ok(PyDataFrame::new(self.ctx.read_empty()?))
839    }
840
841    pub fn session_id(&self) -> String {
842        self.ctx.session_id()
843    }
844
845    #[allow(clippy::too_many_arguments)]
846    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
847    pub fn read_json(
848        &mut self,
849        path: PathBuf,
850        schema: Option<PyArrowType<Schema>>,
851        schema_infer_max_records: usize,
852        file_extension: &str,
853        table_partition_cols: Vec<(String, String)>,
854        file_compression_type: Option<String>,
855        py: Python,
856    ) -> PyDataFusionResult<PyDataFrame> {
857        let path = path
858            .to_str()
859            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
860        let mut options = NdJsonReadOptions::default()
861            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
862            .file_compression_type(parse_file_compression_type(file_compression_type)?);
863        options.schema_infer_max_records = schema_infer_max_records;
864        options.file_extension = file_extension;
865        let df = if let Some(schema) = schema {
866            options.schema = Some(&schema.0);
867            let result = self.ctx.read_json(path, options);
868            wait_for_future(py, result)?
869        } else {
870            let result = self.ctx.read_json(path, options);
871            wait_for_future(py, result)?
872        };
873        Ok(PyDataFrame::new(df))
874    }
875
876    #[allow(clippy::too_many_arguments)]
877    #[pyo3(signature = (
878        path,
879        schema=None,
880        has_header=true,
881        delimiter=",",
882        schema_infer_max_records=1000,
883        file_extension=".csv",
884        table_partition_cols=vec![],
885        file_compression_type=None))]
886    pub fn read_csv(
887        &self,
888        path: &Bound<'_, PyAny>,
889        schema: Option<PyArrowType<Schema>>,
890        has_header: bool,
891        delimiter: &str,
892        schema_infer_max_records: usize,
893        file_extension: &str,
894        table_partition_cols: Vec<(String, String)>,
895        file_compression_type: Option<String>,
896        py: Python,
897    ) -> PyDataFusionResult<PyDataFrame> {
898        let delimiter = delimiter.as_bytes();
899        if delimiter.len() != 1 {
900            return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
901                "Delimiter must be a single character",
902            )));
903        };
904
905        let mut options = CsvReadOptions::new()
906            .has_header(has_header)
907            .delimiter(delimiter[0])
908            .schema_infer_max_records(schema_infer_max_records)
909            .file_extension(file_extension)
910            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
911            .file_compression_type(parse_file_compression_type(file_compression_type)?);
912        options.schema = schema.as_ref().map(|x| &x.0);
913
914        if path.is_instance_of::<PyList>() {
915            let paths = path.extract::<Vec<String>>()?;
916            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
917            let result = self.ctx.read_csv(paths, options);
918            let df = PyDataFrame::new(wait_for_future(py, result)?);
919            Ok(df)
920        } else {
921            let path = path.extract::<String>()?;
922            let result = self.ctx.read_csv(path, options);
923            let df = PyDataFrame::new(wait_for_future(py, result)?);
924            Ok(df)
925        }
926    }
927
928    #[allow(clippy::too_many_arguments)]
929    #[pyo3(signature = (
930        path,
931        table_partition_cols=vec![],
932        parquet_pruning=true,
933        file_extension=".parquet",
934        skip_metadata=true,
935        schema=None,
936        file_sort_order=None))]
937    pub fn read_parquet(
938        &self,
939        path: &str,
940        table_partition_cols: Vec<(String, String)>,
941        parquet_pruning: bool,
942        file_extension: &str,
943        skip_metadata: bool,
944        schema: Option<PyArrowType<Schema>>,
945        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
946        py: Python,
947    ) -> PyDataFusionResult<PyDataFrame> {
948        let mut options = ParquetReadOptions::default()
949            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
950            .parquet_pruning(parquet_pruning)
951            .skip_metadata(skip_metadata);
952        options.file_extension = file_extension;
953        options.schema = schema.as_ref().map(|x| &x.0);
954        options.file_sort_order = file_sort_order
955            .unwrap_or_default()
956            .into_iter()
957            .map(|e| e.into_iter().map(|f| f.into()).collect())
958            .collect();
959
960        let result = self.ctx.read_parquet(path, options);
961        let df = PyDataFrame::new(wait_for_future(py, result)?);
962        Ok(df)
963    }
964
965    #[allow(clippy::too_many_arguments)]
966    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
967    pub fn read_avro(
968        &self,
969        path: &str,
970        schema: Option<PyArrowType<Schema>>,
971        table_partition_cols: Vec<(String, String)>,
972        file_extension: &str,
973        py: Python,
974    ) -> PyDataFusionResult<PyDataFrame> {
975        let mut options = AvroReadOptions::default()
976            .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
977        options.file_extension = file_extension;
978        let df = if let Some(schema) = schema {
979            options.schema = Some(&schema.0);
980            let read_future = self.ctx.read_avro(path, options);
981            wait_for_future(py, read_future)?
982        } else {
983            let read_future = self.ctx.read_avro(path, options);
984            wait_for_future(py, read_future)?
985        };
986        Ok(PyDataFrame::new(df))
987    }
988
989    pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
990        let df = self.ctx.read_table(table.table())?;
991        Ok(PyDataFrame::new(df))
992    }
993
994    fn __repr__(&self) -> PyResult<String> {
995        let config = self.ctx.copied_config();
996        let mut config_entries = config
997            .options()
998            .entries()
999            .iter()
1000            .filter(|e| e.value.is_some())
1001            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1002            .collect::<Vec<_>>();
1003        config_entries.sort();
1004        Ok(format!(
1005            "SessionContext: id={}; configs=[\n\t{}]",
1006            self.session_id(),
1007            config_entries.join("\n\t")
1008        ))
1009    }
1010
1011    /// Execute a partition of an execution plan and return a stream of record batches
1012    pub fn execute(
1013        &self,
1014        plan: PyExecutionPlan,
1015        part: usize,
1016        py: Python,
1017    ) -> PyDataFusionResult<PyRecordBatchStream> {
1018        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1019        // create a Tokio runtime to run the async code
1020        let rt = &get_tokio_runtime().0;
1021        let plan = plan.plan.clone();
1022        let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1023            rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1024        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1025        Ok(PyRecordBatchStream::new(stream?))
1026    }
1027}
1028
1029impl PySessionContext {
1030    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1031        self.ctx.table(name).await
1032    }
1033
1034    async fn register_csv_from_multiple_paths(
1035        &self,
1036        name: &str,
1037        table_paths: Vec<String>,
1038        options: CsvReadOptions<'_>,
1039    ) -> datafusion::common::Result<()> {
1040        let table_paths = table_paths.to_urls()?;
1041        let session_config = self.ctx.copied_config();
1042        let listing_options =
1043            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1044
1045        let option_extension = listing_options.file_extension.clone();
1046
1047        if table_paths.is_empty() {
1048            return exec_err!("No table paths were provided");
1049        }
1050
1051        // check if the file extension matches the expected extension
1052        for path in &table_paths {
1053            let file_path = path.as_str();
1054            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1055                return exec_err!(
1056                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1057                );
1058            }
1059        }
1060
1061        let resolved_schema = options
1062            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1063            .await?;
1064
1065        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1066            .with_listing_options(listing_options)
1067            .with_schema(resolved_schema);
1068        let table = ListingTable::try_new(config)?;
1069        self.ctx
1070            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1071        Ok(())
1072    }
1073}
1074
1075pub fn convert_table_partition_cols(
1076    table_partition_cols: Vec<(String, String)>,
1077) -> PyDataFusionResult<Vec<(String, DataType)>> {
1078    table_partition_cols
1079        .into_iter()
1080        .map(|(name, ty)| match ty.as_str() {
1081            "string" => Ok((name, DataType::Utf8)),
1082            "int" => Ok((name, DataType::Int32)),
1083            _ => Err(crate::errors::PyDataFusionError::Common(format!(
1084                "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
1085            ))),
1086        })
1087        .collect::<Result<Vec<_>, _>>()
1088}
1089
1090pub fn parse_file_compression_type(
1091    file_compression_type: Option<String>,
1092) -> Result<FileCompressionType, PyErr> {
1093    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1094        .map_err(|_| {
1095            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1096        })
1097}
1098
1099impl From<PySessionContext> for SessionContext {
1100    fn from(ctx: PySessionContext) -> SessionContext {
1101        ctx.ctx
1102    }
1103}
1104
1105impl From<SessionContext> for PySessionContext {
1106    fn from(ctx: SessionContext) -> PySessionContext {
1107        PySessionContext { ctx }
1108    }
1109}