datafusion_python/common/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{self, Display, Formatter};
19use std::sync::Arc;
20use std::{any::Any, borrow::Cow};
21
22use arrow::datatypes::Schema;
23use arrow::pyarrow::PyArrowType;
24use datafusion::arrow::datatypes::SchemaRef;
25use datafusion::common::Constraints;
26use datafusion::datasource::TableType;
27use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
28use pyo3::prelude::*;
29
30use datafusion::logical_expr::utils::split_conjunction;
31
32use crate::sql::logical::PyLogicalPlan;
33
34use super::{data_type::DataTypeMap, function::SqlFunction};
35
36use parking_lot::RwLock;
37
38#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)]
39#[derive(Debug, Clone)]
40pub struct SqlSchema {
41    name: Arc<RwLock<String>>,
42    tables: Arc<RwLock<Vec<SqlTable>>>,
43    views: Arc<RwLock<Vec<SqlView>>>,
44    functions: Arc<RwLock<Vec<SqlFunction>>>,
45}
46
47#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
48#[derive(Debug, Clone)]
49pub struct SqlTable {
50    #[pyo3(get, set)]
51    pub name: String,
52    #[pyo3(get, set)]
53    pub columns: Vec<(String, DataTypeMap)>,
54    #[pyo3(get, set)]
55    pub primary_key: Option<String>,
56    #[pyo3(get, set)]
57    pub foreign_keys: Vec<String>,
58    #[pyo3(get, set)]
59    pub indexes: Vec<String>,
60    #[pyo3(get, set)]
61    pub constraints: Vec<String>,
62    #[pyo3(get, set)]
63    pub statistics: SqlStatistics,
64    #[pyo3(get, set)]
65    pub filepaths: Option<Vec<String>>,
66}
67
68#[pymethods]
69impl SqlTable {
70    #[new]
71    #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
72    pub fn new(
73        table_name: String,
74        columns: Vec<(String, DataTypeMap)>,
75        row_count: f64,
76        filepaths: Option<Vec<String>>,
77    ) -> Self {
78        Self {
79            name: table_name,
80            columns,
81            primary_key: None,
82            foreign_keys: Vec::new(),
83            indexes: Vec::new(),
84            constraints: Vec::new(),
85            statistics: SqlStatistics::new(row_count),
86            filepaths,
87        }
88    }
89}
90
91#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
92#[derive(Debug, Clone)]
93pub struct SqlView {
94    #[pyo3(get, set)]
95    pub name: String,
96    #[pyo3(get, set)]
97    pub definition: String, // SQL code that defines the view
98}
99
100#[pymethods]
101impl SqlSchema {
102    #[new]
103    pub fn new(schema_name: &str) -> Self {
104        Self {
105            name: Arc::new(RwLock::new(schema_name.to_owned())),
106            tables: Arc::new(RwLock::new(Vec::new())),
107            views: Arc::new(RwLock::new(Vec::new())),
108            functions: Arc::new(RwLock::new(Vec::new())),
109        }
110    }
111
112    #[getter]
113    fn name(&self) -> PyResult<String> {
114        Ok(self.name.read().clone())
115    }
116
117    #[setter]
118    fn set_name(&self, value: String) -> PyResult<()> {
119        *self.name.write() = value;
120        Ok(())
121    }
122
123    #[getter]
124    fn tables(&self) -> PyResult<Vec<SqlTable>> {
125        Ok(self.tables.read().clone())
126    }
127
128    #[setter]
129    fn set_tables(&self, tables: Vec<SqlTable>) -> PyResult<()> {
130        *self.tables.write() = tables;
131        Ok(())
132    }
133
134    #[getter]
135    fn views(&self) -> PyResult<Vec<SqlView>> {
136        Ok(self.views.read().clone())
137    }
138
139    #[setter]
140    fn set_views(&self, views: Vec<SqlView>) -> PyResult<()> {
141        *self.views.write() = views;
142        Ok(())
143    }
144
145    #[getter]
146    fn functions(&self) -> PyResult<Vec<SqlFunction>> {
147        Ok(self.functions.read().clone())
148    }
149
150    #[setter]
151    fn set_functions(&self, functions: Vec<SqlFunction>) -> PyResult<()> {
152        *self.functions.write() = functions;
153        Ok(())
154    }
155
156    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
157        let tables = self.tables.read();
158        tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned()
159    }
160
161    pub fn add_table(&self, table: SqlTable) {
162        let mut tables = self.tables.write();
163        tables.push(table);
164    }
165
166    pub fn drop_table(&self, table_name: String) {
167        let mut tables = self.tables.write();
168        tables.retain(|x| !x.name.eq(&table_name));
169    }
170}
171
172/// SqlTable wrapper that is compatible with DataFusion logical query plans
173pub struct SqlTableSource {
174    schema: SchemaRef,
175    statistics: Option<SqlStatistics>,
176    filepaths: Option<Vec<String>>,
177}
178
179impl SqlTableSource {
180    /// Initialize a new `EmptyTable` from a schema
181    pub fn new(
182        schema: SchemaRef,
183        statistics: Option<SqlStatistics>,
184        filepaths: Option<Vec<String>>,
185    ) -> Self {
186        Self {
187            schema,
188            statistics,
189            filepaths,
190        }
191    }
192
193    /// Access optional statistics associated with this table source
194    pub fn statistics(&self) -> Option<&SqlStatistics> {
195        self.statistics.as_ref()
196    }
197
198    /// Access optional filepath associated with this table source
199    #[allow(dead_code)]
200    pub fn filepaths(&self) -> Option<&Vec<String>> {
201        self.filepaths.as_ref()
202    }
203}
204
205/// Implement TableSource, used in the logical query plan and in logical query optimizations
206impl TableSource for SqlTableSource {
207    fn as_any(&self) -> &dyn Any {
208        self
209    }
210
211    fn schema(&self) -> SchemaRef {
212        self.schema.clone()
213    }
214
215    fn table_type(&self) -> datafusion::logical_expr::TableType {
216        datafusion::logical_expr::TableType::Base
217    }
218
219    fn supports_filters_pushdown(
220        &self,
221        filters: &[&Expr],
222    ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
223        filters
224            .iter()
225            .map(|f| {
226                let filters = split_conjunction(f);
227                if filters.iter().all(|f| is_supported_push_down_expr(f)) {
228                    // Push down filters to the tablescan operation if all are supported
229                    Ok(TableProviderFilterPushDown::Exact)
230                } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
231                    // Partially apply the filter in the TableScan but retain
232                    // the Filter operator in the plan as well
233                    Ok(TableProviderFilterPushDown::Inexact)
234                } else {
235                    Ok(TableProviderFilterPushDown::Unsupported)
236                }
237            })
238            .collect()
239    }
240
241    fn get_logical_plan(&self) -> Option<Cow<'_, datafusion::logical_expr::LogicalPlan>> {
242        None
243    }
244}
245
246fn is_supported_push_down_expr(_expr: &Expr) -> bool {
247    // For now we support all kinds of expr's at this level
248    true
249}
250
251#[pyclass(frozen, name = "SqlStatistics", module = "datafusion.common", subclass)]
252#[derive(Debug, Clone)]
253pub struct SqlStatistics {
254    row_count: f64,
255}
256
257#[pymethods]
258impl SqlStatistics {
259    #[new]
260    pub fn new(row_count: f64) -> Self {
261        Self { row_count }
262    }
263
264    #[pyo3(name = "getRowCount")]
265    pub fn get_row_count(&self) -> f64 {
266        self.row_count
267    }
268}
269
270#[pyclass(frozen, name = "Constraints", module = "datafusion.expr", subclass)]
271#[derive(Clone)]
272pub struct PyConstraints {
273    pub constraints: Constraints,
274}
275
276impl From<PyConstraints> for Constraints {
277    fn from(constraints: PyConstraints) -> Self {
278        constraints.constraints
279    }
280}
281
282impl From<Constraints> for PyConstraints {
283    fn from(constraints: Constraints) -> Self {
284        PyConstraints { constraints }
285    }
286}
287
288impl Display for PyConstraints {
289    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
290        write!(f, "Constraints: {:?}", self.constraints)
291    }
292}
293
294#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
295#[pyclass(frozen, eq, eq_int, name = "TableType", module = "datafusion.common")]
296pub enum PyTableType {
297    Base,
298    View,
299    Temporary,
300}
301
302impl From<PyTableType> for datafusion::logical_expr::TableType {
303    fn from(table_type: PyTableType) -> Self {
304        match table_type {
305            PyTableType::Base => datafusion::logical_expr::TableType::Base,
306            PyTableType::View => datafusion::logical_expr::TableType::View,
307            PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
308        }
309    }
310}
311
312impl From<TableType> for PyTableType {
313    fn from(table_type: TableType) -> Self {
314        match table_type {
315            datafusion::logical_expr::TableType::Base => PyTableType::Base,
316            datafusion::logical_expr::TableType::View => PyTableType::View,
317            datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
318        }
319    }
320}
321
322#[pyclass(frozen, name = "TableSource", module = "datafusion.common", subclass)]
323#[derive(Clone)]
324pub struct PyTableSource {
325    pub table_source: Arc<dyn TableSource>,
326}
327
328#[pymethods]
329impl PyTableSource {
330    pub fn schema(&self) -> PyArrowType<Schema> {
331        (*self.table_source.schema()).clone().into()
332    }
333
334    pub fn constraints(&self) -> Option<PyConstraints> {
335        self.table_source.constraints().map(|c| PyConstraints {
336            constraints: c.clone(),
337        })
338    }
339
340    pub fn table_type(&self) -> PyTableType {
341        self.table_source.table_type().into()
342    }
343
344    pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
345        self.table_source
346            .get_logical_plan()
347            .map(|plan| PyLogicalPlan::new(plan.into_owned()))
348    }
349}