1use parking_lot::RwLock;
2use polars::sql::{SQLContext, extract_table_identifiers};
3use pyo3::prelude::*;
4
5use crate::PyLazyFrame;
6use crate::error::PyPolarsErr;
7
8#[pyclass(frozen)]
9#[repr(transparent)]
10pub struct PySQLContext {
11 pub context: RwLock<SQLContext>,
12}
13
14impl Clone for PySQLContext {
15 fn clone(&self) -> Self {
16 Self {
17 context: RwLock::new(self.context.read().clone()),
18 }
19 }
20}
21
22#[pymethods]
23#[allow(
24 clippy::wrong_self_convention,
25 clippy::should_implement_trait,
26 clippy::len_without_is_empty
27)]
28impl PySQLContext {
29 #[staticmethod]
30 #[allow(clippy::new_without_default)]
31 pub fn new() -> PySQLContext {
32 PySQLContext {
33 context: RwLock::new(SQLContext::new()),
34 }
35 }
36
37 pub fn execute(&self, query: &str) -> PyResult<PyLazyFrame> {
39 Ok(self
40 .context
41 .write()
42 .execute(query)
43 .map_err(PyPolarsErr::from)?
44 .into())
45 }
46
47 pub fn get_tables(&self) -> PyResult<Vec<String>> {
49 Ok(self.context.read().get_tables())
50 }
51
52 pub fn register(&self, name: &str, lf: PyLazyFrame) {
54 self.context.write().register(name, lf.ldf.into_inner())
55 }
56
57 pub fn unregister(&self, name: &str) {
59 self.context.write().unregister(name)
60 }
61
62 #[staticmethod]
64 #[pyo3(signature = (query, include_schema=true, unique=false))]
65 pub fn table_identifiers(
66 query: &str,
67 include_schema: bool,
68 unique: bool,
69 ) -> PyResult<Vec<String>> {
70 extract_table_identifiers(query, include_schema, unique)
71 .map_err(PyPolarsErr::from)
72 .map_err(Into::into)
73 }
74}