datafusion_python/common/
schema.rs1use std::{any::Any, borrow::Cow};
19
20use datafusion::arrow::datatypes::SchemaRef;
21use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
22use pyo3::prelude::*;
23
24use datafusion::logical_expr::utils::split_conjunction;
25
26use super::{data_type::DataTypeMap, function::SqlFunction};
27
28#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
29#[derive(Debug, Clone)]
30pub struct SqlSchema {
31 #[pyo3(get, set)]
32 pub name: String,
33 #[pyo3(get, set)]
34 pub tables: Vec<SqlTable>,
35 #[pyo3(get, set)]
36 pub views: Vec<SqlView>,
37 #[pyo3(get, set)]
38 pub functions: Vec<SqlFunction>,
39}
40
41#[pyclass(name = "SqlTable", module = "datafusion.common", subclass)]
42#[derive(Debug, Clone)]
43pub struct SqlTable {
44 #[pyo3(get, set)]
45 pub name: String,
46 #[pyo3(get, set)]
47 pub columns: Vec<(String, DataTypeMap)>,
48 #[pyo3(get, set)]
49 pub primary_key: Option<String>,
50 #[pyo3(get, set)]
51 pub foreign_keys: Vec<String>,
52 #[pyo3(get, set)]
53 pub indexes: Vec<String>,
54 #[pyo3(get, set)]
55 pub constraints: Vec<String>,
56 #[pyo3(get, set)]
57 pub statistics: SqlStatistics,
58 #[pyo3(get, set)]
59 pub filepaths: Option<Vec<String>>,
60}
61
62#[pymethods]
63impl SqlTable {
64 #[new]
65 #[pyo3(signature = (table_name, columns, row_count, filepaths=None))]
66 pub fn new(
67 table_name: String,
68 columns: Vec<(String, DataTypeMap)>,
69 row_count: f64,
70 filepaths: Option<Vec<String>>,
71 ) -> Self {
72 Self {
73 name: table_name,
74 columns,
75 primary_key: None,
76 foreign_keys: Vec::new(),
77 indexes: Vec::new(),
78 constraints: Vec::new(),
79 statistics: SqlStatistics::new(row_count),
80 filepaths,
81 }
82 }
83}
84
85#[pyclass(name = "SqlView", module = "datafusion.common", subclass)]
86#[derive(Debug, Clone)]
87pub struct SqlView {
88 #[pyo3(get, set)]
89 pub name: String,
90 #[pyo3(get, set)]
91 pub definition: String, }
93
94#[pymethods]
95impl SqlSchema {
96 #[new]
97 pub fn new(schema_name: &str) -> Self {
98 Self {
99 name: schema_name.to_owned(),
100 tables: Vec::new(),
101 views: Vec::new(),
102 functions: Vec::new(),
103 }
104 }
105
106 pub fn table_by_name(&self, table_name: &str) -> Option<SqlTable> {
107 for tbl in &self.tables {
108 if tbl.name.eq(table_name) {
109 return Some(tbl.clone());
110 }
111 }
112 None
113 }
114
115 pub fn add_table(&mut self, table: SqlTable) {
116 self.tables.push(table);
117 }
118
119 pub fn drop_table(&mut self, table_name: String) {
120 self.tables.retain(|x| !x.name.eq(&table_name));
121 }
122}
123
124pub struct SqlTableSource {
126 schema: SchemaRef,
127 statistics: Option<SqlStatistics>,
128 filepaths: Option<Vec<String>>,
129}
130
131impl SqlTableSource {
132 pub fn new(
134 schema: SchemaRef,
135 statistics: Option<SqlStatistics>,
136 filepaths: Option<Vec<String>>,
137 ) -> Self {
138 Self {
139 schema,
140 statistics,
141 filepaths,
142 }
143 }
144
145 pub fn statistics(&self) -> Option<&SqlStatistics> {
147 self.statistics.as_ref()
148 }
149
150 #[allow(dead_code)]
152 pub fn filepaths(&self) -> Option<&Vec<String>> {
153 self.filepaths.as_ref()
154 }
155}
156
157impl TableSource for SqlTableSource {
159 fn as_any(&self) -> &dyn Any {
160 self
161 }
162
163 fn schema(&self) -> SchemaRef {
164 self.schema.clone()
165 }
166
167 fn table_type(&self) -> datafusion::logical_expr::TableType {
168 datafusion::logical_expr::TableType::Base
169 }
170
171 fn supports_filters_pushdown(
172 &self,
173 filters: &[&Expr],
174 ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
175 filters
176 .iter()
177 .map(|f| {
178 let filters = split_conjunction(f);
179 if filters.iter().all(|f| is_supported_push_down_expr(f)) {
180 Ok(TableProviderFilterPushDown::Exact)
182 } else if filters.iter().any(|f| is_supported_push_down_expr(f)) {
183 Ok(TableProviderFilterPushDown::Inexact)
186 } else {
187 Ok(TableProviderFilterPushDown::Unsupported)
188 }
189 })
190 .collect()
191 }
192
193 fn get_logical_plan(&self) -> Option<Cow<datafusion::logical_expr::LogicalPlan>> {
194 None
195 }
196}
197
198fn is_supported_push_down_expr(_expr: &Expr) -> bool {
199 true
201}
202
203#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)]
204#[derive(Debug, Clone)]
205pub struct SqlStatistics {
206 row_count: f64,
207}
208
209#[pymethods]
210impl SqlStatistics {
211 #[new]
212 pub fn new(row_count: f64) -> Self {
213 Self { row_count }
214 }
215
216 #[pyo3(name = "getRowCount")]
217 pub fn get_row_count(&self) -> f64 {
218 self.row_count
219 }
220}