use std::fmt::{Debug, Display, Formatter};
use std::rc::Rc;
use std::sync::Arc;
use antlr_rust::parser_rule_context::ParserRuleContext;
use antlr_rust::tree::{ParseTree, ParseTreeListener};
use hamelin_lib::completion::Completion;
use hamelin_lib::sql::expression::operator::Operator;
use hamelin_lib::translation::ExpressionTranslation;
use hamelin_sql::range_builder::RangeBuilder;
use once_cell::sync::Lazy;
use ordermap::OrderMap;
use regex::Regex;
use crate::ast::apply::{
CanMatchArgs, HamelinBinaryOperatorApply, HamelinFunctionCallApply, HamelinUnaryPostfixApply,
HamelinUnaryPrefixApply,
};
use crate::ast::deref::HamelinDerefApply;
use crate::ast::index::HamelinIndexAccess;
use crate::ast::string::HamelinStringLiteral;
use crate::ast::ExpressionTranslationContext;
use crate::env::Environment;
use crate::translation::projection_builder::{ProjectionBuilder, ProjectionBuilderExt};
use hamelin_lib::antlr::hamelinlistener::HamelinListener;
use hamelin_lib::antlr::hamelinparser::{
ArrayLiteralContextAttrs, BinaryLiteralContextAttrs, BooleanLiteralContextAttrs,
ColumnReferenceAltContextAttrs, ColumnReferenceContext, ColumnReferenceContextAttrs,
ExpressionContextAll, HamelinParserContextType, HamelinTreeWalker, IntervalLiteralContextAttrs,
NumberContextAll, NumericLiteralContextAttrs, ParenthesizedExpressionContextAttrs,
RowsLiteralContextAttrs, StringLiteralContextAttrs, StructLiteralContextAttrs,
TsTruncContextAttrs, TsTruncTimestampLiteralContextAttrs, TupleLiteralContextAttrs,
};
use hamelin_lib::antlr::{completion_interval, interval};
use hamelin_lib::err::{
Context, NonMergeableTypes, Stage, TranslationError, TranslationErrors, UnexpectedType,
};
use hamelin_lib::func::registry::FunctionRegistry;
use hamelin_lib::sql::expression::apply::{BinaryOperatorApply, FunctionCallApply};
use hamelin_lib::sql::expression::identifier::HamelinSimpleIdentifier;
use hamelin_lib::sql::expression::identifier::Identifier;
use hamelin_lib::sql::expression::literal::{
ArrayLiteral, BinaryLiteral, BooleanLiteral, ColumnReference, DecimalLiteral, IntegerLiteral,
IntervalLiteral, NullLiteral, RowLiteral, ScientificLiteral, StringLiteral, Unit,
};
use hamelin_lib::sql::expression::Cast;
use hamelin_lib::sql::expression::SQLExpression;
use hamelin_lib::sql::types::SQLRowType;
use hamelin_lib::types::array::Array;
use hamelin_lib::types::decimal_type::Decimal;
use hamelin_lib::types::range::Range;
use hamelin_lib::types::struct_type::Struct;
use hamelin_lib::types::tuple::Tuple;
use hamelin_lib::types::{Type, BINARY, BOOLEAN, DOUBLE, INT, STRING, UNKNOWN};
use hamelin_lib::types::{CALENDAR_INTERVAL, INTERVAL, ROWS, TIMESTAMP};
use hamelin_sql::utils::interval_to_timestamp;
use super::cast::translate_hamelin_cast;
use super::column_ref::HamelinColumnRef;
#[derive(Clone, Debug)]
pub struct HamelinExpression {
tree: Rc<ExpressionContextAll<'static>>,
context: Rc<ExpressionTranslationContext>,
}
impl HamelinExpression {
pub fn new(
tree: Rc<ExpressionContextAll<'static>>,
context: Rc<ExpressionTranslationContext>,
) -> Self {
Self { tree, context }
}
pub fn text(&self) -> String {
self.tree.get_text()
}
pub fn tree(&self) -> &ExpressionContextAll<'static> {
self.tree.as_ref()
}
pub fn bindings(&self) -> Arc<Environment> {
self.context.bindings.clone()
}
pub fn registry(&self) -> Arc<FunctionRegistry> {
self.context.registry.clone()
}
pub fn column_references(&self) -> Result<Vec<ColumnReference>, TranslationErrors> {
struct CRListener {
column_references: Vec<ColumnReference>,
errors: Vec<TranslationErrors>,
}
impl CRListener {
pub fn new() -> Self {
Self {
column_references: Vec::new(),
errors: Vec::new(),
}
}
}
impl ParseTreeListener<'static, HamelinParserContextType> for CRListener {}
impl HamelinListener<'static> for CRListener {
fn enter_columnReference(&mut self, ctx: &ColumnReferenceContext<'static>) {
let sql_ident =
HamelinSimpleIdentifier::new(ctx.simpleIdentifier().unwrap()).to_sql();
match sql_ident {
Ok(ident) => {
let column_reference = ColumnReference::new(ident.into());
self.column_references.push(column_reference);
}
Err(e) => {
self.errors.push(e);
}
}
}
}
let mut listener = Box::new(CRListener::new());
listener = HamelinTreeWalker::walk(listener, self.tree.as_ref());
if !listener.errors.is_empty() {
Err(listener.errors.into())
} else {
Ok(listener.column_references.clone())
}
}
pub fn inner(&self, tree: Rc<ExpressionContextAll<'static>>) -> Self {
Self {
tree,
context: self.context.clone(),
}
}
pub fn translate(&self) -> Result<ExpressionTranslation, TranslationErrors> {
let res = match self.tree.as_ref() {
ExpressionContextAll::UnboundRangeLiteralContext(_) => {
let typ = Range::new(UNKNOWN).into();
let sql = RangeBuilder::default().build();
ExpressionTranslation::with_defaults(typ, sql)
}
ExpressionContextAll::NullLiteralContext(_) => {
ExpressionTranslation::with_defaults(UNKNOWN, NullLiteral::default().into())
}
ExpressionContextAll::NumericLiteralContext(nlctx) => {
let number = nlctx.get_text();
match nlctx.number().unwrap().as_ref() {
NumberContextAll::DecimalLiteralContext(_) => {
let parts: Vec<&str> = number.split('.').collect();
let scale = parts[1].len();
let precision = parts[0].len() + scale;
let decimal =
Decimal::new(precision as i32, scale as i32).map_err(|e| {
TranslationError::wrap_box(self.tree.as_ref(), e.into()).single()
})?;
ExpressionTranslation::with_defaults(
decimal.into(),
DecimalLiteral::new(&number).into(),
)
}
NumberContextAll::ScientificLiteralContext(_) => {
ExpressionTranslation::with_defaults(
DOUBLE,
ScientificLiteral::new(&number).into(),
)
}
NumberContextAll::IntegerLiteralContext(_) => {
ExpressionTranslation::with_defaults(
INT,
IntegerLiteral::new(&number).into(),
)
}
NumberContextAll::Error(ctx) => {
return Err(TranslationError::msg(ctx, "parse error").single())
}
}
}
ExpressionContextAll::BooleanLiteralContext(ctx) => {
ExpressionTranslation::with_defaults(
BOOLEAN,
BooleanLiteral::new(ctx.TRUE().is_some()).into(),
)
}
ExpressionContextAll::StringLiteralContext(ctx) => {
ExpressionTranslation::with_defaults(
STRING,
HamelinStringLiteral::new(
TranslationErrors::expect(ctx, ctx.string())?.clone(),
)
.translate()?,
)
}
ExpressionContextAll::BinaryLiteralContext(ctx) => {
ExpressionTranslation::with_defaults(
BINARY,
BinaryLiteral::new(&ctx.BINARY_LITERAL().unwrap().get_text()[1..]).into(),
)
}
ExpressionContextAll::ColumnReferenceAltContext(cractx) => {
let hcr = HamelinColumnRef::new(
self.tree.clone(),
cractx.columnReference().clone(),
self.context.clone(),
);
hcr.append_completions();
hcr.translate()?
}
ExpressionContextAll::DerefContext(ctx) => {
let hamelin_deref_apply = HamelinDerefApply::new(
ctx.left.clone().unwrap(),
ctx.right.clone().unwrap(),
self.context.clone(),
);
hamelin_deref_apply.append_completions();
ExpressionTranslation::with_defaults(
hamelin_deref_apply.infer_type()?,
hamelin_deref_apply.to_sql()?,
)
}
ExpressionContextAll::UnaryPrefixOperatorContext(ctx) => {
HamelinUnaryPrefixApply::try_new(ctx, self.context.clone())?.match_function(ctx)?
}
ExpressionContextAll::UnaryPostfixOperatorContext(ctx) => {
HamelinUnaryPostfixApply::try_new(ctx, self.context.clone())?.match_function(ctx)?
}
ExpressionContextAll::BinaryOperatorContext(ctx) => {
HamelinBinaryOperatorApply::try_new(ctx, self.context.clone())?
.match_function(ctx)?
}
ExpressionContextAll::ParenthesizedExpressionContext(ctx) => {
self.inner(ctx.expression().unwrap()).translate()?
}
ExpressionContextAll::StructLiteralContext(ctx) => {
let pairs = ctx
.simpleIdentifier_all()
.into_iter()
.zip(ctx.expression_all().into_iter());
let mut types = OrderMap::new();
let mut sql_fields = OrderMap::new();
let mut sql_expressions = Vec::new();
for (id, exp) in pairs {
let identifier = HamelinSimpleIdentifier::new(id).to_sql()?;
let expression = self.inner(exp);
let expression_translation = expression.translate()?;
types.insert(identifier.clone(), expression_translation.typ.clone());
sql_expressions.push(expression_translation.sql);
sql_fields.insert(
identifier,
expression_translation.typ.to_sql().map_err(|e| {
TranslationError::wrap_box(expression.tree.as_ref(), e.into()).single()
})?,
);
}
let typ = Struct::new(types);
let sql = Cast::new(
RowLiteral::new(sql_expressions).into(),
SQLRowType::new(sql_fields).into(),
);
ExpressionTranslation::with_defaults(typ.into(), sql.into())
}
ExpressionContextAll::FunctionCallContext(ctx) => {
HamelinFunctionCallApply::try_new(ctx, self.context.clone())?.match_function(ctx)?
}
ExpressionContextAll::IndexAccessContext(ctx) => {
let hia = HamelinIndexAccess::new(
self.inner(ctx.value.clone().unwrap()),
self.inner(ctx.index.clone().unwrap()),
);
ExpressionTranslation::with_defaults(hia.infer_type()?, hia.to_sql()?)
}
ExpressionContextAll::ArrayLiteralContext(ctx) => {
let elements = ctx
.expression_all()
.iter()
.map(|e| self.inner(e.clone()))
.collect::<Vec<HamelinExpression>>();
let elts_translated =
TranslationErrors::from_vec(elements.iter().map(|e| e.translate()).collect())?;
let error_contexts = elts_translated
.iter()
.zip(elements.iter().map(|e| e.tree.as_ref()))
.map(|(trns, tree)| Context::from(tree, trns.typ.to_string().as_str()))
.collect::<Vec<_>>();
let elt_type = get_array_element_type(&elts_translated[..]).map_err(|e| {
TranslationError::msg(
self.tree.as_ref(),
"could not determine array element type",
)
.with_context_vec(error_contexts)
.with_source(e)
.single()
})?;
let typ = Array::new(elt_type.clone().into());
if let Type::Struct(element_struct) = &elt_type {
let needs_expansion = elts_translated.iter().any(|elt| elt.typ != elt_type);
if needs_expansion {
let mut ret = vec![];
for (elt, elt_translated) in elements.iter().zip(elts_translated.iter()) {
if elt_translated.typ != elt_type {
if let Type::Struct(_) = &elt_translated.typ {
let pb = ProjectionBuilder::from_struct_expression(elt.clone())
.map_err(|e| {
TranslationError::wrap_box(elt.tree.as_ref(), e.into())
.single()
})?;
ret.push(
pb.expand(element_struct.clone())
.map_err(|e| {
TranslationError::wrap_box(
elt.tree.as_ref(),
e.into(),
)
.single()
})?
.build_cast()
.map_err(|e| {
TranslationError::wrap_box(
elt.tree.as_ref(),
e.into(),
)
.single()
})?
.into(),
);
} else {
ret.push(elt_translated.sql.clone());
}
} else {
ret.push(elt_translated.sql.clone());
}
}
return Ok(ExpressionTranslation::with_defaults(
typ.into(),
ArrayLiteral::new(ret).into(),
)
.with_span(self.tree.as_ref()));
}
}
let sql =
ArrayLiteral::new(elts_translated.into_iter().map(|elt| elt.sql).collect());
ExpressionTranslation::with_defaults(typ.into(), sql.into())
}
ExpressionContextAll::PairLiteralContext(ctx) => {
let l = self.inner(ctx.left.clone().unwrap()).translate()?;
let r = self.inner(ctx.right.clone().unwrap()).translate()?;
ExpressionTranslation::with_defaults(
Tuple::new(vec![l.typ, r.typ]).into(),
RowLiteral::new(vec![l.sql, r.sql]).into(),
)
}
ExpressionContextAll::TupleLiteralContext(ctx) => {
let inner = TranslationErrors::from_vec(
ctx.expression_all()
.iter()
.map(|e| self.inner(e.clone()).translate())
.collect(),
)?;
ExpressionTranslation::with_defaults(
Tuple::new(inner.iter().map(|et| et.typ.clone()).collect()).into(),
RowLiteral::new(inner.iter().map(|et| et.sql.clone()).collect()).into(),
)
}
ExpressionContextAll::CastContext(ctx) => {
translate_hamelin_cast(&ctx, self.context.clone())?
}
ExpressionContextAll::IntervalLiteralContext(ctx) => {
let (sql, interval_type): (SQLExpression, Type) = if let Some(ctx) =
ctx.NANOSECOND_INTERVAL()
{
(duration(extract_num(ctx.as_ref())?, "ns"), INTERVAL)
} else if let Some(ctx) = ctx.MICROSECOND_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "us"), INTERVAL)
} else if let Some(ctx) = ctx.MILLISECOND_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "ms"), INTERVAL)
} else if let Some(ctx) = ctx.SECOND_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "s"), INTERVAL)
} else if let Some(ctx) = ctx.MINUTE_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "m"), INTERVAL)
} else if let Some(ctx) = ctx.HOUR_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "h"), INTERVAL)
} else if let Some(ctx) = ctx.DAY_INTERVAL() {
(duration(extract_num(ctx.as_ref())?, "d"), INTERVAL)
} else if let Some(ctx) = ctx.WEEK_INTERVAL() {
(
BinaryOperatorApply::new(
Operator::Asterisk,
IntegerLiteral::new("7").into(),
IntervalLiteral::new(extract_num(ctx.as_ref())?, Unit::Day).into(),
)
.into(),
INTERVAL,
)
} else if let Some(ctx) = ctx.MONTH_INTERVAL() {
(
IntervalLiteral::new(extract_num(ctx.as_ref())?, Unit::Month).into(),
CALENDAR_INTERVAL,
)
} else if let Some(ctx) = ctx.QUARTER_INTERVAL() {
(
BinaryOperatorApply::new(
Operator::Asterisk,
IntegerLiteral::new("3").into(),
IntervalLiteral::new(extract_num(ctx.as_ref())?, Unit::Month).into(),
)
.into(),
CALENDAR_INTERVAL,
)
} else {
(
IntervalLiteral::new(
extract_num(ctx.YEAR_INTERVAL().unwrap().as_ref())?,
Unit::Year,
)
.into(),
CALENDAR_INTERVAL,
)
};
ExpressionTranslation::with_defaults(interval_type, sql)
}
ExpressionContextAll::TsTruncContext(ctx) => {
let exp = self.inner(ctx.expression().unwrap());
let exp_translated = exp.translate()?;
let typ = if exp_translated.typ == TIMESTAMP
|| exp_translated.typ == INTERVAL
|| exp_translated.typ == CALENDAR_INTERVAL
{
TIMESTAMP
} else {
return Err(TranslationError::wrap(
ctx,
UnexpectedType::new(exp_translated.typ, vec![TIMESTAMP.into()]),
)
.single());
};
let unit = ctx
.SECOND_TRUNC()
.map(|_| "second")
.or_else(|| ctx.MINUTE_TRUNC().map(|_| "minute"))
.or_else(|| ctx.HOUR_TRUNC().map(|_| "hour"))
.or_else(|| ctx.DAY_TRUNC().map(|_| "day"))
.or_else(|| ctx.WEEK_TRUNC().map(|_| "week"))
.or_else(|| ctx.MONTH_TRUNC().map(|_| "month"))
.or_else(|| ctx.QUARTER_TRUNC().map(|_| "quarter"))
.unwrap_or("year");
let sql =
if exp_translated.typ == INTERVAL || exp_translated.typ == CALENDAR_INTERVAL {
interval_to_timestamp(exp_translated.sql)
} else {
exp_translated.sql
};
let sql =
FunctionCallApply::with_two("date_trunc", StringLiteral::new(unit).into(), sql)
.into();
ExpressionTranslation::with_defaults(typ, sql)
}
ExpressionContextAll::TsTruncTimestampLiteralContext(ctx) => {
let typ = TIMESTAMP;
let unit = ctx
.SECOND_TRUNC()
.map(|_| "second")
.or_else(|| ctx.MINUTE_TRUNC().map(|_| "minute"))
.or_else(|| ctx.HOUR_TRUNC().map(|_| "hour"))
.or_else(|| ctx.DAY_TRUNC().map(|_| "day"))
.or_else(|| ctx.WEEK_TRUNC().map(|_| "week"))
.or_else(|| ctx.MONTH_TRUNC().map(|_| "month"))
.or_else(|| ctx.QUARTER_TRUNC().map(|_| "quarter"))
.unwrap_or("year");
let sql = FunctionCallApply::with_two(
"date_trunc",
StringLiteral::new(unit).into(),
FunctionCallApply::with_no_arguments("now").into(),
)
.into();
ExpressionTranslation::with_defaults(typ, sql)
}
ExpressionContextAll::RowsLiteralContext(ctx) => {
let number = ctx
.ROWS_LITERAL()
.map_or(Ok("1".to_string()), |n| extract_num(n.as_ref()))?;
let sql = DecimalLiteral::new(&number);
ExpressionTranslation::with_defaults(ROWS, sql.into())
}
ExpressionContextAll::Error(ectx) => {
let maybe_guard = self.context.completions.try_borrow_mut();
let range = completion_interval(ectx);
match (self.context.at, maybe_guard) {
(Some(at), Ok(mut guard)) if range.contains(&at) && guard.is_none() => {
let insert_interval = interval(ectx);
let mut completion = Completion::new(insert_interval);
completion.filter(false);
completion.add_items(self.context.bindings.autocomplete_suggestions(false));
completion.add_items(self.context.registry.autocomplete_suggestions());
*guard = Some(completion);
}
_ => {}
}
return Err(
TranslationError::msg(self.tree.as_ref(), "clause missing expression")
.with_stage(Stage::Parsing)
.single(),
);
}
};
Ok(res.with_span(self.tree.as_ref()))
}
}
fn get_array_element_type(elements: &[ExpressionTranslation]) -> Result<Type, NonMergeableTypes> {
let mut ret = None;
for elt in elements.into_iter() {
match ret {
None => {
ret = Some(elt.typ.clone());
}
Some(ty) => {
ret = Some(ty.merge(elt.typ.clone())?);
}
}
}
return Ok(ret.unwrap_or(UNKNOWN));
}
fn extract_num<'a, T>(node: &T) -> Result<String, TranslationErrors>
where
T: ParserRuleContext<'a>,
{
static NUMBERS: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\d+").unwrap());
let text = node.get_text();
let matcher = NUMBERS
.find(&text)
.ok_or(TranslationError::msg(node, "Could not extract a number from the text").single())?;
return Ok(text[matcher.start()..matcher.end()].to_string());
}
fn duration(num: String, unit: &str) -> SQLExpression {
FunctionCallApply::with_one("parse_duration", StringLiteral::new(&(num + unit)).into()).into()
}
#[derive(Debug, thiserror::Error)]
pub enum NestedDerefError {
#[error("The expression {0} is not a struct")]
NotAStruct(HamelinExpression),
#[error("Field {0} not found in {1:#?}")]
FieldNotFound(Identifier, Struct),
#[error("Field {0} in {1:#?} is not a struct")]
FieldNotAStruct(Identifier, Struct),
}
impl Display for HamelinExpression {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.text())
}
}