mod aggregate;
mod arithmetic;
mod array;
mod comparison;
mod conditional;
mod datetime;
mod interval;
mod json;
mod logical;
mod map;
mod math;
mod membership;
mod operators;
mod regex;
mod string;
mod window;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::logical_expr::Expr as DFExpr;
use hamelin_lib::func::def::{FunctionDef, HasType, ParameterBinding};
use hamelin_lib::types::Type;
#[derive(Debug, thiserror::Error)]
pub enum FunctionTranslationFailure {
#[error("No DataFusion translation registered for function '{0}'")]
NoTranslation(String),
#[error("Translation error for function '{function}': {message}")]
TranslationError { function: String, message: String },
}
#[derive(Clone)]
pub struct DFTranslation {
pub expr: DFExpr,
pub typ: Arc<Type>,
}
impl DFTranslation {
pub fn new(expr: DFExpr, typ: Arc<Type>) -> Self {
Self { expr, typ }
}
}
impl HasType for DFTranslation {
fn typ(&self) -> &Type {
&self.typ
}
}
pub type DFTranslationFn =
Box<dyn Fn(ParameterBinding<DFTranslation>) -> anyhow::Result<DFExpr> + Send + Sync>;
pub struct DataFusionTranslationRegistry {
impls: HashMap<TypeId, DFTranslationFn>,
}
impl std::fmt::Debug for DataFusionTranslationRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DataFusionTranslationRegistry")
.field("registered_count", &self.impls.len())
.finish()
}
}
impl DataFusionTranslationRegistry {
pub fn new() -> Self {
Self {
impls: HashMap::new(),
}
}
pub fn register<F: FunctionDef>(
&mut self,
f: impl Fn(ParameterBinding<DFTranslation>) -> anyhow::Result<DFExpr> + Send + Sync + 'static,
) {
self.impls.insert(TypeId::of::<F>(), Box::new(f));
}
pub fn translate(
&self,
func: &dyn FunctionDef,
binding: ParameterBinding<DFTranslation>,
) -> Result<DFExpr, FunctionTranslationFailure> {
let type_id = func.type_id();
match self.impls.get(&type_id) {
Some(translate_fn) => {
translate_fn(binding).map_err(|e| FunctionTranslationFailure::TranslationError {
function: func.name().to_string(),
message: e.to_string(),
})
}
None => Err(FunctionTranslationFailure::NoTranslation(
func.name().to_string(),
)),
}
}
}
impl Default for DataFusionTranslationRegistry {
fn default() -> Self {
let mut registry = Self::new();
aggregate::register(&mut registry);
arithmetic::register(&mut registry);
array::register(&mut registry);
comparison::register(&mut registry);
conditional::register(&mut registry);
datetime::register(&mut registry);
interval::register(&mut registry);
json::register(&mut registry);
logical::register(&mut registry);
map::register(&mut registry);
math::register(&mut registry);
membership::register(&mut registry);
operators::register(&mut registry);
regex::register(&mut registry);
string::register(&mut registry);
window::register(&mut registry);
registry
}
}