Skip to main content

datafusion_python/
context.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
27use datafusion::arrow::pyarrow::PyArrowType;
28use datafusion::arrow::record_batch::RecordBatch;
29use datafusion::catalog::{CatalogProvider, CatalogProviderList};
30use datafusion::common::{ScalarValue, TableReference, exec_err};
31use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
32use datafusion::datasource::file_format::parquet::ParquetFormat;
33use datafusion::datasource::listing::{
34    ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
35};
36use datafusion::datasource::{MemTable, TableProvider};
37use datafusion::execution::TaskContextProvider;
38use datafusion::execution::context::{
39    DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
40};
41use datafusion::execution::disk_manager::DiskManagerMode;
42use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
43use datafusion::execution::options::ReadOptions;
44use datafusion::execution::runtime_env::RuntimeEnvBuilder;
45use datafusion::execution::session_state::SessionStateBuilder;
46use datafusion::prelude::{
47    AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
48};
49use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
50use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
51use datafusion_ffi::execution::FFI_TaskContextProvider;
52use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
53use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
54use object_store::ObjectStore;
55use pyo3::IntoPyObjectExt;
56use pyo3::exceptions::{PyKeyError, PyValueError};
57use pyo3::prelude::*;
58use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
59use url::Url;
60use uuid::Uuid;
61
62use crate::catalog::{
63    PyCatalog, PyCatalogList, RustWrappedPyCatalogProvider, RustWrappedPyCatalogProviderList,
64};
65use crate::common::data_type::PyScalarValue;
66use crate::dataframe::PyDataFrame;
67use crate::dataset::Dataset;
68use crate::errors::{
69    PyDataFusionError, PyDataFusionResult, from_datafusion_error, py_datafusion_err,
70};
71use crate::expr::sort_expr::PySortExpr;
72use crate::options::PyCsvReadOptions;
73use crate::physical_plan::PyExecutionPlan;
74use crate::record_batch::PyRecordBatchStream;
75use crate::sql::logical::PyLogicalPlan;
76use crate::sql::util::replace_placeholders_with_strings;
77use crate::store::StorageContexts;
78use crate::table::PyTable;
79use crate::udaf::PyAggregateUDF;
80use crate::udf::PyScalarUDF;
81use crate::udtf::PyTableFunction;
82use crate::udwf::PyWindowUDF;
83use crate::utils::{
84    create_logical_extension_capsule, extract_logical_extension_codec, get_global_ctx,
85    get_tokio_runtime, spawn_future, validate_pycapsule, wait_for_future,
86};
87
88/// Configuration options for a SessionContext
89#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)]
90#[derive(Clone, Default)]
91pub struct PySessionConfig {
92    pub config: SessionConfig,
93}
94
95impl From<SessionConfig> for PySessionConfig {
96    fn from(config: SessionConfig) -> Self {
97        Self { config }
98    }
99}
100
101#[pymethods]
102impl PySessionConfig {
103    #[pyo3(signature = (config_options=None))]
104    #[new]
105    fn new(config_options: Option<HashMap<String, String>>) -> Self {
106        let mut config = SessionConfig::new();
107        if let Some(hash_map) = config_options {
108            for (k, v) in &hash_map {
109                config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
110            }
111        }
112
113        Self { config }
114    }
115
116    fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
117        Self::from(
118            self.config
119                .clone()
120                .with_create_default_catalog_and_schema(enabled),
121        )
122    }
123
124    fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
125        Self::from(
126            self.config
127                .clone()
128                .with_default_catalog_and_schema(catalog, schema),
129        )
130    }
131
132    fn with_information_schema(&self, enabled: bool) -> Self {
133        Self::from(self.config.clone().with_information_schema(enabled))
134    }
135
136    fn with_batch_size(&self, batch_size: usize) -> Self {
137        Self::from(self.config.clone().with_batch_size(batch_size))
138    }
139
140    fn with_target_partitions(&self, target_partitions: usize) -> Self {
141        Self::from(
142            self.config
143                .clone()
144                .with_target_partitions(target_partitions),
145        )
146    }
147
148    fn with_repartition_aggregations(&self, enabled: bool) -> Self {
149        Self::from(self.config.clone().with_repartition_aggregations(enabled))
150    }
151
152    fn with_repartition_joins(&self, enabled: bool) -> Self {
153        Self::from(self.config.clone().with_repartition_joins(enabled))
154    }
155
156    fn with_repartition_windows(&self, enabled: bool) -> Self {
157        Self::from(self.config.clone().with_repartition_windows(enabled))
158    }
159
160    fn with_repartition_sorts(&self, enabled: bool) -> Self {
161        Self::from(self.config.clone().with_repartition_sorts(enabled))
162    }
163
164    fn with_repartition_file_scans(&self, enabled: bool) -> Self {
165        Self::from(self.config.clone().with_repartition_file_scans(enabled))
166    }
167
168    fn with_repartition_file_min_size(&self, size: usize) -> Self {
169        Self::from(self.config.clone().with_repartition_file_min_size(size))
170    }
171
172    fn with_parquet_pruning(&self, enabled: bool) -> Self {
173        Self::from(self.config.clone().with_parquet_pruning(enabled))
174    }
175
176    fn set(&self, key: &str, value: &str) -> Self {
177        Self::from(self.config.clone().set_str(key, value))
178    }
179}
180
181/// Runtime options for a SessionContext
182#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
183#[derive(Clone)]
184pub struct PyRuntimeEnvBuilder {
185    pub builder: RuntimeEnvBuilder,
186}
187
188#[pymethods]
189impl PyRuntimeEnvBuilder {
190    #[new]
191    fn new() -> Self {
192        Self {
193            builder: RuntimeEnvBuilder::default(),
194        }
195    }
196
197    fn with_disk_manager_disabled(&self) -> Self {
198        let mut runtime_builder = self.builder.clone();
199
200        let mut disk_mgr_builder = runtime_builder
201            .disk_manager_builder
202            .clone()
203            .unwrap_or_default();
204        disk_mgr_builder.set_mode(DiskManagerMode::Disabled);
205
206        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
207        Self {
208            builder: runtime_builder,
209        }
210    }
211
212    fn with_disk_manager_os(&self) -> Self {
213        let mut runtime_builder = self.builder.clone();
214
215        let mut disk_mgr_builder = runtime_builder
216            .disk_manager_builder
217            .clone()
218            .unwrap_or_default();
219        disk_mgr_builder.set_mode(DiskManagerMode::OsTmpDirectory);
220
221        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
222        Self {
223            builder: runtime_builder,
224        }
225    }
226
227    fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
228        let paths = paths.iter().map(|s| s.into()).collect();
229        let mut runtime_builder = self.builder.clone();
230
231        let mut disk_mgr_builder = runtime_builder
232            .disk_manager_builder
233            .clone()
234            .unwrap_or_default();
235        disk_mgr_builder.set_mode(DiskManagerMode::Directories(paths));
236
237        runtime_builder = runtime_builder.with_disk_manager_builder(disk_mgr_builder);
238        Self {
239            builder: runtime_builder,
240        }
241    }
242
243    fn with_unbounded_memory_pool(&self) -> Self {
244        let builder = self.builder.clone();
245        let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
246        Self { builder }
247    }
248
249    fn with_fair_spill_pool(&self, size: usize) -> Self {
250        let builder = self.builder.clone();
251        let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
252        Self { builder }
253    }
254
255    fn with_greedy_memory_pool(&self, size: usize) -> Self {
256        let builder = self.builder.clone();
257        let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
258        Self { builder }
259    }
260
261    fn with_temp_file_path(&self, path: &str) -> Self {
262        let builder = self.builder.clone();
263        let builder = builder.with_temp_file_path(path);
264        Self { builder }
265    }
266}
267
268/// `PySQLOptions` allows you to specify options to the sql execution.
269#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)]
270#[derive(Clone)]
271pub struct PySQLOptions {
272    pub options: SQLOptions,
273}
274
275impl From<SQLOptions> for PySQLOptions {
276    fn from(options: SQLOptions) -> Self {
277        Self { options }
278    }
279}
280
281#[pymethods]
282impl PySQLOptions {
283    #[new]
284    fn new() -> Self {
285        let options = SQLOptions::new();
286        Self { options }
287    }
288
289    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? Defaults to `true`.
290    fn with_allow_ddl(&self, allow: bool) -> Self {
291        Self::from(self.options.with_allow_ddl(allow))
292    }
293
294    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
295    pub fn with_allow_dml(&self, allow: bool) -> Self {
296        Self::from(self.options.with_allow_dml(allow))
297    }
298
299    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
300    pub fn with_allow_statements(&self, allow: bool) -> Self {
301        Self::from(self.options.with_allow_statements(allow))
302    }
303}
304
305/// `PySessionContext` is able to plan and execute DataFusion plans.
306/// It has a powerful optimizer, a physical planner for local execution, and a
307/// multi-threaded execution engine to perform the execution.
308#[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)]
309#[derive(Clone)]
310pub struct PySessionContext {
311    pub ctx: Arc<SessionContext>,
312    logical_codec: Arc<FFI_LogicalExtensionCodec>,
313}
314
315#[pymethods]
316impl PySessionContext {
317    #[pyo3(signature = (config=None, runtime=None))]
318    #[new]
319    pub fn new(
320        config: Option<PySessionConfig>,
321        runtime: Option<PyRuntimeEnvBuilder>,
322    ) -> PyDataFusionResult<Self> {
323        let config = if let Some(c) = config {
324            c.config
325        } else {
326            SessionConfig::default().with_information_schema(true)
327        };
328        let runtime_env_builder = if let Some(c) = runtime {
329            c.builder
330        } else {
331            RuntimeEnvBuilder::default()
332        };
333        let runtime = Arc::new(runtime_env_builder.build()?);
334        let session_state = SessionStateBuilder::new()
335            .with_config(config)
336            .with_runtime_env(runtime)
337            .with_default_features()
338            .build();
339        let ctx = Arc::new(SessionContext::new_with_state(session_state));
340        let logical_codec = Self::default_logical_codec(&ctx);
341        Ok(PySessionContext { ctx, logical_codec })
342    }
343
344    pub fn enable_url_table(&self) -> PyResult<Self> {
345        Ok(PySessionContext {
346            ctx: Arc::new(self.ctx.as_ref().clone().enable_url_table()),
347            logical_codec: Arc::clone(&self.logical_codec),
348        })
349    }
350
351    #[staticmethod]
352    #[pyo3(signature = ())]
353    pub fn global_ctx() -> PyResult<Self> {
354        let ctx = get_global_ctx().clone();
355        let logical_codec = Self::default_logical_codec(&ctx);
356        Ok(Self { ctx, logical_codec })
357    }
358
359    /// Register an object store with the given name
360    #[pyo3(signature = (scheme, store, host=None))]
361    pub fn register_object_store(
362        &self,
363        scheme: &str,
364        store: StorageContexts,
365        host: Option<&str>,
366    ) -> PyResult<()> {
367        // for most stores the "host" is the bucket name and can be inferred from the store
368        let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
369            StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
370            StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
371            StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
372            StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
373            StorageContexts::HTTP(http) => (http.store, http.url),
374        };
375
376        // let users override the host to match the api signature from upstream
377        let derived_host = if let Some(host) = host {
378            host
379        } else {
380            &upstream_host
381        };
382        let url_string = format!("{scheme}{derived_host}");
383        let url = Url::parse(&url_string).unwrap();
384        self.ctx.runtime_env().register_object_store(&url, store);
385        Ok(())
386    }
387
388    #[allow(clippy::too_many_arguments)]
389    #[pyo3(signature = (name, path, table_partition_cols=vec![],
390    file_extension=".parquet",
391    schema=None,
392    file_sort_order=None))]
393    pub fn register_listing_table(
394        &self,
395        name: &str,
396        path: &str,
397        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
398        file_extension: &str,
399        schema: Option<PyArrowType<Schema>>,
400        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
401        py: Python,
402    ) -> PyDataFusionResult<()> {
403        let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
404            .with_file_extension(file_extension)
405            .with_table_partition_cols(
406                table_partition_cols
407                    .into_iter()
408                    .map(|(name, ty)| (name, ty.0))
409                    .collect::<Vec<(String, DataType)>>(),
410            )
411            .with_file_sort_order(
412                file_sort_order
413                    .unwrap_or_default()
414                    .into_iter()
415                    .map(|e| e.into_iter().map(|f| f.into()).collect())
416                    .collect(),
417            );
418        let table_path = ListingTableUrl::parse(path)?;
419        let resolved_schema: SchemaRef = match schema {
420            Some(s) => Arc::new(s.0),
421            None => {
422                let state = self.ctx.state();
423                let schema = options.infer_schema(&state, &table_path);
424                wait_for_future(py, schema)??
425            }
426        };
427        let config = ListingTableConfig::new(table_path)
428            .with_listing_options(options)
429            .with_schema(resolved_schema);
430        let table = ListingTable::try_new(config)?;
431        self.ctx.register_table(name, Arc::new(table))?;
432        Ok(())
433    }
434
435    pub fn register_udtf(&self, func: PyTableFunction) {
436        let name = func.name.clone();
437        let func = Arc::new(func);
438        self.ctx.register_udtf(&name, func);
439    }
440
441    #[pyo3(signature = (query, options=None, param_values=HashMap::default(), param_strings=HashMap::default()))]
442    pub fn sql_with_options(
443        &self,
444        py: Python,
445        mut query: String,
446        options: Option<PySQLOptions>,
447        param_values: HashMap<String, PyScalarValue>,
448        param_strings: HashMap<String, String>,
449    ) -> PyDataFusionResult<PyDataFrame> {
450        let options = if let Some(options) = options {
451            options.options
452        } else {
453            SQLOptions::new()
454        };
455
456        let param_values = param_values
457            .into_iter()
458            .map(|(name, value)| (name, ScalarValue::from(value)))
459            .collect::<HashMap<_, _>>();
460
461        let state = self.ctx.state();
462        let dialect = state.config().options().sql_parser.dialect.as_ref();
463
464        if !param_strings.is_empty() {
465            query = replace_placeholders_with_strings(&query, dialect, param_strings)?;
466        }
467
468        let mut df = wait_for_future(py, async {
469            self.ctx.sql_with_options(&query, options).await
470        })?
471        .map_err(from_datafusion_error)?;
472
473        if !param_values.is_empty() {
474            df = df.with_param_values(param_values)?;
475        }
476
477        Ok(PyDataFrame::new(df))
478    }
479
480    #[pyo3(signature = (partitions, name=None, schema=None))]
481    pub fn create_dataframe(
482        &self,
483        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
484        name: Option<&str>,
485        schema: Option<PyArrowType<Schema>>,
486        py: Python,
487    ) -> PyDataFusionResult<PyDataFrame> {
488        let schema = if let Some(schema) = schema {
489            SchemaRef::from(schema.0)
490        } else {
491            partitions.0[0][0].schema()
492        };
493
494        let table = MemTable::try_new(schema, partitions.0)?;
495
496        // generate a random (unique) name for this table if none is provided
497        // table name cannot start with numeric digit
498        let table_name = match name {
499            Some(val) => val.to_owned(),
500            None => {
501                "c".to_owned()
502                    + Uuid::new_v4()
503                        .simple()
504                        .encode_lower(&mut Uuid::encode_buffer())
505            }
506        };
507
508        self.ctx.register_table(&*table_name, Arc::new(table))?;
509
510        let table = wait_for_future(py, self._table(&table_name))??;
511
512        let df = PyDataFrame::new(table);
513        Ok(df)
514    }
515
516    /// Create a DataFrame from an existing logical plan
517    pub fn create_dataframe_from_logical_plan(&self, plan: PyLogicalPlan) -> PyDataFrame {
518        PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
519    }
520
521    /// Construct datafusion dataframe from Python list
522    #[pyo3(signature = (data, name=None))]
523    pub fn from_pylist(
524        &self,
525        data: Bound<'_, PyList>,
526        name: Option<&str>,
527    ) -> PyResult<PyDataFrame> {
528        // Acquire GIL Token
529        let py = data.py();
530
531        // Instantiate pyarrow Table object & convert to Arrow Table
532        let table_class = py.import("pyarrow")?.getattr("Table")?;
533        let args = PyTuple::new(py, &[data])?;
534        let table = table_class.call_method1("from_pylist", args)?;
535
536        // Convert Arrow Table to datafusion DataFrame
537        let df = self.from_arrow(table, name, py)?;
538        Ok(df)
539    }
540
541    /// Construct datafusion dataframe from Python dictionary
542    #[pyo3(signature = (data, name=None))]
543    pub fn from_pydict(
544        &self,
545        data: Bound<'_, PyDict>,
546        name: Option<&str>,
547    ) -> PyResult<PyDataFrame> {
548        // Acquire GIL Token
549        let py = data.py();
550
551        // Instantiate pyarrow Table object & convert to Arrow Table
552        let table_class = py.import("pyarrow")?.getattr("Table")?;
553        let args = PyTuple::new(py, &[data])?;
554        let table = table_class.call_method1("from_pydict", args)?;
555
556        // Convert Arrow Table to datafusion DataFrame
557        let df = self.from_arrow(table, name, py)?;
558        Ok(df)
559    }
560
561    /// Construct datafusion dataframe from Arrow Table
562    #[pyo3(signature = (data, name=None))]
563    pub fn from_arrow(
564        &self,
565        data: Bound<'_, PyAny>,
566        name: Option<&str>,
567        py: Python,
568    ) -> PyDataFusionResult<PyDataFrame> {
569        let (schema, batches) =
570            if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
571                // Works for any object that implements __arrow_c_stream__ in pycapsule.
572
573                let schema = stream_reader.schema().as_ref().to_owned();
574                let batches = stream_reader
575                    .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
576
577                (schema, batches)
578            } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
579                // While this says RecordBatch, it will work for any object that implements
580                // __arrow_c_array__ and returns a StructArray.
581
582                (array.schema().as_ref().to_owned(), vec![array])
583            } else {
584                return Err(PyDataFusionError::Common(
585                    "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
586                ));
587            };
588
589        // Because create_dataframe() expects a vector of vectors of record batches
590        // here we need to wrap the vector of record batches in an additional vector
591        let list_of_batches = PyArrowType::from(vec![batches]);
592        self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
593    }
594
595    /// Construct datafusion dataframe from pandas
596    #[allow(clippy::wrong_self_convention)]
597    #[pyo3(signature = (data, name=None))]
598    pub fn from_pandas(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
599        // Obtain GIL token
600        let py = data.py();
601
602        // Instantiate pyarrow Table object & convert to Arrow Table
603        let table_class = py.import("pyarrow")?.getattr("Table")?;
604        let args = PyTuple::new(py, &[data])?;
605        let table = table_class.call_method1("from_pandas", args)?;
606
607        // Convert Arrow Table to datafusion DataFrame
608        let df = self.from_arrow(table, name, py)?;
609        Ok(df)
610    }
611
612    /// Construct datafusion dataframe from polars
613    #[pyo3(signature = (data, name=None))]
614    pub fn from_polars(&self, data: Bound<'_, PyAny>, name: Option<&str>) -> PyResult<PyDataFrame> {
615        // Convert Polars dataframe to Arrow Table
616        let table = data.call_method0("to_arrow")?;
617
618        // Convert Arrow Table to datafusion DataFrame
619        let df = self.from_arrow(table, name, data.py())?;
620        Ok(df)
621    }
622
623    pub fn register_table(&self, name: &str, table: Bound<'_, PyAny>) -> PyDataFusionResult<()> {
624        let session = self.clone().into_bound_py_any(table.py())?;
625        let table = PyTable::new(table, Some(session))?;
626
627        self.ctx.register_table(name, table.table)?;
628        Ok(())
629    }
630
631    pub fn deregister_table(&self, name: &str) -> PyDataFusionResult<()> {
632        self.ctx.deregister_table(name)?;
633        Ok(())
634    }
635
636    pub fn register_catalog_provider_list(
637        &self,
638        mut provider: Bound<PyAny>,
639    ) -> PyDataFusionResult<()> {
640        if provider.hasattr("__datafusion_catalog_provider_list__")? {
641            let py = provider.py();
642            let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
643            provider = provider
644                .getattr("__datafusion_catalog_provider_list__")?
645                .call1((codec_capsule,))?;
646        }
647
648        let provider =
649            if let Ok(capsule) = provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
650                validate_pycapsule(capsule, "datafusion_catalog_provider_list")?;
651
652                let provider = unsafe { capsule.reference::<FFI_CatalogProviderList>() };
653                let provider: Arc<dyn CatalogProviderList + Send> = provider.into();
654                provider as Arc<dyn CatalogProviderList>
655            } else {
656                match provider.extract::<PyCatalogList>() {
657                    Ok(py_catalog_list) => py_catalog_list.catalog_list,
658                    Err(_) => Arc::new(RustWrappedPyCatalogProviderList::new(
659                        provider.into(),
660                        Arc::clone(&self.logical_codec),
661                    )) as Arc<dyn CatalogProviderList>,
662                }
663            };
664
665        self.ctx.register_catalog_list(provider);
666
667        Ok(())
668    }
669
670    pub fn register_catalog_provider(
671        &self,
672        name: &str,
673        mut provider: Bound<'_, PyAny>,
674    ) -> PyDataFusionResult<()> {
675        if provider.hasattr("__datafusion_catalog_provider__")? {
676            let py = provider.py();
677            let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
678            provider = provider
679                .getattr("__datafusion_catalog_provider__")?
680                .call1((codec_capsule,))?;
681        }
682
683        let provider =
684            if let Ok(capsule) = provider.downcast::<PyCapsule>().map_err(py_datafusion_err) {
685                validate_pycapsule(capsule, "datafusion_catalog_provider")?;
686
687                let provider = unsafe { capsule.reference::<FFI_CatalogProvider>() };
688                let provider: Arc<dyn CatalogProvider + Send> = provider.into();
689                provider as Arc<dyn CatalogProvider>
690            } else {
691                match provider.extract::<PyCatalog>() {
692                    Ok(py_catalog) => py_catalog.catalog,
693                    Err(_) => Arc::new(RustWrappedPyCatalogProvider::new(
694                        provider.into(),
695                        Arc::clone(&self.logical_codec),
696                    )) as Arc<dyn CatalogProvider>,
697                }
698            };
699
700        let _ = self.ctx.register_catalog(name, provider);
701
702        Ok(())
703    }
704
705    /// Construct datafusion dataframe from Arrow Table
706    pub fn register_table_provider(
707        &self,
708        name: &str,
709        provider: Bound<'_, PyAny>,
710    ) -> PyDataFusionResult<()> {
711        // Deprecated: use `register_table` instead
712        self.register_table(name, provider)
713    }
714
715    pub fn register_record_batches(
716        &self,
717        name: &str,
718        partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
719    ) -> PyDataFusionResult<()> {
720        let schema = partitions.0[0][0].schema();
721        let table = MemTable::try_new(schema, partitions.0)?;
722        self.ctx.register_table(name, Arc::new(table))?;
723        Ok(())
724    }
725
726    #[allow(clippy::too_many_arguments)]
727    #[pyo3(signature = (name, path, table_partition_cols=vec![],
728                        parquet_pruning=true,
729                        file_extension=".parquet",
730                        skip_metadata=true,
731                        schema=None,
732                        file_sort_order=None))]
733    pub fn register_parquet(
734        &self,
735        name: &str,
736        path: &str,
737        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
738        parquet_pruning: bool,
739        file_extension: &str,
740        skip_metadata: bool,
741        schema: Option<PyArrowType<Schema>>,
742        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
743        py: Python,
744    ) -> PyDataFusionResult<()> {
745        let mut options = ParquetReadOptions::default()
746            .table_partition_cols(
747                table_partition_cols
748                    .into_iter()
749                    .map(|(name, ty)| (name, ty.0))
750                    .collect::<Vec<(String, DataType)>>(),
751            )
752            .parquet_pruning(parquet_pruning)
753            .skip_metadata(skip_metadata);
754        options.file_extension = file_extension;
755        options.schema = schema.as_ref().map(|x| &x.0);
756        options.file_sort_order = file_sort_order
757            .unwrap_or_default()
758            .into_iter()
759            .map(|e| e.into_iter().map(|f| f.into()).collect())
760            .collect();
761
762        let result = self.ctx.register_parquet(name, path, options);
763        wait_for_future(py, result)??;
764        Ok(())
765    }
766
767    #[pyo3(signature = (name,
768                        path,
769                        options=None))]
770    pub fn register_csv(
771        &self,
772        name: &str,
773        path: &Bound<'_, PyAny>,
774        options: Option<&PyCsvReadOptions>,
775        py: Python,
776    ) -> PyDataFusionResult<()> {
777        let options = options
778            .map(|opts| opts.try_into())
779            .transpose()?
780            .unwrap_or_default();
781
782        if path.is_instance_of::<PyList>() {
783            let paths = path.extract::<Vec<String>>()?;
784            let result = self.register_csv_from_multiple_paths(name, paths, options);
785            wait_for_future(py, result)??;
786        } else {
787            let path = path.extract::<String>()?;
788            let result = self.ctx.register_csv(name, &path, options);
789            wait_for_future(py, result)??;
790        }
791
792        Ok(())
793    }
794
795    #[allow(clippy::too_many_arguments)]
796    #[pyo3(signature = (name,
797                        path,
798                        schema=None,
799                        schema_infer_max_records=1000,
800                        file_extension=".json",
801                        table_partition_cols=vec![],
802                        file_compression_type=None))]
803    pub fn register_json(
804        &self,
805        name: &str,
806        path: PathBuf,
807        schema: Option<PyArrowType<Schema>>,
808        schema_infer_max_records: usize,
809        file_extension: &str,
810        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
811        file_compression_type: Option<String>,
812        py: Python,
813    ) -> PyDataFusionResult<()> {
814        let path = path
815            .to_str()
816            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
817
818        let mut options = NdJsonReadOptions::default()
819            .file_compression_type(parse_file_compression_type(file_compression_type)?)
820            .table_partition_cols(
821                table_partition_cols
822                    .into_iter()
823                    .map(|(name, ty)| (name, ty.0))
824                    .collect::<Vec<(String, DataType)>>(),
825            );
826        options.schema_infer_max_records = schema_infer_max_records;
827        options.file_extension = file_extension;
828        options.schema = schema.as_ref().map(|x| &x.0);
829
830        let result = self.ctx.register_json(name, path, options);
831        wait_for_future(py, result)??;
832
833        Ok(())
834    }
835
836    #[allow(clippy::too_many_arguments)]
837    #[pyo3(signature = (name,
838                        path,
839                        schema=None,
840                        file_extension=".avro",
841                        table_partition_cols=vec![]))]
842    pub fn register_avro(
843        &self,
844        name: &str,
845        path: PathBuf,
846        schema: Option<PyArrowType<Schema>>,
847        file_extension: &str,
848        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
849        py: Python,
850    ) -> PyDataFusionResult<()> {
851        let path = path
852            .to_str()
853            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
854
855        let mut options = AvroReadOptions::default().table_partition_cols(
856            table_partition_cols
857                .into_iter()
858                .map(|(name, ty)| (name, ty.0))
859                .collect::<Vec<(String, DataType)>>(),
860        );
861        options.file_extension = file_extension;
862        options.schema = schema.as_ref().map(|x| &x.0);
863
864        let result = self.ctx.register_avro(name, path, options);
865        wait_for_future(py, result)??;
866
867        Ok(())
868    }
869
870    // Registers a PyArrow.Dataset
871    pub fn register_dataset(
872        &self,
873        name: &str,
874        dataset: &Bound<'_, PyAny>,
875        py: Python,
876    ) -> PyDataFusionResult<()> {
877        let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
878
879        self.ctx.register_table(name, table)?;
880
881        Ok(())
882    }
883
884    pub fn register_udf(&self, udf: PyScalarUDF) -> PyResult<()> {
885        self.ctx.register_udf(udf.function);
886        Ok(())
887    }
888
889    pub fn register_udaf(&self, udaf: PyAggregateUDF) -> PyResult<()> {
890        self.ctx.register_udaf(udaf.function);
891        Ok(())
892    }
893
894    pub fn register_udwf(&self, udwf: PyWindowUDF) -> PyResult<()> {
895        self.ctx.register_udwf(udwf.function);
896        Ok(())
897    }
898
899    #[pyo3(signature = (name="datafusion"))]
900    pub fn catalog(&self, py: Python, name: &str) -> PyResult<Py<PyAny>> {
901        let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!(
902            "Catalog with name {name} doesn't exist."
903        )))?;
904
905        match catalog
906            .as_any()
907            .downcast_ref::<RustWrappedPyCatalogProvider>()
908        {
909            Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)),
910            None => Ok(
911                PyCatalog::new_from_parts(catalog, Arc::clone(&self.logical_codec))
912                    .into_py_any(py)?,
913            ),
914        }
915    }
916
917    pub fn catalog_names(&self) -> HashSet<String> {
918        self.ctx.catalog_names().into_iter().collect()
919    }
920
921    pub fn tables(&self) -> HashSet<String> {
922        self.ctx
923            .catalog_names()
924            .into_iter()
925            .filter_map(|name| self.ctx.catalog(&name))
926            .flat_map(move |catalog| {
927                catalog
928                    .schema_names()
929                    .into_iter()
930                    .filter_map(move |name| catalog.schema(&name))
931            })
932            .flat_map(|schema| schema.table_names())
933            .collect()
934    }
935
936    pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
937        let res = wait_for_future(py, self.ctx.table(name))
938            .map_err(|e| PyKeyError::new_err(e.to_string()))?;
939        match res {
940            Ok(df) => Ok(PyDataFrame::new(df)),
941            Err(e) => {
942                if let datafusion::error::DataFusionError::Plan(msg) = &e
943                    && msg.contains("No table named")
944                {
945                    return Err(PyKeyError::new_err(msg.to_string()));
946                }
947                Err(py_datafusion_err(e))
948            }
949        }
950    }
951
952    pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
953        Ok(self.ctx.table_exist(name)?)
954    }
955
956    pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
957        Ok(PyDataFrame::new(self.ctx.read_empty()?))
958    }
959
960    pub fn session_id(&self) -> String {
961        self.ctx.session_id()
962    }
963
964    #[allow(clippy::too_many_arguments)]
965    #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
966    pub fn read_json(
967        &self,
968        path: PathBuf,
969        schema: Option<PyArrowType<Schema>>,
970        schema_infer_max_records: usize,
971        file_extension: &str,
972        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
973        file_compression_type: Option<String>,
974        py: Python,
975    ) -> PyDataFusionResult<PyDataFrame> {
976        let path = path
977            .to_str()
978            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
979        let mut options = NdJsonReadOptions::default()
980            .table_partition_cols(
981                table_partition_cols
982                    .into_iter()
983                    .map(|(name, ty)| (name, ty.0))
984                    .collect::<Vec<(String, DataType)>>(),
985            )
986            .file_compression_type(parse_file_compression_type(file_compression_type)?);
987        options.schema_infer_max_records = schema_infer_max_records;
988        options.file_extension = file_extension;
989        let df = if let Some(schema) = schema {
990            options.schema = Some(&schema.0);
991            let result = self.ctx.read_json(path, options);
992            wait_for_future(py, result)??
993        } else {
994            let result = self.ctx.read_json(path, options);
995            wait_for_future(py, result)??
996        };
997        Ok(PyDataFrame::new(df))
998    }
999
1000    #[pyo3(signature = (
1001        path,
1002        options=None))]
1003    pub fn read_csv(
1004        &self,
1005        path: &Bound<'_, PyAny>,
1006        options: Option<&PyCsvReadOptions>,
1007        py: Python,
1008    ) -> PyDataFusionResult<PyDataFrame> {
1009        let options = options
1010            .map(|opts| opts.try_into())
1011            .transpose()?
1012            .unwrap_or_default();
1013
1014        if path.is_instance_of::<PyList>() {
1015            let paths = path.extract::<Vec<String>>()?;
1016            let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
1017            let result = self.ctx.read_csv(paths, options);
1018            let df = PyDataFrame::new(wait_for_future(py, result)??);
1019            Ok(df)
1020        } else {
1021            let path = path.extract::<String>()?;
1022            let result = self.ctx.read_csv(path, options);
1023            let df = PyDataFrame::new(wait_for_future(py, result)??);
1024            Ok(df)
1025        }
1026    }
1027
1028    #[allow(clippy::too_many_arguments)]
1029    #[pyo3(signature = (
1030        path,
1031        table_partition_cols=vec![],
1032        parquet_pruning=true,
1033        file_extension=".parquet",
1034        skip_metadata=true,
1035        schema=None,
1036        file_sort_order=None))]
1037    pub fn read_parquet(
1038        &self,
1039        path: &str,
1040        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1041        parquet_pruning: bool,
1042        file_extension: &str,
1043        skip_metadata: bool,
1044        schema: Option<PyArrowType<Schema>>,
1045        file_sort_order: Option<Vec<Vec<PySortExpr>>>,
1046        py: Python,
1047    ) -> PyDataFusionResult<PyDataFrame> {
1048        let mut options = ParquetReadOptions::default()
1049            .table_partition_cols(
1050                table_partition_cols
1051                    .into_iter()
1052                    .map(|(name, ty)| (name, ty.0))
1053                    .collect::<Vec<(String, DataType)>>(),
1054            )
1055            .parquet_pruning(parquet_pruning)
1056            .skip_metadata(skip_metadata);
1057        options.file_extension = file_extension;
1058        options.schema = schema.as_ref().map(|x| &x.0);
1059        options.file_sort_order = file_sort_order
1060            .unwrap_or_default()
1061            .into_iter()
1062            .map(|e| e.into_iter().map(|f| f.into()).collect())
1063            .collect();
1064
1065        let result = self.ctx.read_parquet(path, options);
1066        let df = PyDataFrame::new(wait_for_future(py, result)??);
1067        Ok(df)
1068    }
1069
1070    #[allow(clippy::too_many_arguments)]
1071    #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
1072    pub fn read_avro(
1073        &self,
1074        path: &str,
1075        schema: Option<PyArrowType<Schema>>,
1076        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
1077        file_extension: &str,
1078        py: Python,
1079    ) -> PyDataFusionResult<PyDataFrame> {
1080        let mut options = AvroReadOptions::default().table_partition_cols(
1081            table_partition_cols
1082                .into_iter()
1083                .map(|(name, ty)| (name, ty.0))
1084                .collect::<Vec<(String, DataType)>>(),
1085        );
1086        options.file_extension = file_extension;
1087        let df = if let Some(schema) = schema {
1088            options.schema = Some(&schema.0);
1089            let read_future = self.ctx.read_avro(path, options);
1090            wait_for_future(py, read_future)??
1091        } else {
1092            let read_future = self.ctx.read_avro(path, options);
1093            wait_for_future(py, read_future)??
1094        };
1095        Ok(PyDataFrame::new(df))
1096    }
1097
1098    pub fn read_table(&self, table: Bound<'_, PyAny>) -> PyDataFusionResult<PyDataFrame> {
1099        let session = self.clone().into_bound_py_any(table.py())?;
1100        let table = PyTable::new(table, Some(session))?;
1101        let df = self.ctx.read_table(table.table())?;
1102        Ok(PyDataFrame::new(df))
1103    }
1104
1105    fn __repr__(&self) -> PyResult<String> {
1106        let config = self.ctx.copied_config();
1107        let mut config_entries = config
1108            .options()
1109            .entries()
1110            .iter()
1111            .filter(|e| e.value.is_some())
1112            .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
1113            .collect::<Vec<_>>();
1114        config_entries.sort();
1115        Ok(format!(
1116            "SessionContext: id={}; configs=[\n\t{}]",
1117            self.session_id(),
1118            config_entries.join("\n\t")
1119        ))
1120    }
1121
1122    /// Execute a partition of an execution plan and return a stream of record batches
1123    pub fn execute(
1124        &self,
1125        plan: PyExecutionPlan,
1126        part: usize,
1127        py: Python,
1128    ) -> PyDataFusionResult<PyRecordBatchStream> {
1129        let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1130        let plan = plan.plan.clone();
1131        let stream = spawn_future(py, async move { plan.execute(part, Arc::new(ctx)) })?;
1132        Ok(PyRecordBatchStream::new(stream))
1133    }
1134
1135    pub fn __datafusion_task_context_provider__<'py>(
1136        &self,
1137        py: Python<'py>,
1138    ) -> PyResult<Bound<'py, PyCapsule>> {
1139        let name = cr"datafusion_task_context_provider".into();
1140
1141        let ctx_provider = Arc::clone(&self.ctx) as Arc<dyn TaskContextProvider>;
1142        let ffi_ctx_provider = FFI_TaskContextProvider::from(&ctx_provider);
1143
1144        PyCapsule::new(py, ffi_ctx_provider, Some(name))
1145    }
1146
1147    pub fn __datafusion_logical_extension_codec__<'py>(
1148        &self,
1149        py: Python<'py>,
1150    ) -> PyResult<Bound<'py, PyCapsule>> {
1151        create_logical_extension_capsule(py, self.logical_codec.as_ref())
1152    }
1153
1154    pub fn with_logical_extension_codec<'py>(
1155        &self,
1156        codec: Bound<'py, PyAny>,
1157    ) -> PyDataFusionResult<Self> {
1158        let py = codec.py();
1159        let logical_codec = extract_logical_extension_codec(py, Some(codec))?;
1160
1161        Ok({
1162            Self {
1163                ctx: Arc::clone(&self.ctx),
1164                logical_codec,
1165            }
1166        })
1167    }
1168}
1169
1170impl PySessionContext {
1171    async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1172        self.ctx.table(name).await
1173    }
1174
1175    async fn register_csv_from_multiple_paths(
1176        &self,
1177        name: &str,
1178        table_paths: Vec<String>,
1179        options: CsvReadOptions<'_>,
1180    ) -> datafusion::common::Result<()> {
1181        let table_paths = table_paths.to_urls()?;
1182        let session_config = self.ctx.copied_config();
1183        let listing_options =
1184            options.to_listing_options(&session_config, self.ctx.copied_table_options());
1185
1186        let option_extension = listing_options.file_extension.clone();
1187
1188        if table_paths.is_empty() {
1189            return exec_err!("No table paths were provided");
1190        }
1191
1192        // check if the file extension matches the expected extension
1193        for path in &table_paths {
1194            let file_path = path.as_str();
1195            if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1196                return exec_err!(
1197                    "File path '{file_path}' does not match the expected extension '{option_extension}'"
1198                );
1199            }
1200        }
1201
1202        let resolved_schema = options
1203            .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1204            .await?;
1205
1206        let config = ListingTableConfig::new_with_multi_paths(table_paths)
1207            .with_listing_options(listing_options)
1208            .with_schema(resolved_schema);
1209        let table = ListingTable::try_new(config)?;
1210        self.ctx
1211            .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1212        Ok(())
1213    }
1214
1215    fn default_logical_codec(ctx: &Arc<SessionContext>) -> Arc<FFI_LogicalExtensionCodec> {
1216        let codec = Arc::new(DefaultLogicalExtensionCodec {});
1217        let runtime = get_tokio_runtime().0.handle().clone();
1218        let ctx_provider = Arc::clone(ctx) as Arc<dyn TaskContextProvider>;
1219        Arc::new(FFI_LogicalExtensionCodec::new(
1220            codec,
1221            Some(runtime),
1222            &ctx_provider,
1223        ))
1224    }
1225}
1226
1227pub fn parse_file_compression_type(
1228    file_compression_type: Option<String>,
1229) -> Result<FileCompressionType, PyErr> {
1230    FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1231        .map_err(|_| {
1232            PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1233        })
1234}
1235
1236impl From<PySessionContext> for SessionContext {
1237    fn from(ctx: PySessionContext) -> SessionContext {
1238        ctx.ctx.as_ref().clone()
1239    }
1240}
1241
1242impl From<SessionContext> for PySessionContext {
1243    fn from(ctx: SessionContext) -> PySessionContext {
1244        let ctx = Arc::new(ctx);
1245        let logical_codec = Self::default_logical_codec(&ctx);
1246
1247        PySessionContext { ctx, logical_codec }
1248    }
1249}