datafusion_python/common/
schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{any::Any, borrow::Cow};
19
20use datafusion::arrow::datatypes::SchemaRef;
21use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
22use pyo3::prelude::*;
23
24use datafusion::logical_expr::utils::split_conjunction;
25
26use super::{data_type::DataTypeMap, function::SqlFunction};
27
28#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
29#[derive(Debug, Clone)]
30pub struct SqlSchema {
31    #[pyo3(get, set)]
32    pub name: String,
33    #[pyo3(get, set)]
34    pub tables: Vec<SqlTable>,
35    #[pyo3(get, set)]
36    pub views: Vec<SqlView>,
37    #[pyo3(get, set)]
38    pub functions: Vec<SqlFunction>,
39}
40
41#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
42#[derive(Debug, Clone)]
43pub struct SqlTable {
44    #[pyo3(get, set)]
45    pub name: String,
46    #[pyo3(get, set)]
47    pub columns: Vec<(String, DataTypeMap)>,
48    #[pyo3(get, set)]
49    pub primary_key: Option<String>,
50    #[pyo3(get, set)]
51    pub foreign_keys: Vec<String>,
52    #[pyo3(get, set)]
53    pub indexes: Vec<String>,
54    #[pyo3(get, set)]
55    pub constraints: Vec<String>,
56    #[pyo3(get, set)]
57    pub statistics: SqlStatistics,
58    #[pyo3(get, set)]
59    pub filepaths: Option<Vec<String>>,
60}
61
62#[pymethods]
63impl SqlTable {
64    #[new]
65    #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
66    pub fn new(
67        table_name: String,
68        columns: Vec<(String, DataTypeMap)>,
69        row_count: f64,
70        filepaths: Option<Vec<String>>,
71    ) -> Self {
72        Self {
73            name: table_name,
74            columns,
75            primary_key: None,
76            foreign_keys: Vec::new(),
77            indexes: Vec::new(),
78            constraints: Vec::new(),
79            statistics: SqlStatistics::new(row_count),
80            filepaths,
81        }
82    }
83}
84
85#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
86#[derive(Debug, Clone)]
87pub struct SqlView {
88    #[pyo3(get, set)]
89    pub name: String,
90    #[pyo3(get, set)]
91    pub definition: String, // SQL code that defines the view
92}
93
94#[pymethods]
95impl SqlSchema {
96    #[new]
97    pub fn new(schema_name: &str) -> Self {
98        Self {
99            name: schema_name.to_owned(),
100            tables: Vec::new(),
101            views: Vec::new(),
102            functions: Vec::new(),
103        }
104    }
105
106    pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
107        for tbl in &self.tables {
108            if tbl.name.eq(table_name) {
109                return Some(tbl.clone());
110            }
111        }
112        None
113    }
114
115    pub fn add_table(&mut self, table: SqlTable) {
116        self.tables.push(table);
117    }
118
119    pub fn drop_table(&mut self, table_name: String) {
120        self.tables.retain(|x| !x.name.eq(&table_name));
121    }
122}
123
124/// SqlTable wrapper that is compatible with DataFusion logical query plans
125pub struct SqlTableSource {
126    schema: SchemaRef,
127    statistics: Option<SqlStatistics>,
128    filepaths: Option<Vec<String>>,
129}
130
131impl SqlTableSource {
132    /// Initialize a new `EmptyTable` from a schema
133    pub fn new(
134        schema: SchemaRef,
135        statistics: Option<SqlStatistics>,
136        filepaths: Option<Vec<String>>,
137    ) -> Self {
138        Self {
139            schema,
140            statistics,
141            filepaths,
142        }
143    }
144
145    /// Access optional statistics associated with this table source
146    pub fn statistics(&self) -> Option<&SqlStatistics> {
147        self.statistics.as_ref()
148    }
149
150    /// Access optional filepath associated with this table source
151    #[allow(dead_code)]
152    pub fn filepaths(&self) -> Option<&Vec<String>> {
153        self.filepaths.as_ref()
154    }
155}
156
157/// Implement TableSource, used in the logical query plan and in logical query optimizations
158impl TableSource for SqlTableSource {
159    fn as_any(&self) -> &dyn Any {
160        self
161    }
162
163    fn schema(&self) -> SchemaRef {
164        self.schema.clone()
165    }
166
167    fn table_type(&self) -> datafusion::logical_expr::TableType {
168        datafusion::logical_expr::TableType::Base
169    }
170
171    fn supports_filters_pushdown(
172        &self,
173        filters: &[&Expr],
174    ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
175        filters
176            .iter()
177            .map(|f| {
178                let filters = split_conjunction(f);
179                if filters.iter().all(|f| is_supported_push_down_expr(f)) {
180                    // Push down filters to the tablescan operation if all are supported
181                    Ok(TableProviderFilterPushDown::Exact)
182                } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
183                    // Partially apply the filter in the TableScan but retain
184                    // the Filter operator in the plan as well
185                    Ok(TableProviderFilterPushDown::Inexact)
186                } else {
187                    Ok(TableProviderFilterPushDown::Unsupported)
188                }
189            })
190            .collect()
191    }
192
193    fn get_logical_plan(&self) -> Option<Cow<datafusion::logical_expr::LogicalPlan>> {
194        None
195    }
196}
197
198fn is_supported_push_down_expr(_expr: &Expr) -> bool {
199    // For now we support all kinds of expr's at this level
200    true
201}
202
203#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
204#[derive(Debug, Clone)]
205pub struct SqlStatistics {
206    row_count: f64,
207}
208
209#[pymethods]
210impl SqlStatistics {
211    #[new]
212    pub fn new(row_count: f64) -> Self {
213        Self { row_count }
214    }
215
216    #[pyo3(name = "getRowCount")]
217    pub fn get_row_count(&self) -> f64 {
218        self.row_count
219    }
220}