1use std::collections::{HashMap, HashSet};
19use std::path::PathBuf;
20use std::str::FromStr;
21use std::sync::Arc;
22
23use arrow::array::RecordBatchReader;
24use arrow::ffi_stream::ArrowArrayStreamReader;
25use arrow::pyarrow::FromPyArrow;
26use datafusion::execution::session_state::SessionStateBuilder;
27use object_store::ObjectStore;
28use url::Url;
29use uuid::Uuid;
30
31use pyo3::exceptions::{PyKeyError, PyValueError};
32use pyo3::prelude::*;
33
34use crate::catalog::{PyCatalog, PyTable};
35use crate::dataframe::PyDataFrame;
36use crate::dataset::Dataset;
37use crate::errors::{py_datafusion_err, PyDataFusionResult};
38use crate::expr::sort_expr::PySortExpr;
39use crate::physical_plan::PyExecutionPlan;
40use crate::record_batch::PyRecordBatchStream;
41use crate::sql::exceptions::py_value_err;
42use crate::sql::logical::PyLogicalPlan;
43use crate::store::StorageContexts;
44use crate::udaf::PyAggregateUDF;
45use crate::udf::PyScalarUDF;
46use crate::udwf::PyWindowUDF;
47use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
48use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
49use datafusion::arrow::pyarrow::PyArrowType;
50use datafusion::arrow::record_batch::RecordBatch;
51use datafusion::common::TableReference;
52use datafusion::common::{exec_err, ScalarValue};
53use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
54use datafusion::datasource::file_format::parquet::ParquetFormat;
55use datafusion::datasource::listing::{
56 ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
57};
58use datafusion::datasource::MemTable;
59use datafusion::datasource::TableProvider;
60use datafusion::execution::context::{
61 DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
62};
63use datafusion::execution::disk_manager::DiskManagerConfig;
64use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
65use datafusion::execution::options::ReadOptions;
66use datafusion::execution::runtime_env::RuntimeEnvBuilder;
67use datafusion::physical_plan::SendableRecordBatchStream;
68use datafusion::prelude::{
69 AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
70};
71use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
72use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
73use tokio::task::JoinHandle;
74
75#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
77#[derive(Clone, Default)]
78pub struct PySessionConfig {
79 pub config: SessionConfig,
80}
81
82impl From<SessionConfig> for PySessionConfig {
83 fn from(config: SessionConfig) -> Self {
84 Self { config }
85 }
86}
87
88#[pymethods]
89impl PySessionConfig {
90 #[pyo3(signature = (config_options=None))]
91 #[new]
92 fn new(config_options: Option<HashMap<String, String>>) -> Self {
93 let mut config = SessionConfig::new();
94 if let Some(hash_map) = config_options {
95 for (k, v) in &hash_map {
96 config = config.set(k, &ScalarValue::Utf8(Some(v.clone())));
97 }
98 }
99
100 Self { config }
101 }
102
103 fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self {
104 Self::from(
105 self.config
106 .clone()
107 .with_create_default_catalog_and_schema(enabled),
108 )
109 }
110
111 fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self {
112 Self::from(
113 self.config
114 .clone()
115 .with_default_catalog_and_schema(catalog, schema),
116 )
117 }
118
119 fn with_information_schema(&self, enabled: bool) -> Self {
120 Self::from(self.config.clone().with_information_schema(enabled))
121 }
122
123 fn with_batch_size(&self, batch_size: usize) -> Self {
124 Self::from(self.config.clone().with_batch_size(batch_size))
125 }
126
127 fn with_target_partitions(&self, target_partitions: usize) -> Self {
128 Self::from(
129 self.config
130 .clone()
131 .with_target_partitions(target_partitions),
132 )
133 }
134
135 fn with_repartition_aggregations(&self, enabled: bool) -> Self {
136 Self::from(self.config.clone().with_repartition_aggregations(enabled))
137 }
138
139 fn with_repartition_joins(&self, enabled: bool) -> Self {
140 Self::from(self.config.clone().with_repartition_joins(enabled))
141 }
142
143 fn with_repartition_windows(&self, enabled: bool) -> Self {
144 Self::from(self.config.clone().with_repartition_windows(enabled))
145 }
146
147 fn with_repartition_sorts(&self, enabled: bool) -> Self {
148 Self::from(self.config.clone().with_repartition_sorts(enabled))
149 }
150
151 fn with_repartition_file_scans(&self, enabled: bool) -> Self {
152 Self::from(self.config.clone().with_repartition_file_scans(enabled))
153 }
154
155 fn with_repartition_file_min_size(&self, size: usize) -> Self {
156 Self::from(self.config.clone().with_repartition_file_min_size(size))
157 }
158
159 fn with_parquet_pruning(&self, enabled: bool) -> Self {
160 Self::from(self.config.clone().with_parquet_pruning(enabled))
161 }
162
163 fn set(&self, key: &str, value: &str) -> Self {
164 Self::from(self.config.clone().set_str(key, value))
165 }
166}
167
168#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)]
170#[derive(Clone)]
171pub struct PyRuntimeEnvBuilder {
172 pub builder: RuntimeEnvBuilder,
173}
174
175#[pymethods]
176impl PyRuntimeEnvBuilder {
177 #[new]
178 fn new() -> Self {
179 Self {
180 builder: RuntimeEnvBuilder::default(),
181 }
182 }
183
184 fn with_disk_manager_disabled(&self) -> Self {
185 let mut builder = self.builder.clone();
186 builder = builder.with_disk_manager(DiskManagerConfig::Disabled);
187 Self { builder }
188 }
189
190 fn with_disk_manager_os(&self) -> Self {
191 let builder = self.builder.clone();
192 let builder = builder.with_disk_manager(DiskManagerConfig::NewOs);
193 Self { builder }
194 }
195
196 fn with_disk_manager_specified(&self, paths: Vec<String>) -> Self {
197 let builder = self.builder.clone();
198 let paths = paths.iter().map(|s| s.into()).collect();
199 let builder = builder.with_disk_manager(DiskManagerConfig::NewSpecified(paths));
200 Self { builder }
201 }
202
203 fn with_unbounded_memory_pool(&self) -> Self {
204 let builder = self.builder.clone();
205 let builder = builder.with_memory_pool(Arc::new(UnboundedMemoryPool::default()));
206 Self { builder }
207 }
208
209 fn with_fair_spill_pool(&self, size: usize) -> Self {
210 let builder = self.builder.clone();
211 let builder = builder.with_memory_pool(Arc::new(FairSpillPool::new(size)));
212 Self { builder }
213 }
214
215 fn with_greedy_memory_pool(&self, size: usize) -> Self {
216 let builder = self.builder.clone();
217 let builder = builder.with_memory_pool(Arc::new(GreedyMemoryPool::new(size)));
218 Self { builder }
219 }
220
221 fn with_temp_file_path(&self, path: &str) -> Self {
222 let builder = self.builder.clone();
223 let builder = builder.with_temp_file_path(path);
224 Self { builder }
225 }
226}
227
228#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
230#[derive(Clone)]
231pub struct PySQLOptions {
232 pub options: SQLOptions,
233}
234
235impl From<SQLOptions> for PySQLOptions {
236 fn from(options: SQLOptions) -> Self {
237 Self { options }
238 }
239}
240
241#[pymethods]
242impl PySQLOptions {
243 #[new]
244 fn new() -> Self {
245 let options = SQLOptions::new();
246 Self { options }
247 }
248
249 fn with_allow_ddl(&self, allow: bool) -> Self {
251 Self::from(self.options.with_allow_ddl(allow))
252 }
253
254 pub fn with_allow_dml(&self, allow: bool) -> Self {
256 Self::from(self.options.with_allow_dml(allow))
257 }
258
259 pub fn with_allow_statements(&self, allow: bool) -> Self {
261 Self::from(self.options.with_allow_statements(allow))
262 }
263}
264
265#[pyclass(name = "SessionContext", module = "datafusion", subclass)]
269#[derive(Clone)]
270pub struct PySessionContext {
271 pub ctx: SessionContext,
272}
273
274#[pymethods]
275impl PySessionContext {
276 #[pyo3(signature = (config=None, runtime=None))]
277 #[new]
278 pub fn new(
279 config: Option<PySessionConfig>,
280 runtime: Option<PyRuntimeEnvBuilder>,
281 ) -> PyDataFusionResult<Self> {
282 let config = if let Some(c) = config {
283 c.config
284 } else {
285 SessionConfig::default().with_information_schema(true)
286 };
287 let runtime_env_builder = if let Some(c) = runtime {
288 c.builder
289 } else {
290 RuntimeEnvBuilder::default()
291 };
292 let runtime = Arc::new(runtime_env_builder.build()?);
293 let session_state = SessionStateBuilder::new()
294 .with_config(config)
295 .with_runtime_env(runtime)
296 .with_default_features()
297 .build();
298 Ok(PySessionContext {
299 ctx: SessionContext::new_with_state(session_state),
300 })
301 }
302
303 pub fn enable_url_table(&self) -> PyResult<Self> {
304 Ok(PySessionContext {
305 ctx: self.ctx.clone().enable_url_table(),
306 })
307 }
308
309 #[classmethod]
310 #[pyo3(signature = ())]
311 fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
312 Ok(Self {
313 ctx: get_global_ctx().clone(),
314 })
315 }
316
317 #[pyo3(signature = (scheme, store, host=None))]
319 pub fn register_object_store(
320 &mut self,
321 scheme: &str,
322 store: StorageContexts,
323 host: Option<&str>,
324 ) -> PyResult<()> {
325 let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
327 StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
328 StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
329 StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
330 StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
331 StorageContexts::HTTP(http) => (http.store, http.url),
332 };
333
334 let derived_host = if let Some(host) = host {
336 host
337 } else {
338 &upstream_host
339 };
340 let url_string = format!("{}{}", scheme, derived_host);
341 let url = Url::parse(&url_string).unwrap();
342 self.ctx.runtime_env().register_object_store(&url, store);
343 Ok(())
344 }
345
346 #[allow(clippy::too_many_arguments)]
347 #[pyo3(signature = (name, path, table_partition_cols=vec![],
348 file_extension=".parquet",
349 schema=None,
350 file_sort_order=None))]
351 pub fn register_listing_table(
352 &mut self,
353 name: &str,
354 path: &str,
355 table_partition_cols: Vec<(String, String)>,
356 file_extension: &str,
357 schema: Option<PyArrowType<Schema>>,
358 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
359 py: Python,
360 ) -> PyDataFusionResult<()> {
361 let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
362 .with_file_extension(file_extension)
363 .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
364 .with_file_sort_order(
365 file_sort_order
366 .unwrap_or_default()
367 .into_iter()
368 .map(|e| e.into_iter().map(|f| f.into()).collect())
369 .collect(),
370 );
371 let table_path = ListingTableUrl::parse(path)?;
372 let resolved_schema: SchemaRef = match schema {
373 Some(s) => Arc::new(s.0),
374 None => {
375 let state = self.ctx.state();
376 let schema = options.infer_schema(&state, &table_path);
377 wait_for_future(py, schema)?
378 }
379 };
380 let config = ListingTableConfig::new(table_path)
381 .with_listing_options(options)
382 .with_schema(resolved_schema);
383 let table = ListingTable::try_new(config)?;
384 self.register_table(
385 name,
386 &PyTable {
387 table: Arc::new(table),
388 },
389 )?;
390 Ok(())
391 }
392
393 pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
395 let result = self.ctx.sql(query);
396 let df = wait_for_future(py, result)?;
397 Ok(PyDataFrame::new(df))
398 }
399
400 #[pyo3(signature = (query, options=None))]
401 pub fn sql_with_options(
402 &mut self,
403 query: &str,
404 options: Option<PySQLOptions>,
405 py: Python,
406 ) -> PyDataFusionResult<PyDataFrame> {
407 let options = if let Some(options) = options {
408 options.options
409 } else {
410 SQLOptions::new()
411 };
412 let result = self.ctx.sql_with_options(query, options);
413 let df = wait_for_future(py, result)?;
414 Ok(PyDataFrame::new(df))
415 }
416
417 #[pyo3(signature = (partitions, name=None, schema=None))]
418 pub fn create_dataframe(
419 &mut self,
420 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
421 name: Option<&str>,
422 schema: Option<PyArrowType<Schema>>,
423 py: Python,
424 ) -> PyDataFusionResult<PyDataFrame> {
425 let schema = if let Some(schema) = schema {
426 SchemaRef::from(schema.0)
427 } else {
428 partitions.0[0][0].schema()
429 };
430
431 let table = MemTable::try_new(schema, partitions.0)?;
432
433 let table_name = match name {
436 Some(val) => val.to_owned(),
437 None => {
438 "c".to_owned()
439 + Uuid::new_v4()
440 .simple()
441 .encode_lower(&mut Uuid::encode_buffer())
442 }
443 };
444
445 self.ctx.register_table(&*table_name, Arc::new(table))?;
446
447 let table = wait_for_future(py, self._table(&table_name))?;
448
449 let df = PyDataFrame::new(table);
450 Ok(df)
451 }
452
453 pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
455 PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
456 }
457
458 #[pyo3(signature = (data, name=None))]
460 pub fn from_pylist(
461 &mut self,
462 data: Bound<'_, PyList>,
463 name: Option<&str>,
464 ) -> PyResult<PyDataFrame> {
465 let py = data.py();
467
468 let table_class = py.import("pyarrow")?.getattr("Table")?;
470 let args = PyTuple::new(py, &[data])?;
471 let table = table_class.call_method1("from_pylist", args)?;
472
473 let df = self.from_arrow(table, name, py)?;
475 Ok(df)
476 }
477
478 #[pyo3(signature = (data, name=None))]
480 pub fn from_pydict(
481 &mut self,
482 data: Bound<'_, PyDict>,
483 name: Option<&str>,
484 ) -> PyResult<PyDataFrame> {
485 let py = data.py();
487
488 let table_class = py.import("pyarrow")?.getattr("Table")?;
490 let args = PyTuple::new(py, &[data])?;
491 let table = table_class.call_method1("from_pydict", args)?;
492
493 let df = self.from_arrow(table, name, py)?;
495 Ok(df)
496 }
497
498 #[pyo3(signature = (data, name=None))]
500 pub fn from_arrow(
501 &mut self,
502 data: Bound<'_, PyAny>,
503 name: Option<&str>,
504 py: Python,
505 ) -> PyDataFusionResult<PyDataFrame> {
506 let (schema, batches) =
507 if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) {
508 let schema = stream_reader.schema().as_ref().to_owned();
511 let batches = stream_reader
512 .collect::<Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
513
514 (schema, batches)
515 } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
516 (array.schema().as_ref().to_owned(), vec![array])
520 } else {
521 return Err(crate::errors::PyDataFusionError::Common(
522 "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(),
523 ));
524 };
525
526 let list_of_batches = PyArrowType::from(vec![batches]);
529 self.create_dataframe(list_of_batches, name, Some(schema.into()), py)
530 }
531
532 #[allow(clippy::wrong_self_convention)]
534 #[pyo3(signature = (data, name=None))]
535 pub fn from_pandas(
536 &mut self,
537 data: Bound<'_, PyAny>,
538 name: Option<&str>,
539 ) -> PyResult<PyDataFrame> {
540 let py = data.py();
542
543 let table_class = py.import("pyarrow")?.getattr("Table")?;
545 let args = PyTuple::new(py, &[data])?;
546 let table = table_class.call_method1("from_pandas", args)?;
547
548 let df = self.from_arrow(table, name, py)?;
550 Ok(df)
551 }
552
553 #[pyo3(signature = (data, name=None))]
555 pub fn from_polars(
556 &mut self,
557 data: Bound<'_, PyAny>,
558 name: Option<&str>,
559 ) -> PyResult<PyDataFrame> {
560 let table = data.call_method0("to_arrow")?;
562
563 let df = self.from_arrow(table, name, data.py())?;
565 Ok(df)
566 }
567
568 pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> {
569 self.ctx.register_table(name, table.table())?;
570 Ok(())
571 }
572
573 pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
574 self.ctx.deregister_table(name)?;
575 Ok(())
576 }
577
578 pub fn register_table_provider(
580 &mut self,
581 name: &str,
582 provider: Bound<'_, PyAny>,
583 ) -> PyDataFusionResult<()> {
584 if provider.hasattr("__datafusion_table_provider__")? {
585 let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
586 let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
587 validate_pycapsule(capsule, "datafusion_table_provider")?;
588
589 let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
590 let provider: ForeignTableProvider = provider.into();
591
592 let _ = self.ctx.register_table(name, Arc::new(provider))?;
593
594 Ok(())
595 } else {
596 Err(crate::errors::PyDataFusionError::Common(
597 "__datafusion_table_provider__ does not exist on Table Provider object."
598 .to_string(),
599 ))
600 }
601 }
602
603 pub fn register_record_batches(
604 &mut self,
605 name: &str,
606 partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
607 ) -> PyDataFusionResult<()> {
608 let schema = partitions.0[0][0].schema();
609 let table = MemTable::try_new(schema, partitions.0)?;
610 self.ctx.register_table(name, Arc::new(table))?;
611 Ok(())
612 }
613
614 #[allow(clippy::too_many_arguments)]
615 #[pyo3(signature = (name, path, table_partition_cols=vec![],
616 parquet_pruning=true,
617 file_extension=".parquet",
618 skip_metadata=true,
619 schema=None,
620 file_sort_order=None))]
621 pub fn register_parquet(
622 &mut self,
623 name: &str,
624 path: &str,
625 table_partition_cols: Vec<(String, String)>,
626 parquet_pruning: bool,
627 file_extension: &str,
628 skip_metadata: bool,
629 schema: Option<PyArrowType<Schema>>,
630 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
631 py: Python,
632 ) -> PyDataFusionResult<()> {
633 let mut options = ParquetReadOptions::default()
634 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
635 .parquet_pruning(parquet_pruning)
636 .skip_metadata(skip_metadata);
637 options.file_extension = file_extension;
638 options.schema = schema.as_ref().map(|x| &x.0);
639 options.file_sort_order = file_sort_order
640 .unwrap_or_default()
641 .into_iter()
642 .map(|e| e.into_iter().map(|f| f.into()).collect())
643 .collect();
644
645 let result = self.ctx.register_parquet(name, path, options);
646 wait_for_future(py, result)?;
647 Ok(())
648 }
649
650 #[allow(clippy::too_many_arguments)]
651 #[pyo3(signature = (name,
652 path,
653 schema=None,
654 has_header=true,
655 delimiter=",",
656 schema_infer_max_records=1000,
657 file_extension=".csv",
658 file_compression_type=None))]
659 pub fn register_csv(
660 &mut self,
661 name: &str,
662 path: &Bound<'_, PyAny>,
663 schema: Option<PyArrowType<Schema>>,
664 has_header: bool,
665 delimiter: &str,
666 schema_infer_max_records: usize,
667 file_extension: &str,
668 file_compression_type: Option<String>,
669 py: Python,
670 ) -> PyDataFusionResult<()> {
671 let delimiter = delimiter.as_bytes();
672 if delimiter.len() != 1 {
673 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
674 "Delimiter must be a single character",
675 )));
676 }
677
678 let mut options = CsvReadOptions::new()
679 .has_header(has_header)
680 .delimiter(delimiter[0])
681 .schema_infer_max_records(schema_infer_max_records)
682 .file_extension(file_extension)
683 .file_compression_type(parse_file_compression_type(file_compression_type)?);
684 options.schema = schema.as_ref().map(|x| &x.0);
685
686 if path.is_instance_of::<PyList>() {
687 let paths = path.extract::<Vec<String>>()?;
688 let result = self.register_csv_from_multiple_paths(name, paths, options);
689 wait_for_future(py, result)?;
690 } else {
691 let path = path.extract::<String>()?;
692 let result = self.ctx.register_csv(name, &path, options);
693 wait_for_future(py, result)?;
694 }
695
696 Ok(())
697 }
698
699 #[allow(clippy::too_many_arguments)]
700 #[pyo3(signature = (name,
701 path,
702 schema=None,
703 schema_infer_max_records=1000,
704 file_extension=".json",
705 table_partition_cols=vec![],
706 file_compression_type=None))]
707 pub fn register_json(
708 &mut self,
709 name: &str,
710 path: PathBuf,
711 schema: Option<PyArrowType<Schema>>,
712 schema_infer_max_records: usize,
713 file_extension: &str,
714 table_partition_cols: Vec<(String, String)>,
715 file_compression_type: Option<String>,
716 py: Python,
717 ) -> PyDataFusionResult<()> {
718 let path = path
719 .to_str()
720 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
721
722 let mut options = NdJsonReadOptions::default()
723 .file_compression_type(parse_file_compression_type(file_compression_type)?)
724 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
725 options.schema_infer_max_records = schema_infer_max_records;
726 options.file_extension = file_extension;
727 options.schema = schema.as_ref().map(|x| &x.0);
728
729 let result = self.ctx.register_json(name, path, options);
730 wait_for_future(py, result)?;
731
732 Ok(())
733 }
734
735 #[allow(clippy::too_many_arguments)]
736 #[pyo3(signature = (name,
737 path,
738 schema=None,
739 file_extension=".avro",
740 table_partition_cols=vec![]))]
741 pub fn register_avro(
742 &mut self,
743 name: &str,
744 path: PathBuf,
745 schema: Option<PyArrowType<Schema>>,
746 file_extension: &str,
747 table_partition_cols: Vec<(String, String)>,
748 py: Python,
749 ) -> PyDataFusionResult<()> {
750 let path = path
751 .to_str()
752 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
753
754 let mut options = AvroReadOptions::default()
755 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
756 options.file_extension = file_extension;
757 options.schema = schema.as_ref().map(|x| &x.0);
758
759 let result = self.ctx.register_avro(name, path, options);
760 wait_for_future(py, result)?;
761
762 Ok(())
763 }
764
765 pub fn register_dataset(
767 &self,
768 name: &str,
769 dataset: &Bound<'_, PyAny>,
770 py: Python,
771 ) -> PyDataFusionResult<()> {
772 let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
773
774 self.ctx.register_table(name, table)?;
775
776 Ok(())
777 }
778
779 pub fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
780 self.ctx.register_udf(udf.function);
781 Ok(())
782 }
783
784 pub fn register_udaf(&mut self, udaf: PyAggregateUDF) -> PyResult<()> {
785 self.ctx.register_udaf(udaf.function);
786 Ok(())
787 }
788
789 pub fn register_udwf(&mut self, udwf: PyWindowUDF) -> PyResult<()> {
790 self.ctx.register_udwf(udwf.function);
791 Ok(())
792 }
793
794 #[pyo3(signature = (name="datafusion"))]
795 pub fn catalog(&self, name: &str) -> PyResult<PyCatalog> {
796 match self.ctx.catalog(name) {
797 Some(catalog) => Ok(PyCatalog::new(catalog)),
798 None => Err(PyKeyError::new_err(format!(
799 "Catalog with name {} doesn't exist.",
800 &name,
801 ))),
802 }
803 }
804
805 pub fn tables(&self) -> HashSet<String> {
806 self.ctx
807 .catalog_names()
808 .into_iter()
809 .filter_map(|name| self.ctx.catalog(&name))
810 .flat_map(move |catalog| {
811 catalog
812 .schema_names()
813 .into_iter()
814 .filter_map(move |name| catalog.schema(&name))
815 })
816 .flat_map(|schema| schema.table_names())
817 .collect()
818 }
819
820 pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
821 let x = wait_for_future(py, self.ctx.table(name))
822 .map_err(|e| PyKeyError::new_err(e.to_string()))?;
823 Ok(PyDataFrame::new(x))
824 }
825
826 pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
827 Ok(self.ctx.table_exist(name)?)
828 }
829
830 pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
831 Ok(PyDataFrame::new(self.ctx.read_empty()?))
832 }
833
834 pub fn session_id(&self) -> String {
835 self.ctx.session_id()
836 }
837
838 #[allow(clippy::too_many_arguments)]
839 #[pyo3(signature = (path, schema=None, schema_infer_max_records=1000, file_extension=".json", table_partition_cols=vec![], file_compression_type=None))]
840 pub fn read_json(
841 &mut self,
842 path: PathBuf,
843 schema: Option<PyArrowType<Schema>>,
844 schema_infer_max_records: usize,
845 file_extension: &str,
846 table_partition_cols: Vec<(String, String)>,
847 file_compression_type: Option<String>,
848 py: Python,
849 ) -> PyDataFusionResult<PyDataFrame> {
850 let path = path
851 .to_str()
852 .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
853 let mut options = NdJsonReadOptions::default()
854 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
855 .file_compression_type(parse_file_compression_type(file_compression_type)?);
856 options.schema_infer_max_records = schema_infer_max_records;
857 options.file_extension = file_extension;
858 let df = if let Some(schema) = schema {
859 options.schema = Some(&schema.0);
860 let result = self.ctx.read_json(path, options);
861 wait_for_future(py, result)?
862 } else {
863 let result = self.ctx.read_json(path, options);
864 wait_for_future(py, result)?
865 };
866 Ok(PyDataFrame::new(df))
867 }
868
869 #[allow(clippy::too_many_arguments)]
870 #[pyo3(signature = (
871 path,
872 schema=None,
873 has_header=true,
874 delimiter=",",
875 schema_infer_max_records=1000,
876 file_extension=".csv",
877 table_partition_cols=vec![],
878 file_compression_type=None))]
879 pub fn read_csv(
880 &self,
881 path: &Bound<'_, PyAny>,
882 schema: Option<PyArrowType<Schema>>,
883 has_header: bool,
884 delimiter: &str,
885 schema_infer_max_records: usize,
886 file_extension: &str,
887 table_partition_cols: Vec<(String, String)>,
888 file_compression_type: Option<String>,
889 py: Python,
890 ) -> PyDataFusionResult<PyDataFrame> {
891 let delimiter = delimiter.as_bytes();
892 if delimiter.len() != 1 {
893 return Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
894 "Delimiter must be a single character",
895 )));
896 };
897
898 let mut options = CsvReadOptions::new()
899 .has_header(has_header)
900 .delimiter(delimiter[0])
901 .schema_infer_max_records(schema_infer_max_records)
902 .file_extension(file_extension)
903 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
904 .file_compression_type(parse_file_compression_type(file_compression_type)?);
905 options.schema = schema.as_ref().map(|x| &x.0);
906
907 if path.is_instance_of::<PyList>() {
908 let paths = path.extract::<Vec<String>>()?;
909 let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
910 let result = self.ctx.read_csv(paths, options);
911 let df = PyDataFrame::new(wait_for_future(py, result)?);
912 Ok(df)
913 } else {
914 let path = path.extract::<String>()?;
915 let result = self.ctx.read_csv(path, options);
916 let df = PyDataFrame::new(wait_for_future(py, result)?);
917 Ok(df)
918 }
919 }
920
921 #[allow(clippy::too_many_arguments)]
922 #[pyo3(signature = (
923 path,
924 table_partition_cols=vec![],
925 parquet_pruning=true,
926 file_extension=".parquet",
927 skip_metadata=true,
928 schema=None,
929 file_sort_order=None))]
930 pub fn read_parquet(
931 &self,
932 path: &str,
933 table_partition_cols: Vec<(String, String)>,
934 parquet_pruning: bool,
935 file_extension: &str,
936 skip_metadata: bool,
937 schema: Option<PyArrowType<Schema>>,
938 file_sort_order: Option<Vec<Vec<PySortExpr>>>,
939 py: Python,
940 ) -> PyDataFusionResult<PyDataFrame> {
941 let mut options = ParquetReadOptions::default()
942 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
943 .parquet_pruning(parquet_pruning)
944 .skip_metadata(skip_metadata);
945 options.file_extension = file_extension;
946 options.schema = schema.as_ref().map(|x| &x.0);
947 options.file_sort_order = file_sort_order
948 .unwrap_or_default()
949 .into_iter()
950 .map(|e| e.into_iter().map(|f| f.into()).collect())
951 .collect();
952
953 let result = self.ctx.read_parquet(path, options);
954 let df = PyDataFrame::new(wait_for_future(py, result)?);
955 Ok(df)
956 }
957
958 #[allow(clippy::too_many_arguments)]
959 #[pyo3(signature = (path, schema=None, table_partition_cols=vec![], file_extension=".avro"))]
960 pub fn read_avro(
961 &self,
962 path: &str,
963 schema: Option<PyArrowType<Schema>>,
964 table_partition_cols: Vec<(String, String)>,
965 file_extension: &str,
966 py: Python,
967 ) -> PyDataFusionResult<PyDataFrame> {
968 let mut options = AvroReadOptions::default()
969 .table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
970 options.file_extension = file_extension;
971 let df = if let Some(schema) = schema {
972 options.schema = Some(&schema.0);
973 let read_future = self.ctx.read_avro(path, options);
974 wait_for_future(py, read_future)?
975 } else {
976 let read_future = self.ctx.read_avro(path, options);
977 wait_for_future(py, read_future)?
978 };
979 Ok(PyDataFrame::new(df))
980 }
981
982 pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult<PyDataFrame> {
983 let df = self.ctx.read_table(table.table())?;
984 Ok(PyDataFrame::new(df))
985 }
986
987 fn __repr__(&self) -> PyResult<String> {
988 let config = self.ctx.copied_config();
989 let mut config_entries = config
990 .options()
991 .entries()
992 .iter()
993 .filter(|e| e.value.is_some())
994 .map(|e| format!("{} = {}", e.key, e.value.as_ref().unwrap()))
995 .collect::<Vec<_>>();
996 config_entries.sort();
997 Ok(format!(
998 "SessionContext: id={}; configs=[\n\t{}]",
999 self.session_id(),
1000 config_entries.join("\n\t")
1001 ))
1002 }
1003
1004 pub fn execute(
1006 &self,
1007 plan: PyExecutionPlan,
1008 part: usize,
1009 py: Python,
1010 ) -> PyDataFusionResult<PyRecordBatchStream> {
1011 let ctx: TaskContext = TaskContext::from(&self.ctx.state());
1012 let rt = &get_tokio_runtime().0;
1014 let plan = plan.plan.clone();
1015 let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
1016 rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1017 let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1018 Ok(PyRecordBatchStream::new(stream?))
1019 }
1020}
1021
1022impl PySessionContext {
1023 async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
1024 self.ctx.table(name).await
1025 }
1026
1027 async fn register_csv_from_multiple_paths(
1028 &self,
1029 name: &str,
1030 table_paths: Vec<String>,
1031 options: CsvReadOptions<'_>,
1032 ) -> datafusion::common::Result<()> {
1033 let table_paths = table_paths.to_urls()?;
1034 let session_config = self.ctx.copied_config();
1035 let listing_options =
1036 options.to_listing_options(&session_config, self.ctx.copied_table_options());
1037
1038 let option_extension = listing_options.file_extension.clone();
1039
1040 if table_paths.is_empty() {
1041 return exec_err!("No table paths were provided");
1042 }
1043
1044 for path in &table_paths {
1046 let file_path = path.as_str();
1047 if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1048 return exec_err!(
1049 "File path '{file_path}' does not match the expected extension '{option_extension}'"
1050 );
1051 }
1052 }
1053
1054 let resolved_schema = options
1055 .get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1056 .await?;
1057
1058 let config = ListingTableConfig::new_with_multi_paths(table_paths)
1059 .with_listing_options(listing_options)
1060 .with_schema(resolved_schema);
1061 let table = ListingTable::try_new(config)?;
1062 self.ctx
1063 .register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1064 Ok(())
1065 }
1066}
1067
1068pub fn convert_table_partition_cols(
1069 table_partition_cols: Vec<(String, String)>,
1070) -> PyDataFusionResult<Vec<(String, DataType)>> {
1071 table_partition_cols
1072 .into_iter()
1073 .map(|(name, ty)| match ty.as_str() {
1074 "string" => Ok((name, DataType::Utf8)),
1075 "int" => Ok((name, DataType::Int32)),
1076 _ => Err(crate::errors::PyDataFusionError::Common(format!(
1077 "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'"
1078 ))),
1079 })
1080 .collect::<Result<Vec<_>, _>>()
1081}
1082
1083pub fn parse_file_compression_type(
1084 file_compression_type: Option<String>,
1085) -> Result<FileCompressionType, PyErr> {
1086 FileCompressionType::from_str(&*file_compression_type.unwrap_or("".to_string()).as_str())
1087 .map_err(|_| {
1088 PyValueError::new_err("file_compression_type must one of: gzip, bz2, xz, zstd")
1089 })
1090}
1091
1092impl From<PySessionContext> for SessionContext {
1093 fn from(ctx: PySessionContext) -> SessionContext {
1094 ctx.ctx
1095 }
1096}
1097
1098impl From<SessionContext> for PySessionContext {
1099 fn from(ctx: SessionContext) -> PySessionContext {
1100 PySessionContext { ctx }
1101 }
1102}