proof_of_sql_planner/
context.rs1use super::table_reference_to_table_ref;
2use crate::schema_to_column_fields;
3use alloc::sync::Arc;
4use arrow::datatypes::{Field, Schema};
5use core::any::Any;
6use datafusion::{
7 common::{
8 arrow::datatypes::{DataType, SchemaRef},
9 DataFusionError,
10 },
11 config::ConfigOptions,
12 logical_expr::{
13 AggregateUDF, Expr, ScalarUDF, TableProviderFilterPushDown, TableSource, WindowUDF,
14 },
15 sql::{planner::ContextProvider, TableReference},
16};
17use proof_of_sql::base::database::{ColumnField, SchemaAccessor};
18
19pub struct PoSqlContextProvider<A: SchemaAccessor> {
23 accessor: A,
24 options: ConfigOptions,
25}
26
27impl<A: SchemaAccessor> PoSqlContextProvider<A> {
28 #[must_use]
30 pub fn new(accessor: A) -> Self {
31 Self {
32 accessor,
33 options: ConfigOptions::default(),
34 }
35 }
36}
37
38impl<A: SchemaAccessor> ContextProvider for PoSqlContextProvider<A> {
39 fn get_table_source(
40 &self,
41 name: TableReference,
42 ) -> Result<Arc<dyn TableSource>, DataFusionError> {
43 let table_ref = table_reference_to_table_ref(&name)
44 .map_err(|err| DataFusionError::External(Box::new(err)))?;
45 let schema = self.accessor.lookup_schema(&table_ref);
46 let column_fields = schema_to_column_fields(schema);
47 Ok(Arc::new(PoSqlTableSource::new(column_fields)) as Arc<dyn TableSource>)
48 }
49 fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
50 None
51 }
52 fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
54 None
55 }
56 fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
57 None
58 }
59 fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
60 None
61 }
62 fn options(&self) -> &ConfigOptions {
63 &self.options
64 }
65 fn udfs_names(&self) -> Vec<String> {
66 Vec::new()
67 }
68 fn udafs_names(&self) -> Vec<String> {
69 Vec::new()
70 }
71 fn udwfs_names(&self) -> Vec<String> {
72 Vec::new()
73 }
74}
75
76pub(crate) struct PoSqlTableSource {
78 schema: SchemaRef,
79}
80
81impl PoSqlTableSource {
82 pub(crate) fn new(column_fields: Vec<ColumnField>) -> Self {
84 let arrow_schema = Schema::new(
85 column_fields
86 .into_iter()
87 .map(|column_field| {
88 Field::new(
89 column_field.name().value.as_str(),
90 (&column_field.data_type()).into(),
91 false,
92 )
93 })
94 .collect::<Vec<_>>(),
95 );
96 Self {
97 schema: Arc::new(arrow_schema),
98 }
99 }
100}
101
102impl TableSource for PoSqlTableSource {
103 fn as_any(&self) -> &dyn Any {
104 self
105 }
106 fn schema(&self) -> SchemaRef {
107 self.schema.clone()
108 }
109 fn supports_filters_pushdown(
110 &self,
111 filters: &[&Expr],
112 ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
113 Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use ahash::AHasher;
121 use alloc::vec;
122 use core::any::TypeId;
123 use indexmap::indexmap_with_default;
124 use proof_of_sql::base::database::{ColumnType, TableRef, TestSchemaAccessor};
125
126 #[test]
128 fn we_can_create_a_posql_table_source() {
129 let table_source = PoSqlTableSource::new(vec![]);
131 assert_eq!(table_source.schema().all_fields(), Vec::<&Field>::new());
132 assert_eq!(
133 table_source.as_any().type_id(),
134 TypeId::of::<PoSqlTableSource>()
135 );
136
137 let column_fields = vec![
139 ColumnField::new("a".into(), ColumnType::SmallInt),
140 ColumnField::new("b".into(), ColumnType::VarChar),
141 ];
142 let table_source = PoSqlTableSource::new(column_fields);
143 assert_eq!(
144 table_source.schema().all_fields(),
145 vec![
146 &Field::new("a", DataType::Int16, false),
147 &Field::new("b", DataType::Utf8, false),
148 ]
149 );
150 assert_eq!(
151 table_source.as_any().type_id(),
152 TypeId::of::<PoSqlTableSource>()
153 );
154 }
155
156 #[test]
158 fn we_can_create_a_posql_context_provider() {
159 let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;});
161 let context_provider = PoSqlContextProvider::new(accessor);
162 assert_eq!(context_provider.udfs_names(), Vec::<String>::new());
163 assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
164 assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
165 assert_eq!(context_provider.get_variable_type(&[]), None);
166 assert_eq!(context_provider.get_function_meta(""), None);
167 assert_eq!(context_provider.get_aggregate_meta(""), None);
168 assert_eq!(context_provider.get_window_meta(""), None);
169 assert_eq!(
170 context_provider
171 .get_table_source(TableReference::from("namespace.table"))
172 .unwrap()
173 .schema(),
174 PoSqlTableSource::new(Vec::new()).schema()
175 );
176
177 let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;
179 TableRef::new("namespace", "a") => indexmap_with_default! {AHasher;
180 "a".into() => ColumnType::SmallInt,
181 "b".into() => ColumnType::VarChar
182 },
183 TableRef::new("namespace", "b") => indexmap_with_default! {AHasher;
184 "c".into() => ColumnType::Int,
185 "d".into() => ColumnType::BigInt
186 },
187 });
188 let context_provider = PoSqlContextProvider::new(accessor);
189 assert_eq!(context_provider.udfs_names(), Vec::<String>::new());
190 assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
191 assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
192 assert_eq!(context_provider.get_variable_type(&[]), None);
193 assert_eq!(context_provider.get_function_meta(""), None);
194 assert_eq!(context_provider.get_aggregate_meta(""), None);
195 assert_eq!(context_provider.get_window_meta(""), None);
196 assert_eq!(
197 context_provider
198 .get_table_source(TableReference::from("namespace.a"))
199 .unwrap()
200 .schema(),
201 Arc::new(PoSqlTableSource::new(vec![
202 ColumnField::new("a".into(), ColumnType::SmallInt),
203 ColumnField::new("b".into(), ColumnType::VarChar)
204 ]))
205 .schema()
206 );
207 }
208
209 #[test]
210 fn we_cannot_create_a_posql_context_provider_if_catalog_provided() {
211 let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;});
212 let context_provider = PoSqlContextProvider::new(accessor);
213 assert!(matches!(
214 context_provider.get_table_source(TableReference::from("catalog.namespace.table")),
215 Err(DataFusionError::External(_))
216 ));
217 }
218}