use std::cell::RefCell;
use std::fmt::{Display, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use chrono::Utc;
use hamelin_sql::range_builder::RangeBuilder;
use serde::Serialize;
use tsify_next::Tsify;
use hamelin_lib::antlr::hamelinparser::{CommandContextAll, QueryEOFContextAttrs};
use hamelin_lib::catalog::CatalogProvider;
use hamelin_lib::err::ContextualTranslationErrors;
use hamelin_lib::err::{Context, TranslationError};
use hamelin_lib::func::def::FunctionTranslationContext;
use hamelin_lib::func::registry::FunctionRegistry;
use hamelin_lib::parse_command;
use hamelin_lib::parse_expression;
use hamelin_lib::parse_query;
use hamelin_lib::provider::EnvironmentProvider;
use hamelin_lib::resilient_parse_query;
use hamelin_lib::sql::expression::identifier::SimpleIdentifier;
use hamelin_lib::sql::expression::literal::{ColumnReference, TimestampLiteral};
use hamelin_lib::sql::expression::SQLExpression;
use hamelin_lib::sql::query::SQLQuery;
use hamelin_lib::sql::statement::Statement;
use hamelin_lib::sql::types::SQLTimestampTzType;
use hamelin_lib::translation::ExpressionTranslation;
use hamelin_lib::types::TIMESTAMP;
use hamelin_sql::utils::within_range;
use hamelin_sql::TranslationRegistry;
use crate::ast::command::within;
use crate::ast::expression::HamelinExpression;
use crate::ast::query::HamelinQuery;
use crate::ast::{ExpressionTranslationContext, QueryTranslationContext};
use crate::env::Environment;
use crate::translation::{
ContextualResult, DMLTranslation, PendingQuery, QueryTranslation, StatementTranslation,
Translation,
};
#[derive(Debug, Clone, Default)]
pub struct TimeRange {
pub start: Option<chrono::DateTime<Utc>>,
pub end: Option<chrono::DateTime<Utc>>,
}
#[derive(Clone)]
pub struct Compiler {
pub expression_environment: Arc<Environment>,
pub query_environment_provider: Arc<dyn EnvironmentProvider>,
pub time_range_filter: Option<SQLExpression>,
pub registry: Arc<FunctionRegistry>,
pub translation_registry: Arc<TranslationRegistry>,
}
#[derive(Serialize, Tsify)]
#[tsify(into_wasm_abi)]
pub struct FunctionDescription {
pub name: String,
pub parameters: String,
}
impl Compiler {
pub fn new() -> Self {
Self {
expression_environment: Arc::new(Environment::default()),
query_environment_provider: Arc::new(CatalogProvider::default()),
time_range_filter: None,
registry: Arc::new(FunctionRegistry::default()),
translation_registry: Arc::new(TranslationRegistry::default()),
}
}
pub fn get_function_descriptions(&self) -> Vec<FunctionDescription> {
self.registry
.function_defs
.iter()
.flat_map(|(name, defs)| {
defs.iter().map(|def| FunctionDescription {
name: name.clone(),
parameters: def.parameters().to_string(),
})
})
.collect()
}
pub fn set_environment(&mut self, environment: Arc<Environment>) {
self.expression_environment = environment.clone();
}
pub fn set_environment_provider(&mut self, provider: Arc<dyn EnvironmentProvider>) {
self.query_environment_provider = provider;
}
pub fn set_time_range(&mut self, time_range: TimeRange) {
let ident = ColumnReference::new(SimpleIdentifier::new("timestamp").into());
let mut range = RangeBuilder::default();
if let Some(from) = time_range.start {
range = range.with_begin(
TimestampLiteral::new(from).into(),
SQLTimestampTzType::new(3).into(),
)
}
if let Some(to) = time_range.end {
range = range.with_end(
TimestampLiteral::new(to).into(),
SQLTimestampTzType::new(3).into(),
);
}
self.time_range_filter = Some(within_range(ident.into(), range))
}
pub fn set_time_range_expression(
&mut self,
hamelin_range_expression: String,
) -> Result<(), ContextualTranslationErrors> {
let templated = format!("WITHIN {}", hamelin_range_expression);
let tree = parse_command(templated.clone())
.map_err(|e| ContextualTranslationErrors::new(templated.clone(), e))?;
let previous = PendingQuery::new(
SQLQuery::default(),
Environment::default().with_binding("timestamp".parse().unwrap(), TIMESTAMP),
);
let filter = if let CommandContextAll::WithinCommandContext(ctx) = tree.as_ref() {
within::translate(
ctx,
&previous,
QueryTranslationContext::new(
None,
self.query_environment_provider.clone(),
self.registry.clone(),
self.translation_registry.clone(),
None,
),
)
.map_err(|e| ContextualTranslationErrors::new(templated.clone(), e))?
.query
.where_
.expect("withins always have a where clause")
} else {
unreachable!()
};
self.time_range_filter = Some(filter);
Ok(())
}
pub fn compile_expression(
&self,
expression: String,
) -> Result<ExpressionTranslation, ContextualTranslationErrors> {
let ctx = parse_expression(expression.clone())
.map_err(|e| ContextualTranslationErrors::new(expression.clone(), e))?;
HamelinExpression::new(
ctx,
ExpressionTranslationContext::new(
self.expression_environment.clone(),
self.registry.clone(),
self.translation_registry.clone(),
FunctionTranslationContext::default(),
None,
Rc::new(RefCell::new(None)),
),
)
.translate()
.map_err(|e| ContextualTranslationErrors::new(expression, e))
}
pub fn compile(
&self,
hmln: String,
) -> Result<StatementTranslation, ContextualTranslationErrors> {
let ctx = parse_query(hmln.clone())
.map_err(|e| ContextualTranslationErrors::new(hmln.clone(), e))?;
let translation_context = QueryTranslationContext::new(
self.time_range_filter.clone(),
self.query_environment_provider.clone(),
self.registry.clone(),
self.translation_registry.clone(),
None,
);
let pending = HamelinQuery::new(ctx.clone(), translation_context.clone()).translate();
if !pending.errors.is_empty() {
Err(ContextualTranslationErrors::new(hmln, pending.errors))
} else {
let cols = pending.translation.env.into_external_columns();
let res = match pending.translation.statement {
Statement::SQLQuery(sqlquery) => QueryTranslation {
translation: Translation {
sql: sqlquery.to_string(),
columns: cols,
},
}
.into(),
Statement::DML(dml) => DMLTranslation {
translation: Translation {
sql: dml.to_string(),
columns: cols,
},
}
.into(),
};
Ok(res)
}
}
pub fn compile_query(
&self,
hmln: String,
) -> Result<QueryTranslation, ContextualTranslationErrors> {
match self.compile(hmln.clone())? {
StatementTranslation::Query(query_translation) => Ok(query_translation),
StatementTranslation::DML(_) => {
let len = hmln.len();
Err(ContextualTranslationErrors::new(
hmln,
TranslationError::new(Context::new(0..=len - 1, "statement has side effects"))
.single(),
))
}
}
}
pub fn compile_dml(&self, hmln: String) -> Result<DMLTranslation, ContextualTranslationErrors> {
match self.compile(hmln.clone())? {
StatementTranslation::DML(dmltranslation) => Ok(dmltranslation),
StatementTranslation::Query(_) => {
let len = hmln.len();
Err(ContextualTranslationErrors::new(
hmln,
TranslationError::new(Context::new(0..=len - 1, "statement is a query"))
.single(),
))
}
}
}
pub fn compile_query_at(&self, query: String, at: Option<usize>) -> ContextualResult {
if query
== "gQ3!mV@x2#Z9^LN7eKd$8wuT0pFzY*b&XHf5+v1RAoJ6MqPCrslijEkDWgBtUO4nchSmyV9Z$L&N^eXpQa"
{
panic!("Panic requested");
}
let mut res = ContextualResult::new(query.clone());
let (ctx, errors) = match resilient_parse_query(query.clone()) {
Ok(s) => s,
Err(e) => {
res.add_error(TranslationError::fatal(query.as_str(), e.into()));
return res;
}
};
res.add_errors(errors);
let translation_context = QueryTranslationContext::new(
self.time_range_filter.clone(),
self.query_environment_provider.clone(),
self.registry.clone(),
self.translation_registry.clone(),
at,
);
let pending = ctx
.query()
.map(|sctx| HamelinQuery::new(sctx, translation_context.clone()).translate())
.unwrap_or_default();
res.with_pending_result(pending)
.with_completions(translation_context.completions.clone())
}
pub fn get_statement_datasets(
&self,
query: String,
) -> Result<Vec<String>, ContextualTranslationErrors> {
let ctx = parse_query(query.clone())
.map_err(|e| ContextualTranslationErrors::new(query.clone(), e))?;
let pending = HamelinQuery::new(
ctx.clone(),
QueryTranslationContext::new(
self.time_range_filter.clone(),
self.query_environment_provider.clone(),
self.registry.clone(),
self.translation_registry.clone(),
None,
),
)
.translate()
.into_result()
.map_err(|e| ContextualTranslationErrors::new(query.clone(), e))?;
Ok(pending
.statement
.get_table_references()
.into_iter()
.map(|t| t.name.to_hamelin())
.collect())
}
}
impl Default for Compiler {
fn default() -> Self {
Self::new()
}
}
impl Display for Compiler {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Hamelin Compiler.\n")?;
write!(
f,
"Environment (for expressions): {}\n",
self.expression_environment
)?;
write!(
f,
"EnvironmentProvider (for queries): {:#?}\n",
self.query_environment_provider
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
#[test]
fn verify_send_sync() {
fn verify_send_sync<T: Send + Sync>() {}
verify_send_sync::<super::Compiler>();
}
}