datafusion_python/common/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{self, Display, Formatter};
19use std::sync::Arc;
20use std::{any::Any, borrow::Cow};
21
22use arrow::datatypes::Schema;
23use arrow::pyarrow::PyArrowType;
24use datafusion::arrow::datatypes::SchemaRef;
25use datafusion::common::Constraints;
26use datafusion::datasource::TableType;
27use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
28use pyo3::prelude::*;
29
30use datafusion::logical_expr::utils::split_conjunction;
31
32use crate::sql::logical::PyLogicalPlan;
33
34use super::{data_type::DataTypeMap, function::SqlFunction};
35
36#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
37#[derive(Debug, Clone)]
38pub struct SqlSchema {
39    #[pyo3(get, set)]
40    pub name: String,
41    #[pyo3(get, set)]
42    pub tables: Vec<SqlTable>,
43    #[pyo3(get, set)]
44    pub views: Vec<SqlView>,
45    #[pyo3(get, set)]
46    pub functions: Vec<SqlFunction>,
47}
48
49#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
50#[derive(Debug, Clone)]
51pub struct SqlTable {
52    #[pyo3(get, set)]
53    pub name: String,
54    #[pyo3(get, set)]
55    pub columns: Vec<(String, DataTypeMap)>,
56    #[pyo3(get, set)]
57    pub primary_key: Option<String>,
58    #[pyo3(get, set)]
59    pub foreign_keys: Vec<String>,
60    #[pyo3(get, set)]
61    pub indexes: Vec<String>,
62    #[pyo3(get, set)]
63    pub constraints: Vec<String>,
64    #[pyo3(get, set)]
65    pub statistics: SqlStatistics,
66    #[pyo3(get, set)]
67    pub filepaths: Option<Vec<String>>,
68}
69
70#[pymethods]
71impl SqlTable {
72    #[new]
73    #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
74    pub fn new(
75        table_name: String,
76        columns: Vec<(String, DataTypeMap)>,
77        row_count: f64,
78        filepaths: Option<Vec<String>>,
79    ) -> Self {
80        Self {
81            name: table_name,
82            columns,
83            primary_key: None,
84            foreign_keys: Vec::new(),
85            indexes: Vec::new(),
86            constraints: Vec::new(),
87            statistics: SqlStatistics::new(row_count),
88            filepaths,
89        }
90    }
91}
92
93#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
94#[derive(Debug, Clone)]
95pub struct SqlView {
96    #[pyo3(get, set)]
97    pub name: String,
98    #[pyo3(get, set)]
99    pub definition: String, // SQL code that defines the view
100}
101
102#[pymethods]
103impl SqlSchema {
104    #[new]
105    pub fn new(schema_name: &str) -> Self {
106        Self {
107            name: schema_name.to_owned(),
108            tables: Vec::new(),
109            views: Vec::new(),
110            functions: Vec::new(),
111        }
112    }
113
114    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
115        for tbl in &self.tables {
116            if tbl.name.eq(table_name) {
117                return Some(tbl.clone());
118            }
119        }
120        None
121    }
122
123    pub fn add_table(&mut self, table: SqlTable) {
124        self.tables.push(table);
125    }
126
127    pub fn drop_table(&mut self, table_name: String) {
128        self.tables.retain(|x| !x.name.eq(&table_name));
129    }
130}
131
132/// SqlTable wrapper that is compatible with DataFusion logical query plans
133pub struct SqlTableSource {
134    schema: SchemaRef,
135    statistics: Option<SqlStatistics>,
136    filepaths: Option<Vec<String>>,
137}
138
139impl SqlTableSource {
140    /// Initialize a new `EmptyTable` from a schema
141    pub fn new(
142        schema: SchemaRef,
143        statistics: Option<SqlStatistics>,
144        filepaths: Option<Vec<String>>,
145    ) -> Self {
146        Self {
147            schema,
148            statistics,
149            filepaths,
150        }
151    }
152
153    /// Access optional statistics associated with this table source
154    pub fn statistics(&self) -> Option<&SqlStatistics> {
155        self.statistics.as_ref()
156    }
157
158    /// Access optional filepath associated with this table source
159    #[allow(dead_code)]
160    pub fn filepaths(&self) -> Option<&Vec<String>> {
161        self.filepaths.as_ref()
162    }
163}
164
165/// Implement TableSource, used in the logical query plan and in logical query optimizations
166impl TableSource for SqlTableSource {
167    fn as_any(&self) -> &dyn Any {
168        self
169    }
170
171    fn schema(&self) -> SchemaRef {
172        self.schema.clone()
173    }
174
175    fn table_type(&self) -> datafusion::logical_expr::TableType {
176        datafusion::logical_expr::TableType::Base
177    }
178
179    fn supports_filters_pushdown(
180        &self,
181        filters: &[&Expr],
182    ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
183        filters
184            .iter()
185            .map(|f| {
186                let filters = split_conjunction(f);
187                if filters.iter().all(|f| is_supported_push_down_expr(f)) {
188                    // Push down filters to the tablescan operation if all are supported
189                    Ok(TableProviderFilterPushDown::Exact)
190                } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
191                    // Partially apply the filter in the TableScan but retain
192                    // the Filter operator in the plan as well
193                    Ok(TableProviderFilterPushDown::Inexact)
194                } else {
195                    Ok(TableProviderFilterPushDown::Unsupported)
196                }
197            })
198            .collect()
199    }
200
201    fn get_logical_plan(&self) -> Option<Cow<datafusion::logical_expr::LogicalPlan>> {
202        None
203    }
204}
205
206fn is_supported_push_down_expr(_expr: &Expr) -> bool {
207    // For now we support all kinds of expr's at this level
208    true
209}
210
211#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
212#[derive(Debug, Clone)]
213pub struct SqlStatistics {
214    row_count: f64,
215}
216
217#[pymethods]
218impl SqlStatistics {
219    #[new]
220    pub fn new(row_count: f64) -> Self {
221        Self { row_count }
222    }
223
224    #[pyo3(name = "getRowCount")]
225    pub fn get_row_count(&self) -> f64 {
226        self.row_count
227    }
228}
229
230#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)]
231#[derive(Clone)]
232pub struct PyConstraints {
233    pub constraints: Constraints,
234}
235
236impl From<PyConstraints> for Constraints {
237    fn from(constraints: PyConstraints) -> Self {
238        constraints.constraints
239    }
240}
241
242impl From<Constraints> for PyConstraints {
243    fn from(constraints: Constraints) -> Self {
244        PyConstraints { constraints }
245    }
246}
247
248impl Display for PyConstraints {
249    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
250        write!(f, "Constraints: {:?}", self.constraints)
251    }
252}
253
254#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
255#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")]
256pub enum PyTableType {
257    Base,
258    View,
259    Temporary,
260}
261
262impl From<PyTableType> for datafusion::logical_expr::TableType {
263    fn from(table_type: PyTableType) -> Self {
264        match table_type {
265            PyTableType::Base => datafusion::logical_expr::TableType::Base,
266            PyTableType::View => datafusion::logical_expr::TableType::View,
267            PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
268        }
269    }
270}
271
272impl From<TableType> for PyTableType {
273    fn from(table_type: TableType) -> Self {
274        match table_type {
275            datafusion::logical_expr::TableType::Base => PyTableType::Base,
276            datafusion::logical_expr::TableType::View => PyTableType::View,
277            datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
278        }
279    }
280}
281
282#[pyclass(name = "TableSource", module = "datafusion.common", subclass)]
283#[derive(Clone)]
284pub struct PyTableSource {
285    pub table_source: Arc<dyn TableSource>,
286}
287
288#[pymethods]
289impl PyTableSource {
290    pub fn schema(&self) -> PyArrowType<Schema> {
291        (*self.table_source.schema()).clone().into()
292    }
293
294    pub fn constraints(&self) -> Option<PyConstraints> {
295        self.table_source.constraints().map(|c| PyConstraints {
296            constraints: c.clone(),
297        })
298    }
299
300    pub fn table_type(&self) -> PyTableType {
301        self.table_source.table_type().into()
302    }
303
304    pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
305        self.table_source
306            .get_logical_plan()
307            .map(|plan| PyLogicalPlan::new(plan.into_owned()))
308    }
309}