1use super::table_reference_to_table_ref;
2use crate::schema_to_column_fields;
3use alloc::sync::Arc;
4use arrow::datatypes::{Field, Schema};
5use core::any::Any;
6use datafusion::{
7 common::{
8 arrow::datatypes::{DataType, SchemaRef},
9 DataFusionError,
10 },
11 config::ConfigOptions,
12 execution::{context::SessionState, runtime_env::RuntimeEnv},
13 logical_expr::{
14 AggregateUDF, Expr, ScalarUDF, TableProviderFilterPushDown, TableSource, WindowUDF,
15 },
16 prelude::SessionConfig,
17 sql::{planner::ContextProvider, TableReference},
18};
19use proof_of_sql::base::database::{ColumnField, SchemaAccessor};
20
21pub struct PoSqlContextProvider<A: SchemaAccessor> {
25 accessor: A,
26 options: ConfigOptions,
27 state: SessionState,
28}
29
30impl<A: SchemaAccessor> PoSqlContextProvider<A> {
31 #[must_use]
33 pub fn new(accessor: A) -> Self {
34 Self {
35 accessor,
36 options: ConfigOptions::default(),
37 state: SessionState::new_with_config_rt(
38 SessionConfig::default(),
39 RuntimeEnv::default().into(),
40 ),
41 }
42 }
43}
44
45impl<A: SchemaAccessor> ContextProvider for PoSqlContextProvider<A> {
46 fn get_table_source(
47 &self,
48 name: TableReference,
49 ) -> Result<Arc<dyn TableSource>, DataFusionError> {
50 let table_ref = table_reference_to_table_ref(&name)
51 .map_err(|err| DataFusionError::External(Box::new(err)))?;
52 let schema = self.accessor.lookup_schema(&table_ref);
53 let column_fields = schema_to_column_fields(schema);
54 Ok(Arc::new(PoSqlTableSource::new(column_fields)) as Arc<dyn TableSource>)
55 }
56 fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
57 self.state.scalar_functions().get(name).cloned()
58 }
59 fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
61 None
62 }
63 fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
64 None
65 }
66 fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
67 None
68 }
69 fn options(&self) -> &ConfigOptions {
70 &self.options
71 }
72 fn udfs_names(&self) -> Vec<String> {
73 self.state.scalar_functions().keys().cloned().collect()
74 }
75 fn udafs_names(&self) -> Vec<String> {
76 Vec::new()
77 }
78 fn udwfs_names(&self) -> Vec<String> {
79 Vec::new()
80 }
81}
82
83pub(crate) struct PoSqlTableSource {
85 schema: SchemaRef,
86}
87
88impl PoSqlTableSource {
89 pub(crate) fn new(column_fields: Vec<ColumnField>) -> Self {
91 let arrow_schema = Schema::new(
92 column_fields
93 .into_iter()
94 .map(|column_field| {
95 Field::new(
96 column_field.name().value.as_str(),
97 (&column_field.data_type()).into(),
98 false,
99 )
100 })
101 .collect::<Vec<_>>(),
102 );
103 Self {
104 schema: Arc::new(arrow_schema),
105 }
106 }
107}
108
109impl TableSource for PoSqlTableSource {
110 fn as_any(&self) -> &dyn Any {
111 self
112 }
113 fn schema(&self) -> SchemaRef {
114 self.schema.clone()
115 }
116 fn supports_filters_pushdown(
117 &self,
118 filters: &[&Expr],
119 ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
120 Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use ahash::AHasher;
128 use alloc::vec;
129 use core::any::TypeId;
130 use indexmap::indexmap_with_default;
131 use proof_of_sql::base::database::{ColumnType, SchemaAccessorImpl, TableRef};
132
133 #[test]
135 fn we_can_create_a_posql_table_source() {
136 let table_source = PoSqlTableSource::new(vec![]);
138 assert_eq!(table_source.schema().all_fields(), Vec::<&Field>::new());
139 assert_eq!(
140 table_source.as_any().type_id(),
141 TypeId::of::<PoSqlTableSource>()
142 );
143
144 let column_fields = vec![
146 ColumnField::new("a".into(), ColumnType::SmallInt),
147 ColumnField::new("b".into(), ColumnType::VarChar),
148 ];
149 let table_source = PoSqlTableSource::new(column_fields);
150 assert_eq!(
151 table_source.schema().all_fields(),
152 vec![
153 &Field::new("a", DataType::Int16, false),
154 &Field::new("b", DataType::Utf8, false),
155 ]
156 );
157 assert_eq!(
158 table_source.as_any().type_id(),
159 TypeId::of::<PoSqlTableSource>()
160 );
161 }
162
163 #[test]
165 fn we_can_create_a_posql_context_provider() {
166 let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;});
168 let context_provider = PoSqlContextProvider::new(accessor);
169 assert_eq!(
170 context_provider.udfs_names(),
171 context_provider
172 .state
173 .scalar_functions()
174 .keys()
175 .cloned()
176 .collect::<Vec<_>>()
177 );
178 assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
179 assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
180 assert_eq!(context_provider.get_variable_type(&[]), None);
181 assert_eq!(context_provider.get_function_meta(""), None);
182 assert_eq!(context_provider.get_aggregate_meta(""), None);
183 assert_eq!(context_provider.get_window_meta(""), None);
184
185 let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;
187 TableRef::new("namespace", "a") => vec![("a".into(), ColumnType::SmallInt),
188 ("b".into(), ColumnType::VarChar)],
189 TableRef::new("namespace", "b") => vec![("c".into(), ColumnType::Int),
190 ("d".into(), ColumnType::BigInt)]
191 });
192 let context_provider = PoSqlContextProvider::new(accessor);
193 assert_eq!(
194 context_provider.udfs_names(),
195 context_provider
196 .state
197 .scalar_functions()
198 .keys()
199 .cloned()
200 .collect::<Vec<_>>()
201 );
202 assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
203 assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
204 assert_eq!(context_provider.get_variable_type(&[]), None);
205 assert_eq!(context_provider.get_function_meta(""), None);
206 assert_eq!(context_provider.get_aggregate_meta(""), None);
207 assert_eq!(context_provider.get_window_meta(""), None);
208 assert_eq!(
209 context_provider
210 .get_table_source(TableReference::from("namespace.a"))
211 .unwrap()
212 .schema(),
213 Arc::new(PoSqlTableSource::new(vec![
214 ColumnField::new("a".into(), ColumnType::SmallInt),
215 ColumnField::new("b".into(), ColumnType::VarChar)
216 ]))
217 .schema()
218 );
219 }
220
221 #[test]
222 fn we_cannot_create_a_posql_context_provider_if_catalog_provided() {
223 let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;});
224 let context_provider = PoSqlContextProvider::new(accessor);
225 assert!(matches!(
226 context_provider.get_table_source(TableReference::from("catalog.namespace.table")),
227 Err(DataFusionError::External(_))
228 ));
229 }
230}