use alloc::{format, string::ToString, sync::Arc, vec};
use qusql_parse::{Expression, Identifier, Spanned, UnaryOperator, Variable, issue_todo};
use crate::{
Type,
schema::parse_column,
type_::{ArgType, BaseType, FullType},
type_binary_expression::type_binary_expression,
type_function::{type_aggregate_function, type_function},
type_select::type_union_select,
typer::{Restrict, Typer},
};
#[derive(Clone, Copy, Default)]
pub struct ExpressionFlags {
pub true_: bool,
pub not_null: bool,
pub in_on_duplicate_key_update: bool,
}
impl ExpressionFlags {
pub fn with_true(self, true_: bool) -> Self {
Self { true_, ..self }
}
pub fn with_not_null(self, not_null: bool) -> Self {
Self { not_null, ..self }
}
pub fn with_in_on_duplicate_key_update(self, in_on_duplicate_key_update: bool) -> Self {
Self {
in_on_duplicate_key_update,
..self
}
}
pub fn without_values(self) -> Self {
Self {
not_null: false,
true_: false,
..self
}
}
}
fn type_unary_expression<'a>(
typer: &mut Typer<'a, '_>,
op: &UnaryOperator,
operand: &Expression<'a>,
flags: ExpressionFlags,
) -> FullType<'a> {
let op_span = op.span();
match op {
UnaryOperator::Binary(_) | UnaryOperator::LogicalNot(_) | UnaryOperator::Minus(_) => {
let op_type = type_expression(typer, operand, flags.with_true(false), BaseType::Any);
let t = match &op_type.t {
Type::F32
| Type::F64
| Type::I16
| Type::I24
| Type::I32
| Type::I64
| Type::I8
| Type::Invalid
| Type::Base(BaseType::Integer)
| Type::Base(BaseType::Float) => op_type.t,
Type::Args(..) | Type::Base(..) | Type::Enum(..) | Type::JSON | Type::Set(..) => {
typer.err(format!("Expected numeric type got {}", op_type.t), &op_span);
Type::Invalid
}
Type::U16 => Type::I16,
Type::U24 => Type::I24,
Type::U32 => Type::I32,
Type::U64 => Type::I64,
Type::U8 => Type::I8,
Type::Null => Type::Null,
};
FullType::new(t, op_type.not_null)
}
UnaryOperator::Not(_) => {
let op_type = type_expression(typer, operand, flags.with_true(false), BaseType::Bool);
typer.ensure_base(operand, &op_type, BaseType::Bool);
op_type
}
}
}
pub(crate) fn type_expression<'a>(
typer: &mut Typer<'a, '_>,
expression: &Expression<'a>,
flags: ExpressionFlags,
_context: BaseType,
) -> FullType<'a> {
match expression {
Expression::Binary(e) => type_binary_expression(typer, &e.op, &e.lhs, &e.rhs, flags),
Expression::Unary(e) => type_unary_expression(typer, &e.op, &e.operand, flags),
Expression::Subquery(e) => {
let select_type = type_union_select(typer, &e.expression, false);
if let [v] = select_type.columns.as_slice() {
let mut r = v.type_.clone();
r.not_null = false;
r
} else {
typer.err("Subquery should yield one column", &e.expression);
FullType::invalid()
}
}
Expression::ListHack(v) => {
typer.err("_LIST_ only allowed in IN ()", v);
FullType::invalid()
}
Expression::Null(_) => FullType::new(Type::Null, false),
Expression::Bool(_) => FullType::new(BaseType::Bool, true),
Expression::String(_) => FullType::new(BaseType::String, true),
Expression::Integer(_) => FullType::new(BaseType::Integer, true),
Expression::Float(_) => FullType::new(BaseType::Float, true),
Expression::Function(e) => {
type_function(typer, &e.function, &e.args, &e.function_span, flags)
}
Expression::WindowFunction(e) => {
if let Some((_, partition_by)) = &e.over.window_spec.partition_by {
for e in partition_by {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
}
if let Some((_, order_by)) = &e.over.window_spec.order_by {
for (e, _) in order_by {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
}
type_function(typer, &e.function, &e.args, &e.function_span, flags)
}
Expression::AggregateFunction(e) => {
if let Some((_, filter)) = &e.filter {
type_expression(typer, filter, ExpressionFlags::default(), BaseType::Bool);
}
if let Some((_, within_group_order)) = &e.within_group {
for (e, _) in within_group_order {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
}
if let Some(over) = &e.over {
if let Some((_, partition_by)) = &over.window_spec.partition_by {
for e in partition_by {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
}
if let Some((_, order_by)) = &over.window_spec.order_by {
for (e, _) in order_by {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
}
}
type_aggregate_function(
typer,
&e.function,
&e.args,
&e.function_span,
&e.distinct_span,
flags,
)
}
Expression::Identifier(e) => {
let mut t = None;
match e.parts.as_slice() {
[part] => {
let col = match part {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
let mut cnt = 0;
for r in &mut typer.reference_types {
for c in &mut r.columns {
if c.0 == *col {
cnt += 1;
if flags.not_null {
c.1.not_null = true;
}
t = Some(c);
}
}
}
if cnt > 1 {
let mut issue = typer.issues.err("Ambiguous reference", col);
for r in &typer.reference_types {
for c in &r.columns {
if c.0 == *col {
issue.frag("Defined here", &r.span);
}
}
}
return FullType::invalid();
}
}
[p1, p2] => {
let tbl = match p1 {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
let col = match p2 {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
for r in &mut typer.reference_types {
if r.name == Some(tbl.clone()) {
for c in &mut r.columns {
if c.0 == *col {
if flags.not_null {
c.1.not_null = true;
}
t = Some(c);
}
}
}
}
}
_ => {
typer.err("Bad identifier length", expression);
return FullType::invalid();
}
}
match t {
None => {
typer.err("Unknown identifier", expression);
FullType::invalid()
}
Some((_, type_)) => type_.clone(),
}
}
Expression::Arg(e) => FullType::new(
Type::Args(
BaseType::Any,
Arc::new(vec![(e.index, ArgType::Normal, e.span.clone())]),
),
false,
),
Expression::Exists(e) => {
type_union_select(typer, &e.subquery, false);
FullType::new(BaseType::Bool, true)
}
Expression::In(e) => {
let f2 = if flags.true_ {
flags.with_not_null(true).with_true(false)
} else {
flags
};
let mut lhs_type = type_expression(typer, &e.lhs, f2, BaseType::Any);
let mut not_null = lhs_type.not_null;
lhs_type.not_null = false;
for rhs in &e.rhs {
let rhs_type = match rhs {
Expression::Subquery(q) => {
let rhs_type = type_union_select(typer, &q.expression, false);
if rhs_type.columns.len() != 1 {
typer.err(
format!(
"Subquery in IN should yield one column but gave {}",
rhs_type.columns.len()
),
&q.expression,
);
}
if let Some(c) = rhs_type.columns.first() {
c.type_.clone()
} else {
FullType::invalid()
}
}
Expression::ListHack(e) => FullType::new(
Type::Args(
BaseType::Any,
Arc::new(vec![(e.index, ArgType::ListHack, e.span.clone())]),
),
false,
),
_ => type_expression(typer, rhs, flags.without_values(), BaseType::Any),
};
not_null &= rhs_type.not_null;
if typer.matched_type(&lhs_type, &rhs_type).is_none() {
typer
.err("Incompatible types", &e.in_span)
.frag(lhs_type.t.to_string(), &e.lhs)
.frag(rhs_type.to_string(), rhs);
}
}
FullType::new(BaseType::Bool, not_null)
}
Expression::MemberOf(e) => {
let lhs_type = type_expression(typer, &e.lhs, flags, BaseType::Any);
let rhs_type = type_expression(typer, &e.rhs, flags, BaseType::String); FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
Expression::Is(e) => {
let (flags, base_type) = match e.is {
qusql_parse::Is::Null => (flags.without_values(), BaseType::Any),
qusql_parse::Is::NotNull => {
if flags.true_ {
(flags.with_not_null(true).with_true(false), BaseType::Any)
} else {
(flags.with_not_null(false), BaseType::Any)
}
}
qusql_parse::Is::True
| qusql_parse::Is::NotTrue
| qusql_parse::Is::False
| qusql_parse::Is::NotFalse => (flags.without_values(), BaseType::Bool),
qusql_parse::Is::Unknown | qusql_parse::Is::NotUnknown => {
(flags.without_values(), BaseType::Any)
}
qusql_parse::Is::DistinctFrom(_) | qusql_parse::Is::NotDistinctFrom(_) => {
(flags.without_values(), BaseType::Any)
}
};
let t = type_expression(typer, &e.lhs, flags, base_type);
match e.is {
qusql_parse::Is::Null => {
if t.not_null {
typer.warn("Cannot be null", &e.lhs);
}
FullType::new(BaseType::Bool, true)
}
qusql_parse::Is::NotNull
| qusql_parse::Is::True
| qusql_parse::Is::NotTrue
| qusql_parse::Is::DistinctFrom(_)
| qusql_parse::Is::NotDistinctFrom(_)
| qusql_parse::Is::False
| qusql_parse::Is::NotFalse => FullType::new(BaseType::Bool, true),
qusql_parse::Is::Unknown | qusql_parse::Is::NotUnknown => {
issue_todo!(typer.issues, &e.lhs);
FullType::invalid()
}
}
}
Expression::Invalid(_) => FullType::invalid(),
Expression::Case(e) => {
if let Some(e) = &e.value {
issue_todo!(typer.issues, e);
FullType::invalid()
} else {
let not_null = true;
let mut t: Option<Type> = None;
for when in &e.whens {
let op_type = type_expression(typer, &when.when, flags, BaseType::Bool);
typer.ensure_base(&when.when, &op_type, BaseType::Bool);
let t2 = type_expression(typer, &when.then, flags, BaseType::Any);
if let Some(t1) = t {
t = typer.matched_type(&t1, &t2.t)
} else {
t = Some(t2.t);
}
}
if let Some((_, else_)) = &e.else_ {
let t2 = type_expression(typer, else_, flags, BaseType::Any);
if let Some(t1) = t {
t = typer.matched_type(&t1, &t2.t)
} else {
t = Some(t2.t);
}
}
if let Some(t) = t {
FullType::new(t, not_null)
} else {
FullType::invalid()
}
}
}
Expression::Cast(e) => {
let col = parse_column(
e.type_.clone(),
Identifier::new("", e.as_span.clone()),
typer.issues,
None,
);
if typer.dialect().is_maria() {
match e.type_.type_ {
qusql_parse::Type::Char(_)
| qusql_parse::Type::Date
| qusql_parse::Type::Inet4
| qusql_parse::Type::Inet6
| qusql_parse::Type::InetAddr
| qusql_parse::Type::Cidr
| qusql_parse::Type::Macaddr
| qusql_parse::Type::Macaddr8
| qusql_parse::Type::TsQuery
| qusql_parse::Type::TsVector
| qusql_parse::Type::Uuid
| qusql_parse::Type::Xml
| qusql_parse::Type::Range(_)
| qusql_parse::Type::MultiRange(_)
| qusql_parse::Type::DateTime(_)
| qusql_parse::Type::Double(_)
| qusql_parse::Type::Float8
| qusql_parse::Type::Float(_)
| qusql_parse::Type::Integer(_)
| qusql_parse::Type::Int(_)
| qusql_parse::Type::Binary(_)
| qusql_parse::Type::Timestamptz
| qusql_parse::Type::Time(_) => {}
qusql_parse::Type::Boolean
| qusql_parse::Type::TinyInt(_)
| qusql_parse::Type::SmallInt(_)
| qusql_parse::Type::BigInt(_)
| qusql_parse::Type::MediumInt(_)
| qusql_parse::Type::VarChar(_)
| qusql_parse::Type::TinyText(_)
| qusql_parse::Type::MediumText(_)
| qusql_parse::Type::Text(_)
| qusql_parse::Type::LongText(_)
| qusql_parse::Type::Enum(_)
| qusql_parse::Type::Set(_)
| qusql_parse::Type::Numeric(_)
| qusql_parse::Type::Decimal(_)
| qusql_parse::Type::Timestamp(_)
| qusql_parse::Type::TinyBlob(_)
| qusql_parse::Type::MediumBlob(_)
| qusql_parse::Type::Blob(_)
| qusql_parse::Type::LongBlob(_)
| qusql_parse::Type::Json
| qusql_parse::Type::Jsonb
| qusql_parse::Type::Bit(_, _)
| qusql_parse::Type::VarBit(_)
| qusql_parse::Type::Bytea
| qusql_parse::Type::Named(_) | qusql_parse::Type::Array(_, _)
| qusql_parse::Type::VarBinary(_)
| qusql_parse::Type::BigSerial
| qusql_parse::Type::Serial
| qusql_parse::Type::SmallSerial
| qusql_parse::Type::Money
| qusql_parse::Type::Timetz(_)
| qusql_parse::Type::Interval(_)
| qusql_parse::Type::Point
| qusql_parse::Type::Line
| qusql_parse::Type::Lseg
| qusql_parse::Type::Box
| qusql_parse::Type::Path
| qusql_parse::Type::Polygon
| qusql_parse::Type::Circle
| qusql_parse::Type::Table(_, _) => {
typer
.err("Type not allow in cast", &e.type_);
}
};
} else {
}
let e = type_expression(typer, &e.expr, flags, col.type_.base());
FullType::new(col.type_.t, e.not_null)
}
Expression::GroupConcat(e) => {
let e = type_expression(typer, &e.expr, flags.without_values(), BaseType::Any);
FullType::new(BaseType::String, e.not_null)
}
Expression::Variable(e) => match &e.variable {
Variable::TimeZone => FullType::new(BaseType::String, true),
Variable::Other(_) => {
typer.err("Unknown variable", e);
FullType::new(BaseType::Any, false)
}
},
Expression::Interval(e) => {
let cnt = match e.time_unit.0 {
qusql_parse::TimeUnit::Microsecond => 1,
qusql_parse::TimeUnit::Second => 1,
qusql_parse::TimeUnit::Minute => 1,
qusql_parse::TimeUnit::Hour => 1,
qusql_parse::TimeUnit::Day => 1,
qusql_parse::TimeUnit::Week => 1,
qusql_parse::TimeUnit::Month => 1,
qusql_parse::TimeUnit::Quarter => 1,
qusql_parse::TimeUnit::Year => 1,
qusql_parse::TimeUnit::SecondMicrosecond => 2,
qusql_parse::TimeUnit::MinuteMicrosecond => 3,
qusql_parse::TimeUnit::MinuteSecond => 2,
qusql_parse::TimeUnit::HourMicrosecond => 4,
qusql_parse::TimeUnit::HourSecond => 3,
qusql_parse::TimeUnit::HourMinute => 2,
qusql_parse::TimeUnit::DayMicrosecond => 5,
qusql_parse::TimeUnit::DaySecond => 4,
qusql_parse::TimeUnit::DayMinute => 3,
qusql_parse::TimeUnit::DayHour => 2,
qusql_parse::TimeUnit::YearMonth => 2,
};
if cnt != e.time_interval.0.len() {
typer.err(
format!(
"Expected {} values for {:?} got {}",
cnt,
e.time_unit.0,
e.time_interval.0.len()
),
&e.time_interval.1,
);
}
FullType::new(BaseType::TimeInterval, true)
}
Expression::Extract(e) => {
let t = type_expression(typer, &e.date, flags, BaseType::Any);
FullType::new(BaseType::Integer, t.not_null)
}
Expression::TimestampAdd(e) => {
let t1 = type_expression(typer, &e.interval, flags, BaseType::Integer);
let t2 = type_expression(typer, &e.datetime, flags, BaseType::Any);
typer.ensure_base(&e.interval, &t1, BaseType::Integer);
typer.ensure_datetime(&e.datetime, &t2, Restrict::Require, Restrict::Allow);
FullType::new(BaseType::DateTime, t1.not_null && t2.not_null)
}
Expression::TimestampDiff(e) => {
let t1 = type_expression(typer, &e.e1, flags, BaseType::Any);
let t2 = type_expression(typer, &e.e2, flags, BaseType::Any);
typer.ensure_datetime(&e.e1, &t1, Restrict::Require, Restrict::Allow);
typer.ensure_datetime(&e.e2, &t2, Restrict::Require, Restrict::Allow);
FullType::new(BaseType::Integer, t1.not_null && t2.not_null)
}
e @ Expression::MatchAgainst { .. } => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::Convert { .. } => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::UserVariable { .. } => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::TypeCast(..) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::Array(..) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::ArraySubscript(..) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::Default(_) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
e @ Expression::Between(_) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
Expression::Quantifier(e) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
Expression::FieldAccess(e) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
Expression::Trim(e) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
Expression::Char(e) => {
issue_todo!(typer.issues, e);
FullType::invalid()
}
}
}