datafusion_python/common/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::borrow::Cow;
20use std::fmt::{self, Display, Formatter};
21use std::sync::Arc;
22
23use arrow::datatypes::Schema;
24use arrow::pyarrow::PyArrowType;
25use datafusion::arrow::datatypes::SchemaRef;
26use datafusion::common::Constraints;
27use datafusion::datasource::TableType;
28use datafusion::logical_expr::utils::split_conjunction;
29use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
30use parking_lot::RwLock;
31use pyo3::prelude::*;
32
33use super::data_type::DataTypeMap;
34use super::function::SqlFunction;
35use crate::sql::logical::PyLogicalPlan;
36
37#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)]
38#[derive(Debug, Clone)]
39pub struct SqlSchema {
40    name: Arc<RwLock<String>>,
41    tables: Arc<RwLock<Vec<SqlTable>>>,
42    views: Arc<RwLock<Vec<SqlView>>>,
43    functions: Arc<RwLock<Vec<SqlFunction>>>,
44}
45
46#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
47#[derive(Debug, Clone)]
48pub struct SqlTable {
49    #[pyo3(get, set)]
50    pub name: String,
51    #[pyo3(get, set)]
52    pub columns: Vec<(String, DataTypeMap)>,
53    #[pyo3(get, set)]
54    pub primary_key: Option<String>,
55    #[pyo3(get, set)]
56    pub foreign_keys: Vec<String>,
57    #[pyo3(get, set)]
58    pub indexes: Vec<String>,
59    #[pyo3(get, set)]
60    pub constraints: Vec<String>,
61    #[pyo3(get, set)]
62    pub statistics: SqlStatistics,
63    #[pyo3(get, set)]
64    pub filepaths: Option<Vec<String>>,
65}
66
67#[pymethods]
68impl SqlTable {
69    #[new]
70    #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
71    pub fn new(
72        table_name: String,
73        columns: Vec<(String, DataTypeMap)>,
74        row_count: f64,
75        filepaths: Option<Vec<String>>,
76    ) -> Self {
77        Self {
78            name: table_name,
79            columns,
80            primary_key: None,
81            foreign_keys: Vec::new(),
82            indexes: Vec::new(),
83            constraints: Vec::new(),
84            statistics: SqlStatistics::new(row_count),
85            filepaths,
86        }
87    }
88}
89
90#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
91#[derive(Debug, Clone)]
92pub struct SqlView {
93    #[pyo3(get, set)]
94    pub name: String,
95    #[pyo3(get, set)]
96    pub definition: String, // SQL code that defines the view
97}
98
99#[pymethods]
100impl SqlSchema {
101    #[new]
102    pub fn new(schema_name: &str) -> Self {
103        Self {
104            name: Arc::new(RwLock::new(schema_name.to_owned())),
105            tables: Arc::new(RwLock::new(Vec::new())),
106            views: Arc::new(RwLock::new(Vec::new())),
107            functions: Arc::new(RwLock::new(Vec::new())),
108        }
109    }
110
111    #[getter]
112    fn name(&self) -> PyResult<String> {
113        Ok(self.name.read().clone())
114    }
115
116    #[setter]
117    fn set_name(&self, value: String) -> PyResult<()> {
118        *self.name.write() = value;
119        Ok(())
120    }
121
122    #[getter]
123    fn tables(&self) -> PyResult<Vec<SqlTable>> {
124        Ok(self.tables.read().clone())
125    }
126
127    #[setter]
128    fn set_tables(&self, tables: Vec<SqlTable>) -> PyResult<()> {
129        *self.tables.write() = tables;
130        Ok(())
131    }
132
133    #[getter]
134    fn views(&self) -> PyResult<Vec<SqlView>> {
135        Ok(self.views.read().clone())
136    }
137
138    #[setter]
139    fn set_views(&self, views: Vec<SqlView>) -> PyResult<()> {
140        *self.views.write() = views;
141        Ok(())
142    }
143
144    #[getter]
145    fn functions(&self) -> PyResult<Vec<SqlFunction>> {
146        Ok(self.functions.read().clone())
147    }
148
149    #[setter]
150    fn set_functions(&self, functions: Vec<SqlFunction>) -> PyResult<()> {
151        *self.functions.write() = functions;
152        Ok(())
153    }
154
155    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
156        let tables = self.tables.read();
157        tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned()
158    }
159
160    pub fn add_table(&self, table: SqlTable) {
161        let mut tables = self.tables.write();
162        tables.push(table);
163    }
164
165    pub fn drop_table(&self, table_name: String) {
166        let mut tables = self.tables.write();
167        tables.retain(|x| !x.name.eq(&table_name));
168    }
169}
170
171/// SqlTable wrapper that is compatible with DataFusion logical query plans
172pub struct SqlTableSource {
173    schema: SchemaRef,
174    statistics: Option<SqlStatistics>,
175    filepaths: Option<Vec<String>>,
176}
177
178impl SqlTableSource {
179    /// Initialize a new `EmptyTable` from a schema
180    pub fn new(
181        schema: SchemaRef,
182        statistics: Option<SqlStatistics>,
183        filepaths: Option<Vec<String>>,
184    ) -> Self {
185        Self {
186            schema,
187            statistics,
188            filepaths,
189        }
190    }
191
192    /// Access optional statistics associated with this table source
193    pub fn statistics(&self) -> Option<&SqlStatistics> {
194        self.statistics.as_ref()
195    }
196
197    /// Access optional filepath associated with this table source
198    #[allow(dead_code)]
199    pub fn filepaths(&self) -> Option<&Vec<String>> {
200        self.filepaths.as_ref()
201    }
202}
203
204/// Implement TableSource, used in the logical query plan and in logical query optimizations
205impl TableSource for SqlTableSource {
206    fn as_any(&self) -> &dyn Any {
207        self
208    }
209
210    fn schema(&self) -> SchemaRef {
211        self.schema.clone()
212    }
213
214    fn table_type(&self) -> datafusion::logical_expr::TableType {
215        datafusion::logical_expr::TableType::Base
216    }
217
218    fn supports_filters_pushdown(
219        &self,
220        filters: &[&Expr],
221    ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
222        filters
223            .iter()
224            .map(|f| {
225                let filters = split_conjunction(f);
226                if filters.iter().all(|f| is_supported_push_down_expr(f)) {
227                    // Push down filters to the tablescan operation if all are supported
228                    Ok(TableProviderFilterPushDown::Exact)
229                } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
230                    // Partially apply the filter in the TableScan but retain
231                    // the Filter operator in the plan as well
232                    Ok(TableProviderFilterPushDown::Inexact)
233                } else {
234                    Ok(TableProviderFilterPushDown::Unsupported)
235                }
236            })
237            .collect()
238    }
239
240    fn get_logical_plan(&self) -> Option<Cow<'_, datafusion::logical_expr::LogicalPlan>> {
241        None
242    }
243}
244
245fn is_supported_push_down_expr(_expr: &Expr) -> bool {
246    // For now we support all kinds of expr's at this level
247    true
248}
249
250#[pyclass(frozen, name = "SqlStatistics", module = "datafusion.common", subclass)]
251#[derive(Debug, Clone)]
252pub struct SqlStatistics {
253    row_count: f64,
254}
255
256#[pymethods]
257impl SqlStatistics {
258    #[new]
259    pub fn new(row_count: f64) -> Self {
260        Self { row_count }
261    }
262
263    #[pyo3(name = "getRowCount")]
264    pub fn get_row_count(&self) -> f64 {
265        self.row_count
266    }
267}
268
269#[pyclass(frozen, name = "Constraints", module = "datafusion.expr", subclass)]
270#[derive(Clone)]
271pub struct PyConstraints {
272    pub constraints: Constraints,
273}
274
275impl From<PyConstraints> for Constraints {
276    fn from(constraints: PyConstraints) -> Self {
277        constraints.constraints
278    }
279}
280
281impl From<Constraints> for PyConstraints {
282    fn from(constraints: Constraints) -> Self {
283        PyConstraints { constraints }
284    }
285}
286
287impl Display for PyConstraints {
288    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
289        write!(f, "Constraints: {:?}", self.constraints)
290    }
291}
292
293#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
294#[pyclass(frozen, eq, eq_int, name = "TableType", module = "datafusion.common")]
295pub enum PyTableType {
296    Base,
297    View,
298    Temporary,
299}
300
301impl From<PyTableType> for datafusion::logical_expr::TableType {
302    fn from(table_type: PyTableType) -> Self {
303        match table_type {
304            PyTableType::Base => datafusion::logical_expr::TableType::Base,
305            PyTableType::View => datafusion::logical_expr::TableType::View,
306            PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
307        }
308    }
309}
310
311impl From<TableType> for PyTableType {
312    fn from(table_type: TableType) -> Self {
313        match table_type {
314            datafusion::logical_expr::TableType::Base => PyTableType::Base,
315            datafusion::logical_expr::TableType::View => PyTableType::View,
316            datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
317        }
318    }
319}
320
321#[pyclass(frozen, name = "TableSource", module = "datafusion.common", subclass)]
322#[derive(Clone)]
323pub struct PyTableSource {
324    pub table_source: Arc<dyn TableSource>,
325}
326
327#[pymethods]
328impl PyTableSource {
329    pub fn schema(&self) -> PyArrowType<Schema> {
330        (*self.table_source.schema()).clone().into()
331    }
332
333    pub fn constraints(&self) -> Option<PyConstraints> {
334        self.table_source.constraints().map(|c| PyConstraints {
335            constraints: c.clone(),
336        })
337    }
338
339    pub fn table_type(&self) -> PyTableType {
340        self.table_source.table_type().into()
341    }
342
343    pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
344        self.table_source
345            .get_logical_plan()
346            .map(|plan| PyLogicalPlan::new(plan.into_owned()))
347    }
348}