proof_of_sql_planner/
context.rs

1use super::table_reference_to_table_ref;
2use crate::schema_to_column_fields;
3use alloc::sync::Arc;
4use arrow::datatypes::{Field, Schema};
5use core::any::Any;
6use datafusion::{
7    common::{
8        arrow::datatypes::{DataType, SchemaRef},
9        DataFusionError,
10    },
11    config::ConfigOptions,
12    logical_expr::{
13        AggregateUDF, Expr, ScalarUDF, TableProviderFilterPushDown, TableSource, WindowUDF,
14    },
15    sql::{planner::ContextProvider, TableReference},
16};
17use proof_of_sql::base::database::{ColumnField, SchemaAccessor};
18
19/// A [`ContextProvider`] implementation for Proof of SQL
20///
21/// This provider is used to provide tables to the Proof of SQL planner
22pub struct PoSqlContextProvider<A: SchemaAccessor> {
23    accessor: A,
24    options: ConfigOptions,
25}
26
27impl<A: SchemaAccessor> PoSqlContextProvider<A> {
28    /// Create a new `PoSqlContextProvider`
29    #[must_use]
30    pub fn new(accessor: A) -> Self {
31        Self {
32            accessor,
33            options: ConfigOptions::default(),
34        }
35    }
36}
37
38impl<A: SchemaAccessor> ContextProvider for PoSqlContextProvider<A> {
39    fn get_table_source(
40        &self,
41        name: TableReference,
42    ) -> Result<Arc<dyn TableSource>, DataFusionError> {
43        let table_ref = table_reference_to_table_ref(&name)
44            .map_err(|err| DataFusionError::External(Box::new(err)))?;
45        let schema = self.accessor.lookup_schema(&table_ref);
46        let column_fields = schema_to_column_fields(schema);
47        Ok(Arc::new(PoSqlTableSource::new(column_fields)) as Arc<dyn TableSource>)
48    }
49    fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
50        None
51    }
52    //TODO: add count and sum
53    fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
54        None
55    }
56    fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
57        None
58    }
59    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
60        None
61    }
62    fn options(&self) -> &ConfigOptions {
63        &self.options
64    }
65    fn udfs_names(&self) -> Vec<String> {
66        Vec::new()
67    }
68    fn udafs_names(&self) -> Vec<String> {
69        Vec::new()
70    }
71    fn udwfs_names(&self) -> Vec<String> {
72        Vec::new()
73    }
74}
75
76/// A [`TableSource`] implementation for Proof of SQL
77pub(crate) struct PoSqlTableSource {
78    schema: SchemaRef,
79}
80
81impl PoSqlTableSource {
82    /// Create a new `PoSqlTableSource`
83    pub(crate) fn new(column_fields: Vec<ColumnField>) -> Self {
84        let arrow_schema = Schema::new(
85            column_fields
86                .into_iter()
87                .map(|column_field| {
88                    Field::new(
89                        column_field.name().value.as_str(),
90                        (&column_field.data_type()).into(),
91                        false,
92                    )
93                })
94                .collect::<Vec<_>>(),
95        );
96        Self {
97            schema: Arc::new(arrow_schema),
98        }
99    }
100}
101
102impl TableSource for PoSqlTableSource {
103    fn as_any(&self) -> &dyn Any {
104        self
105    }
106    fn schema(&self) -> SchemaRef {
107        self.schema.clone()
108    }
109    fn supports_filters_pushdown(
110        &self,
111        filters: &[&Expr],
112    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
113        Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use ahash::AHasher;
121    use alloc::vec;
122    use core::any::TypeId;
123    use indexmap::indexmap_with_default;
124    use proof_of_sql::base::database::{ColumnType, TableRef, TestSchemaAccessor};
125
126    // PoSqlTableSource
127    #[test]
128    fn we_can_create_a_posql_table_source() {
129        // Empty
130        let table_source = PoSqlTableSource::new(vec![]);
131        assert_eq!(table_source.schema().all_fields(), Vec::<&Field>::new());
132        assert_eq!(
133            table_source.as_any().type_id(),
134            TypeId::of::<PoSqlTableSource>()
135        );
136
137        // Non-empty
138        let column_fields = vec![
139            ColumnField::new("a".into(), ColumnType::SmallInt),
140            ColumnField::new("b".into(), ColumnType::VarChar),
141        ];
142        let table_source = PoSqlTableSource::new(column_fields);
143        assert_eq!(
144            table_source.schema().all_fields(),
145            vec![
146                &Field::new("a", DataType::Int16, false),
147                &Field::new("b", DataType::Utf8, false),
148            ]
149        );
150        assert_eq!(
151            table_source.as_any().type_id(),
152            TypeId::of::<PoSqlTableSource>()
153        );
154    }
155
156    // PoSqlContextProvider
157    #[test]
158    fn we_can_create_a_posql_context_provider() {
159        // Empty
160        let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;});
161        let context_provider = PoSqlContextProvider::new(accessor);
162        assert_eq!(context_provider.udfs_names(), Vec::<String>::new());
163        assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
164        assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
165        assert_eq!(context_provider.get_variable_type(&[]), None);
166        assert_eq!(context_provider.get_function_meta(""), None);
167        assert_eq!(context_provider.get_aggregate_meta(""), None);
168        assert_eq!(context_provider.get_window_meta(""), None);
169        assert_eq!(
170            context_provider
171                .get_table_source(TableReference::from("namespace.table"))
172                .unwrap()
173                .schema(),
174            PoSqlTableSource::new(Vec::new()).schema()
175        );
176
177        // Non-empty
178        let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;
179            TableRef::new("namespace", "a") => indexmap_with_default! {AHasher;
180                "a".into() => ColumnType::SmallInt,
181                "b".into() => ColumnType::VarChar
182            },
183            TableRef::new("namespace", "b") => indexmap_with_default! {AHasher;
184                "c".into() => ColumnType::Int,
185                "d".into() => ColumnType::BigInt
186            },
187        });
188        let context_provider = PoSqlContextProvider::new(accessor);
189        assert_eq!(context_provider.udfs_names(), Vec::<String>::new());
190        assert_eq!(context_provider.udafs_names(), Vec::<String>::new());
191        assert_eq!(context_provider.udwfs_names(), Vec::<String>::new());
192        assert_eq!(context_provider.get_variable_type(&[]), None);
193        assert_eq!(context_provider.get_function_meta(""), None);
194        assert_eq!(context_provider.get_aggregate_meta(""), None);
195        assert_eq!(context_provider.get_window_meta(""), None);
196        assert_eq!(
197            context_provider
198                .get_table_source(TableReference::from("namespace.a"))
199                .unwrap()
200                .schema(),
201            Arc::new(PoSqlTableSource::new(vec![
202                ColumnField::new("a".into(), ColumnType::SmallInt),
203                ColumnField::new("b".into(), ColumnType::VarChar)
204            ]))
205            .schema()
206        );
207    }
208
209    #[test]
210    fn we_cannot_create_a_posql_context_provider_if_catalog_provided() {
211        let accessor = TestSchemaAccessor::new(indexmap_with_default! {AHasher;});
212        let context_provider = PoSqlContextProvider::new(accessor);
213        assert!(matches!(
214            context_provider.get_table_source(TableReference::from("catalog.namespace.table")),
215            Err(DataFusionError::External(_))
216        ));
217    }
218}