hamelin_lib/
lib.rs

1extern crate core;
2
3use std::cell::RefCell;
4use std::fmt::{Display, Formatter};
5use std::rc::Rc;
6use std::sync::Arc;
7
8use antlr::hamelinparser::{
9    CommandContextAll, CommandEOFContextAttrs, ExpressionContextAll, ExpressionEOFContextAttrs,
10    IdentifierContextAll, IdentifierEOFContextAll, IdentifierEOFContextAttrs,
11    SimpleIdentifierContextAll, SimpleIdentifierEOFContextAttrs,
12};
13use antlr_rust::errors::ANTLRError;
14use anyhow::anyhow;
15use ast::command::within;
16use ast::err::ContextualTranslationErrors;
17use ast::ExpressionTranslationContext;
18use ast::QueryTranslationContext;
19use chrono::Utc;
20use func::def::FunctionTranslationContext;
21use func::utils::within_range;
22use serde::Serialize;
23use sql::expression::identifier::SimpleIdentifier;
24use sql::expression::literal::{ColumnReference, TimestampLiteral};
25use sql::expression::SQLExpression;
26use sql::query::SQLQuery;
27use sql::types::SQLTimestampTzType;
28use translation::range_builder::RangeBuilder;
29use translation::{ContextualResult, PendingQuery, Translation};
30use tsify_next::Tsify;
31use types::TIMESTAMP;
32
33use crate::antlr::hamelinparser::{
34    ExpressionEOFContextAll, HamelintypeContextAll, PipelineEOFContextAll, QueryContextAll,
35    QueryEOFContextAll, QueryEOFContextAttrs,
36};
37use crate::ast::err::{Context, TranslationError, TranslationErrors};
38use crate::ast::expression::HamelinExpression;
39use crate::ast::query::HamelinQuery;
40use crate::catalog::CatalogProvider;
41use crate::func::registry::FunctionRegistry;
42use crate::parser::{make_hamelin_parser_from_input, HamelinStringParser};
43use crate::provider::EnvironmentProvider;
44use crate::sql::statement::Statement;
45use crate::translation::env::Environment;
46use crate::translation::{
47    DMLTranslation, ExpressionTranslation, QueryTranslation, StatementTranslation,
48};
49
50pub mod antlr;
51pub mod ast;
52pub mod catalog;
53pub mod eval;
54pub mod func;
55pub mod incremental;
56pub mod interner;
57pub mod parser;
58pub mod provider;
59pub mod reverse_eval;
60pub mod sql;
61pub mod translation;
62pub mod tree;
63pub mod types;
64mod write_utils;
65
66pub const VERSION: &str = env!("CARGO_PKG_VERSION");
67
68pub fn execute_resillient_parse<T>(
69    input: String,
70    parse_fn: fn(HamelinStringParser) -> Result<T, ANTLRError>,
71) -> anyhow::Result<(T, TranslationErrors)> {
72    let (node, errors) = {
73        let (parser, errors) = make_hamelin_parser_from_input(input);
74        parse_fn(parser)
75            .map_err(|e| anyhow!(e.to_string()))
76            .map(|node| (node, errors))?
77    };
78    let unwrapped = Rc::try_unwrap(errors)
79        .expect("could not unwrap Rc<TranslationErrors> after parsing complete.")
80        .into_inner();
81    Ok((node, unwrapped))
82}
83
84pub fn execute_parse<T>(
85    input: String,
86    parse_fn: fn(HamelinStringParser) -> Result<T, ANTLRError>,
87) -> Result<T, TranslationErrors> {
88    let len = input.len();
89    let (node, errors) = {
90        let (parser, errors) = make_hamelin_parser_from_input(input);
91        parse_fn(parser)
92            .map_err(|e| {
93                TranslationError::new(Context::new(0..=len - 1, "failed to load grammars"))
94                    .with_source_boxed(e.to_string().into())
95                    .single()
96            })
97            .map(|node| (node, errors))?
98    };
99    let unwrapped = Rc::try_unwrap(errors)
100        .expect("could not unwrap Rc<TranslationErrors> after parsing complete.")
101        .into_inner();
102    unwrapped.or_ok(node)
103}
104
105pub fn parse_command(command: String) -> Result<Rc<CommandContextAll<'static>>, TranslationErrors> {
106    execute_parse(command, |mut parser| parser.commandEOF())
107        .map(|ctx| ctx.command().expect("required"))
108}
109
110pub fn parse_query(statement: String) -> Result<Rc<QueryContextAll<'static>>, TranslationErrors> {
111    execute_parse(statement, |mut parser| parser.queryEOF())
112        .map(|ctx| ctx.query().expect("required"))
113}
114
115pub fn resilient_parse_query(
116    statement: String,
117) -> anyhow::Result<(Rc<QueryEOFContextAll<'static>>, TranslationErrors)> {
118    execute_resillient_parse(statement, |mut parser| parser.queryEOF())
119}
120
121pub fn parse_expression(
122    expression: String,
123) -> Result<Rc<ExpressionContextAll<'static>>, TranslationErrors> {
124    execute_parse(expression, |mut parser| parser.expressionEOF())
125        .map(|ctx| ctx.expression().expect("required"))
126}
127
128pub fn resilient_parse_expression(
129    expression: String,
130) -> anyhow::Result<(Rc<ExpressionEOFContextAll<'static>>, TranslationErrors)> {
131    execute_resillient_parse(expression, |mut parser| parser.expressionEOF())
132}
133
134pub fn resilient_parse_pipeline(
135    pipeline: String,
136) -> anyhow::Result<(Rc<PipelineEOFContextAll<'static>>, TranslationErrors)> {
137    execute_resillient_parse(pipeline, |mut parser| parser.pipelineEOF())
138}
139
140pub fn parse_identifier(
141    identifier: String,
142) -> Result<Rc<IdentifierContextAll<'static>>, TranslationErrors> {
143    execute_parse(identifier, |mut parser| parser.identifierEOF())
144        .map(|ctx| ctx.identifier().expect("required"))
145}
146
147pub fn resilient_parse_identifier(
148    identifier: String,
149) -> anyhow::Result<(Rc<IdentifierEOFContextAll<'static>>, TranslationErrors)> {
150    execute_resillient_parse(identifier, |mut parser| parser.identifierEOF())
151}
152
153pub fn parse_simple_identifier(
154    identifier: String,
155) -> Result<Rc<SimpleIdentifierContextAll<'static>>, TranslationErrors> {
156    execute_parse(identifier, |mut parser| parser.simpleIdentifierEOF())
157        .map(|ctx| ctx.simpleIdentifier().expect("required"))
158}
159
160pub fn parse_type(
161    type_string: String,
162) -> Result<Rc<HamelintypeContextAll<'static>>, TranslationErrors> {
163    execute_parse(type_string, |mut parser| parser.hamelintype())
164}
165
166// TODO: Refactor to usu a Rust RangeInclusive, because it's more idiomatic and easier to use
167#[derive(Debug, Clone, Default)]
168pub struct TimeRange {
169    pub start: Option<chrono::DateTime<Utc>>,
170    pub end: Option<chrono::DateTime<Utc>>,
171}
172
173#[derive(Clone)]
174pub struct Compiler {
175    pub expression_environment: Arc<Environment>,
176    pub query_environment_provider: Arc<dyn EnvironmentProvider>,
177    pub time_range_filter: Option<SQLExpression>,
178    pub registry: Arc<FunctionRegistry>,
179}
180
181#[derive(Serialize, Tsify)]
182#[tsify(into_wasm_abi)]
183pub struct FunctionDescription {
184    pub name: String,
185    pub parameters: String,
186}
187
188impl Compiler {
189    pub fn new() -> Self {
190        Self {
191            expression_environment: Arc::new(Environment::default()),
192            query_environment_provider: Arc::new(CatalogProvider::default()),
193            time_range_filter: None,
194            registry: Arc::new(FunctionRegistry::default()),
195        }
196    }
197
198    pub fn get_function_descriptions(&self) -> Vec<FunctionDescription> {
199        self.registry
200            .function_defs
201            .iter()
202            .flat_map(|(name, defs)| {
203                defs.iter().map(|def| FunctionDescription {
204                    name: name.clone(),
205                    parameters: def.parameters.to_string(),
206                })
207            })
208            .collect()
209    }
210
211    /// Use the given environment for compiling expressions
212    pub fn set_environment(&mut self, environment: Arc<Environment>) {
213        self.expression_environment = environment.clone();
214    }
215
216    /// Use the given environment provider for compiling queries.
217    pub fn set_environment_provider(&mut self, provider: Arc<dyn EnvironmentProvider>) {
218        self.query_environment_provider = provider;
219    }
220
221    pub fn set_time_range(&mut self, time_range: TimeRange) {
222        let ident = ColumnReference::new(SimpleIdentifier::new("timestamp").into());
223        let mut range = RangeBuilder::default();
224
225        if let Some(from) = time_range.start {
226            range = range.with_begin(
227                TimestampLiteral::new(from).into(),
228                SQLTimestampTzType::new(3).into(),
229            )
230        }
231
232        if let Some(to) = time_range.end {
233            range = range.with_end(
234                TimestampLiteral::new(to).into(),
235                SQLTimestampTzType::new(3).into(),
236            );
237        }
238
239        self.time_range_filter = Some(within_range(ident.into(), range))
240    }
241
242    pub fn set_time_range_expression(
243        &mut self,
244        hamelin_range_expression: String,
245    ) -> Result<(), ContextualTranslationErrors> {
246        let templated = format!("WITHIN {}", hamelin_range_expression);
247        let tree = parse_command(templated.clone())
248            .map_err(|e| ContextualTranslationErrors::new(templated.clone(), e))?;
249        let previous = PendingQuery::new(
250            SQLQuery::default(),
251            Environment::default().with_binding("timestamp".parse().unwrap(), TIMESTAMP),
252        );
253
254        let filter = if let CommandContextAll::WithinCommandContext(ctx) = tree.as_ref() {
255            within::translate(
256                ctx,
257                &previous,
258                QueryTranslationContext::new(
259                    None,
260                    self.query_environment_provider.clone(),
261                    self.registry.clone(),
262                    None,
263                ),
264            )
265            .map_err(|e| ContextualTranslationErrors::new(templated.clone(), e))?
266            .query
267            .where_
268            .expect("withins always have a where clause")
269        } else {
270            unreachable!()
271        };
272        self.time_range_filter = Some(filter);
273        Ok(())
274    }
275
276    pub fn compile_expression(
277        &self,
278        expression: String,
279    ) -> Result<ExpressionTranslation, ContextualTranslationErrors> {
280        let ctx = parse_expression(expression.clone())
281            .map_err(|e| ContextualTranslationErrors::new(expression.clone(), e))?;
282
283        HamelinExpression::new(
284            ctx,
285            ExpressionTranslationContext::new(
286                self.expression_environment.clone(),
287                self.registry.clone(),
288                FunctionTranslationContext::default(),
289                None,
290                Rc::new(RefCell::new(None)),
291            ),
292        )
293        .translate()
294        .map_err(|e| ContextualTranslationErrors::new(expression, e))
295    }
296
297    pub fn compile(
298        &self,
299        hmln: String,
300    ) -> Result<StatementTranslation, ContextualTranslationErrors> {
301        let ctx = parse_query(hmln.clone())
302            .map_err(|e| ContextualTranslationErrors::new(hmln.clone(), e))?;
303
304        let translation_context = QueryTranslationContext::new(
305            self.time_range_filter.clone(),
306            self.query_environment_provider.clone(),
307            self.registry.clone(),
308            None,
309        );
310        let pending = HamelinQuery::new(ctx.clone(), translation_context.clone()).translate();
311        if !pending.errors.is_empty() {
312            Err(ContextualTranslationErrors::new(hmln, pending.errors))
313        } else {
314            let cols = pending.translation.env.into_external_columns();
315            let res = match pending.translation.statement {
316                Statement::SQLQuery(sqlquery) => QueryTranslation {
317                    translation: Translation {
318                        sql: sqlquery.to_string(),
319                        columns: cols,
320                    },
321                }
322                .into(),
323                Statement::DML(dml) => DMLTranslation {
324                    translation: Translation {
325                        sql: dml.to_string(),
326                        columns: cols,
327                    },
328                }
329                .into(),
330            };
331
332            Ok(res)
333        }
334    }
335
336    pub fn compile_query(
337        &self,
338        hmln: String,
339    ) -> Result<QueryTranslation, ContextualTranslationErrors> {
340        match self.compile(hmln.clone())? {
341            StatementTranslation::Query(query_translation) => Ok(query_translation),
342            StatementTranslation::DML(_) => {
343                let len = hmln.len();
344                Err(ContextualTranslationErrors::new(
345                    hmln,
346                    TranslationError::new(Context::new(0..=len - 1, "statement has side effects"))
347                        .single(),
348                ))
349            }
350        }
351    }
352
353    pub fn compile_dml(&self, hmln: String) -> Result<DMLTranslation, ContextualTranslationErrors> {
354        match self.compile(hmln.clone())? {
355            StatementTranslation::DML(dmltranslation) => Ok(dmltranslation),
356            StatementTranslation::Query(_) => {
357                let len = hmln.len();
358                Err(ContextualTranslationErrors::new(
359                    hmln,
360                    TranslationError::new(Context::new(0..=len - 1, "statement is a query"))
361                        .single(),
362                ))
363            }
364        }
365    }
366
367    pub fn compile_query_at(&self, query: String, at: Option<usize>) -> ContextualResult {
368        // Travis wants me to panic if I am given a magic string value. That way he can test his panic handling?
369        if query
370            == "gQ3!mV@x2#Z9^LN7eKd$8wuT0pFzY*b&XHf5+v1RAoJ6MqPCrslijEkDWgBtUO4nchSmyV9Z$L&N^eXpQa"
371        {
372            panic!("Panic requested");
373        }
374
375        let mut res = ContextualResult::new(query.clone());
376
377        let (ctx, errors) = match resilient_parse_query(query.clone()) {
378            Ok(s) => s,
379            Err(e) => {
380                res.add_error(TranslationError::fatal(query.as_str(), e.into()));
381                return res;
382            }
383        };
384        res.add_errors(errors);
385
386        let translation_context = QueryTranslationContext::new(
387            self.time_range_filter.clone(),
388            self.query_environment_provider.clone(),
389            self.registry.clone(),
390            at,
391        );
392        let pending = ctx
393            .query()
394            .map(|sctx| HamelinQuery::new(sctx, translation_context.clone()).translate())
395            .unwrap_or_default();
396
397        res.with_pending_result(pending)
398            .with_completions(translation_context.completions.clone())
399    }
400
401    pub fn get_statement_datasets(
402        &self,
403        query: String,
404    ) -> Result<Vec<String>, ContextualTranslationErrors> {
405        let ctx = parse_query(query.clone())
406            .map_err(|e| ContextualTranslationErrors::new(query.clone(), e))?;
407
408        let pending = HamelinQuery::new(
409            ctx.clone(),
410            QueryTranslationContext::new(
411                self.time_range_filter.clone(),
412                self.query_environment_provider.clone(),
413                self.registry.clone(),
414                None,
415            ),
416        )
417        .translate()
418        .into_result()
419        .map_err(|e| ContextualTranslationErrors::new(query.clone(), e))?;
420
421        Ok(pending
422            .statement
423            .get_table_references()
424            .into_iter()
425            .map(|t| t.name.to_hamelin())
426            .collect())
427    }
428}
429
430impl Display for Compiler {
431    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
432        write!(f, "Hamelin Compiler.\n")?;
433        write!(
434            f,
435            "Environment (for expressions): {}\n",
436            self.expression_environment
437        )?;
438        write!(
439            f,
440            "EnvironmentProvider (for queries): {:#?}\n",
441            self.query_environment_provider
442        )?;
443
444        Ok(())
445    }
446}
447
448#[derive(Clone)]
449pub struct CompilerContext {}
450
451#[cfg(test)]
452mod tests {
453    #[test]
454    fn verify_send_sync() {
455        fn verify_send_sync<T: Send + Sync>() {}
456        verify_send_sync::<super::Compiler>();
457    }
458}