Skip to main content

polars_python/
sql.rs

1use parking_lot::RwLock;
2use polars::sql::{SQLContext, extract_table_identifiers};
3use pyo3::prelude::*;
4
5use crate::PyLazyFrame;
6use crate::error::PyPolarsErr;
7
8#[pyclass(frozen)]
9#[repr(transparent)]
10pub struct PySQLContext {
11    pub context: RwLock<SQLContext>,
12}
13
14impl Clone for PySQLContext {
15    fn clone(&self) -> Self {
16        Self {
17            context: RwLock::new(self.context.read().clone()),
18        }
19    }
20}
21
22#[pymethods]
23#[allow(
24    clippy::wrong_self_convention,
25    clippy::should_implement_trait,
26    clippy::len_without_is_empty
27)]
28impl PySQLContext {
29    #[staticmethod]
30    #[allow(clippy::new_without_default)]
31    pub fn new() -> PySQLContext {
32        PySQLContext {
33            context: RwLock::new(SQLContext::new()),
34        }
35    }
36
37    /// Execute a SQL query in the current SQLContext.
38    pub fn execute(&self, query: &str) -> PyResult<PyLazyFrame> {
39        Ok(self
40            .context
41            .write()
42            .execute(query)
43            .map_err(PyPolarsErr::from)?
44            .into())
45    }
46
47    /// Get a list of table names registered in the current SQLContext.
48    pub fn get_tables(&self) -> PyResult<Vec<String>> {
49        Ok(self.context.read().get_tables())
50    }
51
52    /// Register a table in the current SQLContext.
53    pub fn register(&self, name: &str, lf: PyLazyFrame) {
54        self.context.write().register(name, lf.ldf.into_inner())
55    }
56
57    /// Unregister a table from the current SQLContext.
58    pub fn unregister(&self, name: &str) {
59        self.context.write().unregister(name)
60    }
61
62    /// Extract table identifiers from a SQL query string.
63    #[staticmethod]
64    #[pyo3(signature = (query, include_schema=true, unique=false))]
65    pub fn table_identifiers(
66        query: &str,
67        include_schema: bool,
68        unique: bool,
69    ) -> PyResult<Vec<String>> {
70        extract_table_identifiers(query, include_schema, unique)
71            .map_err(PyPolarsErr::from)
72            .map_err(Into::into)
73    }
74}