use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use antlr_rust::parser_rule_context::ParserRuleContext;
use antlr_rust::token::{CommonToken, Token};
use anyhow::anyhow;
use hamelin_lib::func::err::MatchTestFailure;
use hamelin_lib::translation::ExpressionTranslation;
use ordermap::OrderMap;
use strsim;
use thiserror::Error;
use crate::ast::expression::HamelinExpression;
use crate::env::Environment;
use hamelin_lib::antlr::hamelinparser::{
BinaryOperatorContext, ExpressionContextAll, FunctionCallContext, FunctionCallContextAttrs,
NamedArgumentContextAll, NamedArgumentContextAttrs, PositionalArgumentContextAll,
PositionalArgumentContextAttrs, SimpleIdentifierContextAll, UnaryPostfixOperatorContext,
UnaryPostfixOperatorContextAttrs, UnaryPrefixOperatorContext, UnaryPrefixOperatorContextAttrs,
};
use hamelin_lib::err::{Context, TranslationError, TranslationErrors};
use hamelin_lib::func::def::{
DirectResolution, FunctionBindFailure, FunctionDef, FunctionParameterBindingFailure,
FunctionParameterBindingFailures, FunctionResolution, FunctionTranslationFailure,
};
use hamelin_lib::operator::Operator;
use hamelin_lib::parse_expression;
use hamelin_lib::sql::expression::apply::{FunctionCallApply, Lambda};
use hamelin_lib::sql::expression::identifier::HamelinSimpleIdentifier;
use hamelin_lib::sql::expression::identifier::{Identifier, SimpleIdentifier};
use hamelin_lib::types::array::Array;
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::Type;
use super::ExpressionTranslationContext;
pub trait CanMatchArgs {
fn function_name(&self) -> Result<String, TranslationErrors>;
fn positional(&self) -> Result<Vec<ExpressionTranslation>, TranslationErrors>;
fn named(&self) -> Result<OrderMap<String, ExpressionTranslation>, TranslationErrors>;
fn functions(&self) -> &HashMap<String, Vec<Arc<dyn FunctionDef>>>;
fn expression_translation_context(&self) -> Rc<ExpressionTranslationContext>;
fn match_function<'a, T>(&self, ctx: &T) -> Result<ExpressionTranslation, TranslationErrors>
where
T: ParserRuleContext<'a>,
{
let (function_name, positional, named) =
TranslationErrors::from_3(self.function_name(), self.positional(), self.named())?;
let mut attempts = vec![];
let lookup_key = function_name.to_lowercase();
for f in self.functions().get(&lookup_key).unwrap_or(&vec![]).iter() {
match bind(f.clone(), positional.clone(), named.clone()) {
Ok(resolution) => {
if let Some(special_position) = (**f).special_position() {
let fctx = self.expression_translation_context();
if !fctx.fctx.specials_allowed.contains(&special_position) {
attempts.push(ApplyAttempt {
function_def: (**f).to_string(),
binding_failures: FunctionParameterBindingFailures(vec![
FunctionParameterBindingFailure::DoesNotMatch(
format!("{} not allowed here", special_position).into(),
),
]),
});
continue;
}
}
let ectx = self.expression_translation_context();
match ectx.translation_registry.translate(
&**f,
&function_name,
&ectx.fctx,
resolution,
) {
Ok(t) => {
let nested_special = positional
.iter()
.chain(named.values())
.flat_map(|t| {
t.special
.as_ref()
.map(|s| (s.clone(), t.span.clone().unwrap()))
.into_iter()
.chain(t.nested_special.clone().into_iter())
})
.collect::<Vec<_>>();
match (**f).special_position() {
Some(s) => {
if !nested_special.is_empty() {
let err = TranslationError::msg(
ctx,
format!(
"Nested special function calls not allowed in {}",
s
)
.as_str(),
)
.with_context_vec(
nested_special
.into_iter()
.map(|(sp, rng)| {
Context::new(rng, sp.to_string().as_str())
})
.collect(),
)
.single();
return Err(err);
} else {
return Ok(t.clone().with_special(s));
}
}
None => {
if !nested_special.is_empty() {
let mut res = t.clone();
for (sp, rng) in nested_special.into_iter() {
res = res.with_nested_special(sp, rng);
}
return Ok(res);
} else {
return Ok(t);
}
}
}
}
Err(FunctionTranslationFailure::Fatal(e)) => {
return Err(TranslationError::wrap_box(ctx, e).into())
}
}
}
Err(FunctionBindFailure::ParameterBindingFailures(bf)) => {
attempts.push(ApplyAttempt {
function_def: (**f).to_string(),
binding_failures: bf,
});
}
Err(FunctionBindFailure::Fatal(e)) => {
return Err(TranslationError::wrap_box(ctx, e).into())
}
}
}
if let Some(t) = self.match_function_array_passthrough(&positional, &named)? {
return Ok(t);
}
if attempts.is_empty() {
let all_functions = self.functions();
let all_function_names: Vec<&String> = all_functions.keys().collect();
let suggestions = find_similar_function_names(&function_name, &all_function_names, 5);
let mut e = TranslationError::msg(ctx, "No matching function found");
if !suggestions.is_empty() {
e = e.with_source_boxed(
anyhow!(
"Did you mean one of these?\n{}",
suggestions
.iter()
.map(|s| format!(" - {}", s))
.collect::<Vec<_>>()
.join("\n")
)
.into(),
);
}
Err(e.single())
} else {
let mut error =
TranslationError::msg(ctx, "could not find a matching function definition")
.with_source(ApplyFailure(attempts));
let args = positional
.iter()
.map(|p| (p.span.clone(), &p.typ))
.chain(named.iter().map(|(_, v)| (v.span.clone(), &v.typ)));
for (span, typ) in args {
if let Some(s) = span {
error.add_context(s, &format!("{}", typ));
}
}
Err(error.into())
}
}
fn match_function_array_passthrough(
&self,
positional: &Vec<ExpressionTranslation>,
named: &OrderMap<String, ExpressionTranslation>,
) -> Result<Option<ExpressionTranslation>, TranslationErrors> {
let all_overrides = self.array_passthrough_overrides(positional)?;
if all_overrides.is_empty() {
return Ok(None);
}
let override_combinations = make_combinations(&all_overrides);
for overrides in override_combinations {
let mut new_positional = positional.clone();
let mut new_named = named.clone();
for override_ in &overrides {
match override_ {
ArgumentOverride::NamedOverride(no) => {
new_named.insert(no.name.clone(), no.translation.clone());
}
ArgumentOverride::PositionalOverride(po) => {
new_positional[po.position] = po.translation.clone();
}
}
}
let call_name = self.function_name()?;
let call_name_lower = call_name.to_lowercase();
let ectx = self.expression_translation_context();
for f in self
.functions()
.get(&call_name_lower)
.unwrap_or(&vec![])
.iter()
{
if let Ok(res) = bind(f.clone(), new_positional.clone(), new_named.clone()) {
if let Ok(translation) = ectx
.translation_registry
.translate(&**f, &call_name, &ectx.fctx, res)
{
let mut expression = translation.sql;
for override_ in &overrides {
match override_ {
ArgumentOverride::NamedOverride(no) => {
expression = FunctionCallApply::with_two(
"transform",
new_named.get(&no.name).unwrap().sql.clone(),
Lambda::from_single_argument(
no.identifier.clone(),
expression,
)
.into(),
)
.into();
}
ArgumentOverride::PositionalOverride(po) => {
expression = FunctionCallApply::with_two(
"transform",
positional.get(po.position).unwrap().sql.clone(),
Lambda::from_single_argument(
po.identifier.clone(),
expression,
)
.into(),
)
.into();
}
}
}
return Ok(Some(ExpressionTranslation::with_defaults(
Array::new(translation.typ).into(),
expression,
)));
}
}
}
}
Ok(None)
}
fn array_passthrough_overrides(
&self,
positional: &Vec<ExpressionTranslation>,
) -> Result<Vec<ArgumentOverride>, TranslationErrors> {
let mut overrides = vec![];
for (i, arg) in positional.iter().enumerate() {
if let Type::Array(a) = &arg.typ {
let reference_name = format!("e_{}", i);
let reference_ident: Identifier = reference_name.parse()?;
let reference_name_parsed = parse_expression(reference_name.clone())?;
let new_env = Arc::new(Environment::new(
Struct::default().with(reference_ident, (*a.element_type).clone()),
));
let new_context = self
.expression_translation_context()
.without_completion()
.with_env(new_env);
let translation =
HamelinExpression::new(reference_name_parsed, new_context).translate()?;
overrides.push(ArgumentOverride::PositionalOverride(PositionalOverride {
identifier: SimpleIdentifier::new(&reference_name),
translation,
position: i,
}));
}
}
for (key, value) in self.named()?.into_iter() {
if let Type::Array(a) = value.typ {
let reference_name = format!("e_{}", key);
let reference_ident: Identifier = reference_name.parse()?;
let reference_name_parsed = parse_expression(reference_name.clone())?;
let new_env = Arc::new(Environment::new(
Struct::default().with(reference_ident, *a.element_type),
));
let new_context = self
.expression_translation_context()
.without_completion()
.with_env(new_env);
let translation =
HamelinExpression::new(reference_name_parsed, new_context).translate()?;
overrides.push(ArgumentOverride::NamedOverride(NamedOverride {
identifier: SimpleIdentifier::new(&reference_name),
translation,
name: key.clone(),
}));
}
}
Ok(overrides)
}
}
fn bind(
func: Arc<dyn FunctionDef>,
positional: Vec<ExpressionTranslation>,
named: impl IntoIterator<Item = (String, ExpressionTranslation)>,
) -> Result<FunctionResolution<ExpressionTranslation>, FunctionBindFailure> {
let parameters = func.parameters();
let binding = parameters.bind_and_check(positional, named).map_err(|e| {
FunctionBindFailure::ParameterBindingFailures(FunctionParameterBindingFailures(e))
})?;
let typ = match func.return_type(&binding) {
Ok(t) => Ok(t),
Err(e) => match e.downcast::<MatchTestFailure>() {
Ok(mtf) => Err(FunctionBindFailure::ParameterBindingFailures(
FunctionParameterBindingFailures(vec![
FunctionParameterBindingFailure::DoesNotMatch(mtf.0),
]),
)),
Err(other) => Err(FunctionBindFailure::Fatal(other.into())),
},
}?;
Ok(DirectResolution {
function_def: func,
binding,
typ,
}
.into())
}
pub struct HamelinUnaryPrefixApply {
expression: Rc<ExpressionContextAll<'static>>,
operator: Box<CommonToken<'static>>,
expression_translation_context: Rc<ExpressionTranslationContext>,
}
impl HamelinUnaryPrefixApply {
pub fn try_new(
tree: &UnaryPrefixOperatorContext<'static>,
expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<Self, TranslationErrors> {
Ok(Self {
expression: TranslationErrors::expect(tree, tree.expression().clone())?,
operator: TranslationErrors::expect(tree, tree.operator.clone())?,
expression_translation_context,
})
}
}
impl CanMatchArgs for HamelinUnaryPrefixApply {
fn function_name(&self) -> Result<String, TranslationErrors> {
Operator::of(self.operator.get_text().to_uppercase().as_str())
.map(|o| o.to_string())
.map_err(|e| {
TranslationError::new(Context::new(
self.operator.get_start() as usize..=self.operator.get_stop() as usize,
"unexpected operator",
))
.with_source_boxed(e.into())
.into()
})
}
fn positional(&self) -> Result<Vec<ExpressionTranslation>, TranslationErrors> {
let expr = HamelinExpression::new(
self.expression.clone(),
self.expression_translation_context.clone(),
);
TranslationErrors::from_vec(vec![expr.translate()])
}
fn named(&self) -> Result<OrderMap<String, ExpressionTranslation>, TranslationErrors> {
Ok(OrderMap::new())
}
fn functions(&self) -> &HashMap<String, Vec<Arc<dyn FunctionDef>>> {
&self
.expression_translation_context
.registry
.unary_prefix_operation_defs
}
fn expression_translation_context(&self) -> Rc<ExpressionTranslationContext> {
self.expression_translation_context.clone()
}
}
pub struct HamelinUnaryPostfixApply {
expression: Rc<ExpressionContextAll<'static>>,
operator: Box<CommonToken<'static>>,
expression_translation_context: Rc<ExpressionTranslationContext>,
}
impl CanMatchArgs for HamelinUnaryPostfixApply {
fn function_name(&self) -> Result<String, TranslationErrors> {
Operator::of(self.operator.get_text().to_uppercase().as_str())
.map(|o| o.to_string())
.map_err(|e| {
TranslationError::new(Context::new(
self.operator.get_start() as usize..=self.operator.get_stop() as usize,
"unexpected operator",
))
.with_source_boxed(e.into())
.into()
})
}
fn functions(&self) -> &HashMap<String, Vec<Arc<dyn FunctionDef>>> {
&self
.expression_translation_context
.registry
.unary_postfix_operation_defs
}
fn positional(&self) -> Result<Vec<ExpressionTranslation>, TranslationErrors> {
let expr = HamelinExpression::new(
self.expression.clone(),
self.expression_translation_context.clone(),
);
TranslationErrors::from_vec(vec![expr.translate()])
}
fn named(&self) -> Result<OrderMap<String, ExpressionTranslation>, TranslationErrors> {
Ok(OrderMap::new())
}
fn expression_translation_context(&self) -> Rc<ExpressionTranslationContext> {
self.expression_translation_context.clone()
}
}
impl HamelinUnaryPostfixApply {
pub fn try_new(
tree: &UnaryPostfixOperatorContext<'static>,
expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<Self, TranslationErrors> {
Ok(Self {
expression: TranslationErrors::expect(tree, tree.expression().clone())?,
operator: TranslationErrors::expect(tree, tree.operator.clone())?,
expression_translation_context,
})
}
}
pub struct HamelinBinaryOperatorApply {
left: Rc<ExpressionContextAll<'static>>,
right: Rc<ExpressionContextAll<'static>>,
operator: Box<CommonToken<'static>>,
expression_translation_context: Rc<ExpressionTranslationContext>,
}
impl HamelinBinaryOperatorApply {
pub fn try_new(
tree: &BinaryOperatorContext<'static>,
expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<Self, TranslationErrors> {
Ok(Self {
left: TranslationErrors::expect(tree, tree.left.clone())?,
right: TranslationErrors::expect(tree, tree.right.clone())?,
operator: TranslationErrors::expect(tree, tree.operator.clone())?,
expression_translation_context,
})
}
}
impl CanMatchArgs for HamelinBinaryOperatorApply {
fn function_name(&self) -> Result<String, TranslationErrors> {
Operator::of(self.operator.get_text().to_uppercase().as_str())
.map(|o| o.to_string())
.map_err(|e| {
TranslationError::new(Context::new(
self.operator.get_start() as usize..=self.operator.get_stop() as usize,
"unexpected operator",
))
.with_source_boxed(e.into())
.into()
})
}
fn positional(&self) -> Result<Vec<ExpressionTranslation>, TranslationErrors> {
let left = HamelinExpression::new(
self.left.clone(),
self.expression_translation_context.clone(),
);
let right = HamelinExpression::new(
self.right.clone(),
self.expression_translation_context.clone(),
);
TranslationErrors::from_vec(vec![left.translate(), right.translate()])
}
fn named(&self) -> Result<OrderMap<String, ExpressionTranslation>, TranslationErrors> {
Ok(OrderMap::new())
}
fn functions(&self) -> &HashMap<String, Vec<Arc<dyn FunctionDef>>> {
&self
.expression_translation_context
.registry
.binary_operation_defs
}
fn expression_translation_context(&self) -> Rc<ExpressionTranslationContext> {
self.expression_translation_context.clone()
}
}
pub struct HamelinFunctionCallApply {
function_name: Rc<SimpleIdentifierContextAll<'static>>,
positional_arguments: Vec<Rc<PositionalArgumentContextAll<'static>>>,
named_arguments: Vec<Rc<NamedArgumentContextAll<'static>>>,
expression_translation_context: Rc<ExpressionTranslationContext>,
}
impl HamelinFunctionCallApply {
pub fn try_new(
tree: &FunctionCallContext<'static>,
expression_translation_context: Rc<ExpressionTranslationContext>,
) -> Result<Self, TranslationErrors> {
Ok(Self {
function_name: TranslationErrors::expect(tree, tree.functionName.clone())?,
named_arguments: tree.namedArgument_all().clone(),
positional_arguments: tree.positionalArgument_all().clone(),
expression_translation_context,
})
}
}
impl CanMatchArgs for HamelinFunctionCallApply {
fn function_name(&self) -> Result<String, TranslationErrors> {
Ok(HamelinSimpleIdentifier::new(self.function_name.clone())
.to_sql()?
.name
.to_lowercase())
}
fn positional(&self) -> Result<Vec<ExpressionTranslation>, TranslationErrors> {
TranslationErrors::from_vec(
self.positional_arguments
.iter()
.map(|pa| {
TranslationErrors::expect(pa.as_ref(), pa.expression()).and_then(|exp| {
HamelinExpression::new(exp, self.expression_translation_context.clone())
.translate()
})
})
.collect(),
)
}
fn named(&self) -> Result<OrderMap<String, ExpressionTranslation>, TranslationErrors> {
let keys = TranslationErrors::from_vec(
self.named_arguments
.iter()
.map(|ctx| {
TranslationErrors::expect(ctx.as_ref(), ctx.simpleIdentifier())
.and_then(|si| HamelinSimpleIdentifier::new(si).to_sql())
})
.collect(),
);
let values = TranslationErrors::from_vec(
self.named_arguments
.iter()
.map(|ctx| {
TranslationErrors::expect(ctx.as_ref(), ctx.expression()).and_then(|e| {
HamelinExpression::new(e, self.expression_translation_context.clone())
.translate()
})
})
.collect(),
);
let (keys, values) = TranslationErrors::from_2(keys, values)?;
Ok(keys
.into_iter()
.map(|k| k.name.to_lowercase())
.zip(values.into_iter())
.collect())
}
fn functions(&self) -> &HashMap<String, Vec<Arc<dyn FunctionDef>>> {
&self.expression_translation_context.registry.function_defs
}
fn expression_translation_context(&self) -> Rc<ExpressionTranslationContext> {
self.expression_translation_context.clone()
}
}
#[derive(Error, Debug)]
#[error("Attempted {function_def}\n{binding_failures}")]
pub struct ApplyAttempt {
pub function_def: String,
pub binding_failures: FunctionParameterBindingFailures,
}
#[derive(Error, Debug)]
pub struct ApplyFailure(pub Vec<ApplyAttempt>);
impl Display for ApplyFailure {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
for attempt in self.0.iter() {
write!(f, "{}\n", attempt)?;
}
Ok(())
}
}
#[derive(Clone)]
pub enum ArgumentOverride {
PositionalOverride(PositionalOverride),
NamedOverride(NamedOverride),
}
#[derive(Clone)]
pub struct PositionalOverride {
pub identifier: SimpleIdentifier,
pub translation: ExpressionTranslation,
pub position: usize,
}
#[derive(Clone)]
pub struct NamedOverride {
pub identifier: SimpleIdentifier,
pub translation: ExpressionTranslation,
pub name: String,
}
fn make_combinations(overrides: &[ArgumentOverride]) -> Vec<Vec<ArgumentOverride>> {
let mut result: Vec<Vec<ArgumentOverride>> = vec![vec![]];
for element in overrides {
let current_size = result.len();
for i in 0..current_size {
let mut new_combination = result[i].clone();
new_combination.push(element.clone());
result.push(new_combination);
}
}
result.remove(0);
result
}
fn find_similar_function_names(
target: &str,
available: &[&String],
max_suggestions: usize,
) -> Vec<String> {
let mut candidates: Vec<(String, f64)> = Vec::new();
let target_lower = target.to_lowercase();
for &func_name in available {
let func_lower = func_name.to_lowercase();
let mut score = 0.0f64;
if target_lower == func_lower {
score = 1.0;
}
else if func_lower.contains(&target_lower) {
let length_penalty = (func_name.len() - target.len()) as f64 / func_name.len() as f64;
score = 0.9 - (length_penalty * 0.3);
}
else if target_lower.contains(&func_lower) {
let length_penalty = (target.len() - func_name.len()) as f64 / target.len() as f64;
score = 0.8 - (length_penalty * 0.3);
}
else {
let jaro_winkler = strsim::jaro_winkler(&target_lower, &func_lower);
let normalized_levenshtein =
strsim::normalized_damerau_levenshtein(&target_lower, &func_lower);
let combined_score = (jaro_winkler * 0.7) + (normalized_levenshtein * 0.3);
if combined_score >= 0.6 {
score = combined_score * 0.7; }
}
if score > 0.0 {
candidates.push((func_name.clone(), score));
}
}
candidates.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then(a.0.cmp(&b.0))
});
candidates
.into_iter()
.take(max_suggestions)
.map(|(name, _)| name)
.collect()
}