use std::any::TypeId;
use std::collections::HashMap;
use hamelin_lib::func::def::{
FunctionDef, FunctionTranslationContext, FunctionTranslationFailure, ParameterBinding,
};
use hamelin_lib::sql::expression::SQLExpression;
use hamelin_lib::sql::query::window::WindowExpression;
use hamelin_lib::translation::ExpressionTranslation;
use hamelin_lib::types::Type;
use crate::func;
pub type TranslationFn = Box<
dyn Fn(&str, ParameterBinding<ExpressionTranslation>) -> anyhow::Result<SQLExpression>
+ Send
+ Sync,
>;
pub struct TranslationRegistry {
impls: HashMap<TypeId, TranslationFn>,
}
impl std::fmt::Debug for TranslationRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TranslationRegistry")
.field("impls", &format!("<{} translations>", self.impls.len()))
.finish()
}
}
impl TranslationRegistry {
pub fn new() -> Self {
Self {
impls: HashMap::new(),
}
}
pub fn register<F: 'static>(
&mut self,
f: impl Fn(&str, ParameterBinding<ExpressionTranslation>) -> anyhow::Result<SQLExpression>
+ Send
+ Sync
+ 'static,
) {
self.impls.insert(TypeId::of::<F>(), Box::new(f));
}
pub fn translate(
&self,
func: &dyn FunctionDef,
call_name: &str,
fctx: &FunctionTranslationContext,
binding: ParameterBinding<ExpressionTranslation>,
typ: Type,
) -> Result<ExpressionTranslation, FunctionTranslationFailure> {
let type_id = func.type_id();
let translation_fn = self.impls.get(&type_id).ok_or_else(|| {
FunctionTranslationFailure::Fatal(
format!("No translation registered for function '{}'", call_name).into(),
)
})?;
let sql = translation_fn(call_name, binding)
.map_err(|e| FunctionTranslationFailure::Fatal(e.into()))?;
let ordered = if func.sortable_input() && !fctx.order_by.is_empty() {
match sql {
SQLExpression::FunctionCallApply(function_call_apply) => function_call_apply
.with_order_by(fctx.order_by.clone())
.into(),
_ => sql,
}
} else {
sql
};
let wrapped = match (func.special_position(), &fctx.window) {
(Some(position), Some(window)) if fctx.specials_allowed.contains(&position) => {
WindowExpression::new(ordered, window.clone()).into()
}
(Some(position), None) if fctx.specials_allowed.contains(&position) => ordered,
(Some(position), _) => {
return Err(FunctionTranslationFailure::Fatal(
format!("{} not allowed here", position).into(),
))
}
(None, _) => ordered,
};
Ok(ExpressionTranslation::with_defaults(typ, wrapped))
}
}
impl Default for TranslationRegistry {
fn default() -> Self {
let mut registry = Self::new();
func::register(&mut registry);
registry
}
}