Skip to main content

proof_of_sql_planner/
context.rs

1use super::table_reference_to_table_ref;
2use crate::schema_to_column_fields;
3use alloc::sync::Arc;
4use arrow::datatypes::{Field, Schema};
5use core::any::Any;
6use datafusion::{
7    common::{
8        arrow::datatypes::{DataType, SchemaRef},
9        DataFusionError,
10    },
11    config::ConfigOptions,
12    execution::{context::SessionState, runtime_env::RuntimeEnv},
13    logical_expr::{
14        AggregateUDF, Expr, ScalarUDF, TableProviderFilterPushDown, TableSource, WindowUDF,
15    },
16    prelude::SessionConfig,
17    sql::{planner::ContextProvider, TableReference},
18};
19use proof_of_sql::base::database::{ColumnField, SchemaAccessor};
20
21/// A [`ContextProvider`] implementation for Proof of SQL
22///
23/// This provider is used to provide tables to the Proof of SQL planner
24pub struct PoSqlContextProvider<A: SchemaAccessor> {
25    accessor: A,
26    options: ConfigOptions,
27    state: SessionState,
28}
29
30impl<A: SchemaAccessor> PoSqlContextProvider<A> {
31    /// Create a new `PoSqlContextProvider`
32    #[must_use]
33    pub fn new(accessor: A) -> Self {
34        Self {
35            accessor,
36            options: ConfigOptions::default(),
37            state: SessionState::new_with_config_rt(
38                SessionConfig::default(),
39                RuntimeEnv::default().into(),
40            ),
41        }
42    }
43}
44
45impl<A: SchemaAccessor> ContextProvider for PoSqlContextProvider<A> {
46    fn get_table_source(
47        &self,
48        name: TableReference,
49    ) -> Result<Arc<dyn TableSource>, DataFusionError> {
50        let table_ref = table_reference_to_table_ref(&name)
51            .map_err(|err| DataFusionError::External(Box::new(err)))?;
52        let schema = self.accessor.lookup_schema(&table_ref);
53        let column_fields = schema_to_column_fields(schema);
54        Ok(Arc::new(PoSqlTableSource::new(column_fields)) as Arc<dyn TableSource>)
55    }
56    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
57        self.state.scalar_functions().get(name).cloned()
58    }
59    //TODO: add count and sum
60    fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
61        None
62    }
63    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
64        None
65    }
66    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
67        None
68    }
69    fn options(&self) -> &ConfigOptions {
70        &self.options
71    }
72    fn udfs_names(&self) -> Vec<String> {
73        self.state.scalar_functions().keys().cloned().collect()
74    }
75    fn udafs_names(&self) -> Vec<String> {
76        Vec::new()
77    }
78    fn udwfs_names(&self) -> Vec<String> {
79        Vec::new()
80    }
81}
82
83/// A [`TableSource`] implementation for Proof of SQL
84pub(crate) struct PoSqlTableSource {
85    schema: SchemaRef,
86}
87
88impl PoSqlTableSource {
89    /// Create a new `PoSqlTableSource`
90    pub(crate) fn new(column_fields: Vec<ColumnField>) -> Self {
91        let arrow_schema = Schema::new(
92            column_fields
93                .into_iter()
94                .map(|column_field| {
95                    Field::new(
96                        column_field.name().value.as_str(),
97                        (&column_field.data_type()).into(),
98                        false,
99                    )
100                })
101                .collect::<Vec<_>>(),
102        );
103        Self {
104            schema: Arc::new(arrow_schema),
105        }
106    }
107}
108
109impl TableSource for PoSqlTableSource {
110    fn as_any(&self) -> &dyn Any {
111        self
112    }
113    fn schema(&self) -> SchemaRef {
114        self.schema.clone()
115    }
116    fn supports_filters_pushdown(
117        &self,
118        filters: &[&Expr],
119    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
120        Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use ahash::AHasher;
128    use alloc::vec;
129    use core::any::TypeId;
130    use indexmap::indexmap_with_default;
131    use proof_of_sql::base::database::{ColumnType, SchemaAccessorImpl, TableRef};
132
133    // PoSqlTableSource
134    #[test]
135    fn we_can_create_a_posql_table_source() {
136        // Empty
137        let table_source = PoSqlTableSource::new(vec![]);
138        assert_eq!(table_source.schema().all_fields(), Vec::<&Field>::new());
139        assert_eq!(
140            table_source.as_any().type_id(),
141            TypeId::of::<PoSqlTableSource>()
142        );
143
144        // Non-empty
145        let column_fields = vec![
146            ColumnField::new("a".into(), ColumnType::SmallInt),
147            ColumnField::new("b".into(), ColumnType::VarChar),
148        ];
149        let table_source = PoSqlTableSource::new(column_fields);
150        assert_eq!(
151            table_source.schema().all_fields(),
152            vec![
153                &Field::new("a", DataType::Int16, false),
154                &Field::new("b", DataType::Utf8, false),
155            ]
156        );
157        assert_eq!(
158            table_source.as_any().type_id(),
159            TypeId::of::<PoSqlTableSource>()
160        );
161    }
162
163    // PoSqlContextProvider
164    #[test]
165    fn we_can_create_a_posql_context_provider() {
166        // Empty
167        let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;});
168        let context_provider = PoSqlContextProvider::new(accessor);
169        assert_eq!(
170            context_provider.udfs_names(),
171            context_provider
172                .state
173                .scalar_functions()
174                .keys()
175                .cloned()
176                .collect::<Vec<_>>()
177        );
178        assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
179        assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
180        assert_eq!(context_provider.get_variable_type(&[]), None);
181        assert_eq!(context_provider.get_function_meta(""), None);
182        assert_eq!(context_provider.get_aggregate_meta(""), None);
183        assert_eq!(context_provider.get_window_meta(""), None);
184
185        // Non-empty
186        let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;
187            TableRef::new("namespace", "a") => vec![("a".into(), ColumnType::SmallInt),
188                ("b".into(), ColumnType::VarChar)],
189            TableRef::new("namespace", "b") => vec![("c".into(), ColumnType::Int),
190                ("d".into(), ColumnType::BigInt)]
191        });
192        let context_provider = PoSqlContextProvider::new(accessor);
193        assert_eq!(
194            context_provider.udfs_names(),
195            context_provider
196                .state
197                .scalar_functions()
198                .keys()
199                .cloned()
200                .collect::<Vec<_>>()
201        );
202        assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
203        assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
204        assert_eq!(context_provider.get_variable_type(&[]), None);
205        assert_eq!(context_provider.get_function_meta(""), None);
206        assert_eq!(context_provider.get_aggregate_meta(""), None);
207        assert_eq!(context_provider.get_window_meta(""), None);
208        assert_eq!(
209            context_provider
210                .get_table_source(TableReference::from("namespace.a"))
211                .unwrap()
212                .schema(),
213            Arc::new(PoSqlTableSource::new(vec![
214                ColumnField::new("a".into(), ColumnType::SmallInt),
215                ColumnField::new("b".into(), ColumnType::VarChar)
216            ]))
217            .schema()
218        );
219    }
220
221    #[test]
222    fn we_cannot_create_a_posql_context_provider_if_catalog_provided() {
223        let accessor = SchemaAccessorImpl::new(indexmap_with_default! {AHasher;});
224        let context_provider = PoSqlContextProvider::new(accessor);
225        assert!(matches!(
226            context_provider.get_table_source(TableReference::from("catalog.namespace.table")),
227            Err(DataFusionError::External(_))
228        ));
229    }
230}