Skip to main content

datafusion_python/common/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::borrow::Cow;
20use std::fmt::{self, Display, Formatter};
21use std::sync::Arc;
22
23use arrow::datatypes::Schema;
24use arrow::pyarrow::PyArrowType;
25use datafusion::arrow::datatypes::SchemaRef;
26use datafusion::common::Constraints;
27use datafusion::datasource::TableType;
28use datafusion::logical_expr::utils::split_conjunction;
29use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
30use parking_lot::RwLock;
31use pyo3::prelude::*;
32
33use super::data_type::DataTypeMap;
34use super::function::SqlFunction;
35use crate::sql::logical::PyLogicalPlan;
36
37#[pyclass(
38    from_py_object,
39    name = "SqlSchema",
40    module = "datafusion.common",
41    subclass,
42    frozen
43)]
44#[derive(Debug, Clone)]
45pub struct SqlSchema {
46    name: Arc<RwLock<String>>,
47    tables: Arc<RwLock<Vec<SqlTable>>>,
48    views: Arc<RwLock<Vec<SqlView>>>,
49    functions: Arc<RwLock<Vec<SqlFunction>>>,
50}
51
52#[pyclass(
53    from_py_object,
54    name = "SqlTable",
55    module = "datafusion.common",
56    subclass
57)]
58#[derive(Debug, Clone)]
59pub struct SqlTable {
60    #[pyo3(get, set)]
61    pub name: String,
62    #[pyo3(get, set)]
63    pub columns: Vec<(String, DataTypeMap)>,
64    #[pyo3(get, set)]
65    pub primary_key: Option<String>,
66    #[pyo3(get, set)]
67    pub foreign_keys: Vec<String>,
68    #[pyo3(get, set)]
69    pub indexes: Vec<String>,
70    #[pyo3(get, set)]
71    pub constraints: Vec<String>,
72    #[pyo3(get, set)]
73    pub statistics: SqlStatistics,
74    #[pyo3(get, set)]
75    pub filepaths: Option<Vec<String>>,
76}
77
78#[pymethods]
79impl SqlTable {
80    #[new]
81    #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
82    pub fn new(
83        table_name: String,
84        columns: Vec<(String, DataTypeMap)>,
85        row_count: f64,
86        filepaths: Option<Vec<String>>,
87    ) -> Self {
88        Self {
89            name: table_name,
90            columns,
91            primary_key: None,
92            foreign_keys: Vec::new(),
93            indexes: Vec::new(),
94            constraints: Vec::new(),
95            statistics: SqlStatistics::new(row_count),
96            filepaths,
97        }
98    }
99}
100
101#[pyclass(
102    from_py_object,
103    name = "SqlView",
104    module = "datafusion.common",
105    subclass
106)]
107#[derive(Debug, Clone)]
108pub struct SqlView {
109    #[pyo3(get, set)]
110    pub name: String,
111    #[pyo3(get, set)]
112    pub definition: String, // SQL code that defines the view
113}
114
115#[pymethods]
116impl SqlSchema {
117    #[new]
118    pub fn new(schema_name: &str) -> Self {
119        Self {
120            name: Arc::new(RwLock::new(schema_name.to_owned())),
121            tables: Arc::new(RwLock::new(Vec::new())),
122            views: Arc::new(RwLock::new(Vec::new())),
123            functions: Arc::new(RwLock::new(Vec::new())),
124        }
125    }
126
127    #[getter]
128    fn name(&self) -> PyResult<String> {
129        Ok(self.name.read().clone())
130    }
131
132    #[setter]
133    fn set_name(&self, value: String) -> PyResult<()> {
134        *self.name.write() = value;
135        Ok(())
136    }
137
138    #[getter]
139    fn tables(&self) -> PyResult<Vec<SqlTable>> {
140        Ok(self.tables.read().clone())
141    }
142
143    #[setter]
144    fn set_tables(&self, tables: Vec<SqlTable>) -> PyResult<()> {
145        *self.tables.write() = tables;
146        Ok(())
147    }
148
149    #[getter]
150    fn views(&self) -> PyResult<Vec<SqlView>> {
151        Ok(self.views.read().clone())
152    }
153
154    #[setter]
155    fn set_views(&self, views: Vec<SqlView>) -> PyResult<()> {
156        *self.views.write() = views;
157        Ok(())
158    }
159
160    #[getter]
161    fn functions(&self) -> PyResult<Vec<SqlFunction>> {
162        Ok(self.functions.read().clone())
163    }
164
165    #[setter]
166    fn set_functions(&self, functions: Vec<SqlFunction>) -> PyResult<()> {
167        *self.functions.write() = functions;
168        Ok(())
169    }
170
171    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
172        let tables = self.tables.read();
173        tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned()
174    }
175
176    pub fn add_table(&self, table: SqlTable) {
177        let mut tables = self.tables.write();
178        tables.push(table);
179    }
180
181    pub fn drop_table(&self, table_name: String) {
182        let mut tables = self.tables.write();
183        tables.retain(|x| !x.name.eq(&table_name));
184    }
185}
186
187/// SqlTable wrapper that is compatible with DataFusion logical query plans
188pub struct SqlTableSource {
189    schema: SchemaRef,
190    statistics: Option<SqlStatistics>,
191    filepaths: Option<Vec<String>>,
192}
193
194impl SqlTableSource {
195    /// Initialize a new `EmptyTable` from a schema
196    pub fn new(
197        schema: SchemaRef,
198        statistics: Option<SqlStatistics>,
199        filepaths: Option<Vec<String>>,
200    ) -> Self {
201        Self {
202            schema,
203            statistics,
204            filepaths,
205        }
206    }
207
208    /// Access optional statistics associated with this table source
209    pub fn statistics(&self) -> Option<&SqlStatistics> {
210        self.statistics.as_ref()
211    }
212
213    /// Access optional filepath associated with this table source
214    #[allow(dead_code)]
215    pub fn filepaths(&self) -> Option<&Vec<String>> {
216        self.filepaths.as_ref()
217    }
218}
219
220/// Implement TableSource, used in the logical query plan and in logical query optimizations
221impl TableSource for SqlTableSource {
222    fn as_any(&self) -> &dyn Any {
223        self
224    }
225
226    fn schema(&self) -> SchemaRef {
227        self.schema.clone()
228    }
229
230    fn table_type(&self) -> datafusion::logical_expr::TableType {
231        datafusion::logical_expr::TableType::Base
232    }
233
234    fn supports_filters_pushdown(
235        &self,
236        filters: &[&Expr],
237    ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
238        filters
239            .iter()
240            .map(|f| {
241                let filters = split_conjunction(f);
242                if filters.iter().all(|f| is_supported_push_down_expr(f)) {
243                    // Push down filters to the tablescan operation if all are supported
244                    Ok(TableProviderFilterPushDown::Exact)
245                } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
246                    // Partially apply the filter in the TableScan but retain
247                    // the Filter operator in the plan as well
248                    Ok(TableProviderFilterPushDown::Inexact)
249                } else {
250                    Ok(TableProviderFilterPushDown::Unsupported)
251                }
252            })
253            .collect()
254    }
255
256    fn get_logical_plan(&self) -> Option<Cow<'_, datafusion::logical_expr::LogicalPlan>> {
257        None
258    }
259}
260
261fn is_supported_push_down_expr(_expr: &Expr) -> bool {
262    // For now we support all kinds of expr's at this level
263    true
264}
265
266#[pyclass(
267    from_py_object,
268    frozen,
269    name = "SqlStatistics",
270    module = "datafusion.common",
271    subclass
272)]
273#[derive(Debug, Clone)]
274pub struct SqlStatistics {
275    row_count: f64,
276}
277
278#[pymethods]
279impl SqlStatistics {
280    #[new]
281    pub fn new(row_count: f64) -> Self {
282        Self { row_count }
283    }
284
285    #[pyo3(name = "getRowCount")]
286    pub fn get_row_count(&self) -> f64 {
287        self.row_count
288    }
289}
290
291#[pyclass(
292    from_py_object,
293    frozen,
294    name = "Constraints",
295    module = "datafusion.expr",
296    subclass
297)]
298#[derive(Clone)]
299pub struct PyConstraints {
300    pub constraints: Constraints,
301}
302
303impl From<PyConstraints> for Constraints {
304    fn from(constraints: PyConstraints) -> Self {
305        constraints.constraints
306    }
307}
308
309impl From<Constraints> for PyConstraints {
310    fn from(constraints: Constraints) -> Self {
311        PyConstraints { constraints }
312    }
313}
314
315impl Display for PyConstraints {
316    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
317        write!(f, "Constraints: {:?}", self.constraints)
318    }
319}
320
321#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
322#[pyclass(
323    from_py_object,
324    frozen,
325    eq,
326    eq_int,
327    name = "TableType",
328    module = "datafusion.common"
329)]
330pub enum PyTableType {
331    Base,
332    View,
333    Temporary,
334}
335
336impl From<PyTableType> for datafusion::logical_expr::TableType {
337    fn from(table_type: PyTableType) -> Self {
338        match table_type {
339            PyTableType::Base => datafusion::logical_expr::TableType::Base,
340            PyTableType::View => datafusion::logical_expr::TableType::View,
341            PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
342        }
343    }
344}
345
346impl From<TableType> for PyTableType {
347    fn from(table_type: TableType) -> Self {
348        match table_type {
349            datafusion::logical_expr::TableType::Base => PyTableType::Base,
350            datafusion::logical_expr::TableType::View => PyTableType::View,
351            datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
352        }
353    }
354}
355
356#[pyclass(
357    from_py_object,
358    frozen,
359    name = "TableSource",
360    module = "datafusion.common",
361    subclass
362)]
363#[derive(Clone)]
364pub struct PyTableSource {
365    pub table_source: Arc<dyn TableSource>,
366}
367
368#[pymethods]
369impl PyTableSource {
370    pub fn schema(&self) -> PyArrowType<Schema> {
371        (*self.table_source.schema()).clone().into()
372    }
373
374    pub fn constraints(&self) -> Option<PyConstraints> {
375        self.table_source.constraints().map(|c| PyConstraints {
376            constraints: c.clone(),
377        })
378    }
379
380    pub fn table_type(&self) -> PyTableType {
381        self.table_source.table_type().into()
382    }
383
384    pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
385        self.table_source
386            .get_logical_plan()
387            .map(|plan| PyLogicalPlan::new(plan.into_owned()))
388    }
389}