datafusion_kql/
session.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use datafusion::arrow::datatypes::DataType;
use datafusion::config::ConfigOptions;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::DefaultTableSource;
use datafusion::execution::SessionState;
use datafusion::execution::context::SessionContext;

use datafusion_common::{not_impl_err, plan_datafusion_err, DataFusionError, ResolvedTableReference, Result, TableReference};

use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF};
use datafusion_expr::planner::ContextProvider;
use datafusion_expr::registry::FunctionRegistry;

use kqlparser::ast::Statement;
use kqlparser::parser::parse;

use std::collections::HashMap;
use std::sync::Arc;

use crate::planner::KqlToRel;

#[allow(async_fn_in_trait)]
pub trait SessionContextExt {
    async fn kql(&self, sql: &str) -> Result<DataFrame>;
}

#[allow(async_fn_in_trait)]
pub trait SessionStateExt {
    async fn create_logical_plan_kql(&self, kql: &str) -> Result<LogicalPlan>;
    fn kql_to_statement(&self, kql: &str) -> Result<Statement>;
    async fn kql_statement_to_plan(&self, statement: Statement) -> Result<LogicalPlan>;
}

impl SessionContextExt for SessionContext {
    async fn kql(&self, kql: &str) -> Result<DataFrame> {
        let plan = self.state().create_logical_plan_kql(kql).await?;
        self.execute_logical_plan(plan).await
    }
}

impl SessionStateExt for SessionState {
    async fn create_logical_plan_kql(&self, kql: &str) -> Result<LogicalPlan> {
        //let dialect = self.config.options().sql_parser.dialect.as_str();
        let statement = self.kql_to_statement(kql)?;
        let plan = self.kql_statement_to_plan(statement).await?;
        Ok(plan)
    }
    
    fn kql_to_statement(&self, kql: &str) -> Result<Statement> {
        let mut statements = parse(kql).unwrap().1;
        if statements.len() > 1 {
            return not_impl_err!(
                "The context currently only supports a single KQL statement"
            )
        }
        Ok(statements.pop().ok_or_else(|| {
            plan_datafusion_err!("No KQL statements were provided in the query string")
        })?)
    }
    
    async fn kql_statement_to_plan(&self, statement: Statement) -> Result<LogicalPlan> {
        let mut provider = SessionContextProvider {
            state: self,
            tables: HashMap::with_capacity(10),
        };

        let catalog_list = self.catalog_list();
        for catalog in catalog_list.catalog_names() {
            let schema_list = catalog_list.catalog(&catalog).ok_or_else(|| DataFusionError::Plan(format!("Catalog '{catalog}' not found")))?;
            for schema in schema_list.schema_names() {
                let table_list = schema_list.schema(&schema).ok_or_else(|| DataFusionError::Plan(format!("Schema '{schema}' not found")))?;
                for table in table_list.table_names() {
                    let resolved_ref = ResolvedTableReference {
                        catalog: Arc::from(Box::from(catalog.clone())),
                        schema: Arc::from(Box::from(schema.clone())),
                        table: Arc::from(Box::from(table.clone()))
                    };
                    let table_provider = table_list.table(&table).await?.ok_or_else(|| DataFusionError::Plan(format!("Table '{table}' not found")))?;
                    provider.tables.insert(resolved_ref.to_string(), Arc::new(DefaultTableSource::new(table_provider)));
                    //println!("Table: {}", resolved_ref.to_string());
                }
            }
        }

        let kql = KqlToRel::new(&provider);
        match statement {
            Statement::TabularExpression(query) => kql.query_to_plan(&query),
            _ => not_impl_err!("Statement type not supported")
        }
    }
}

struct SessionContextProvider<'a> {
    state: &'a SessionState,
    tables: HashMap<String, Arc<dyn TableSource>>,
}

impl<'a> ContextProvider for SessionContextProvider<'a> {
    fn get_table_source(&self, name: TableReference) -> Result<Arc<dyn TableSource>> {
        let catalog = &self.state.config_options().catalog;
        let name = name.resolve(&catalog.default_catalog, &catalog.default_schema).to_string();
        self.tables.get(&name).cloned().ok_or_else(|| DataFusionError::Plan(format!("Table '{}' not found", name)))
    }

    fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
        self.state.udf(name).ok()
    }

    fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
        self.state.udaf(name).ok()
    }

    fn get_variable_type(&self, _variable_names: &[String]) -> Option<DataType> {
        None
    }
    
    fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
        self.state.udwf(name).ok()
    }
    
    fn options(&self) -> &ConfigOptions {
        !unimplemented!()
    }
    
    fn udf_names(&self) -> Vec<String> {
        self.state.scalar_functions().keys().cloned().collect()
    }
    
    fn udaf_names(&self) -> Vec<String> {
        self.state.aggregate_functions().keys().cloned().collect()
    }
    
    fn udwf_names(&self) -> Vec<String> {
        self.state.window_functions().keys().cloned().collect()
    }
}