use alloc::{format, string::ToString, sync::Arc, vec};
use core::ops::Deref;
use qusql_parse::{Expression, Identifier, Span, UnaryOperator, Variable, issue_todo};
use crate::{
Type,
schema::parse_column,
type_::{ArgType, BaseType, FullType},
type_binary_expression::type_binary_expression,
type_function::type_function,
type_select::{resolve_kleene_identifier, type_union_select},
typer::{Restrict, Typer},
};
#[derive(Clone, Copy, Default)]
pub struct ExpressionFlags {
pub true_: bool,
pub not_null: bool,
pub in_on_duplicate_key_update: bool,
}
impl ExpressionFlags {
pub fn with_true(self, true_: bool) -> Self {
Self { true_, ..self }
}
pub fn with_not_null(self, not_null: bool) -> Self {
Self { not_null, ..self }
}
pub fn with_in_on_duplicate_key_update(self, in_on_duplicate_key_update: bool) -> Self {
Self {
in_on_duplicate_key_update,
..self
}
}
pub fn without_values(self) -> Self {
Self {
not_null: false,
true_: false,
..self
}
}
}
fn type_unary_expression<'a>(
typer: &mut Typer<'a, '_>,
op: &UnaryOperator,
op_span: &Span,
operand: &Expression<'a>,
flags: ExpressionFlags,
) -> FullType<'a> {
match op {
UnaryOperator::Binary
| UnaryOperator::Collate
| UnaryOperator::LogicalNot
| UnaryOperator::Minus => {
let op_type = type_expression(typer, operand, flags.with_true(false), BaseType::Any);
let t = match &op_type.t {
Type::F32
| Type::F64
| Type::I16
| Type::I32
| Type::I64
| Type::I8
| Type::Invalid
| Type::Base(BaseType::Integer)
| Type::Base(BaseType::Float) => op_type.t,
Type::Args(..) | Type::Base(..) | Type::Enum(..) | Type::JSON | Type::Set(..) => {
typer.err(format!("Expected numeric type got {}", op_type.t), op_span);
Type::Invalid
}
Type::U16 => Type::I16,
Type::U32 => Type::I32,
Type::U64 => Type::I64,
Type::U8 => Type::I8,
Type::Null => Type::Null,
};
FullType::new(t, op_type.not_null)
}
UnaryOperator::Not => {
let op_type = type_expression(typer, operand, flags.with_true(false), BaseType::Bool);
typer.ensure_base(operand, &op_type, BaseType::Bool);
op_type
}
}
}
pub(crate) fn type_expression<'a>(
typer: &mut Typer<'a, '_>,
expression: &Expression<'a>,
flags: ExpressionFlags,
_context: BaseType,
) -> FullType<'a> {
match expression {
Expression::Binary {
op,
op_span,
lhs,
rhs,
} => type_binary_expression(typer, op, op_span, lhs, rhs, flags),
Expression::Unary {
op,
op_span,
operand,
} => type_unary_expression(typer, op, op_span, operand, flags),
Expression::Subquery(select) => {
let select_type = type_union_select(typer, select, false);
if let [v] = select_type.columns.as_slice() {
let mut r = v.type_.clone();
r.not_null = false;
r
} else {
typer.err("Subquery should yield one column", select);
FullType::invalid()
}
}
Expression::ListHack(v) => {
typer.err("_LIST_ only allowed in IN ()", v);
FullType::invalid()
}
Expression::Null(_) => FullType::new(Type::Null, false),
Expression::Bool(_, _) => FullType::new(BaseType::Bool, true),
Expression::String(_) => FullType::new(BaseType::String, true),
Expression::Integer(_) => FullType::new(BaseType::Integer, true),
Expression::Float(_) => FullType::new(BaseType::Float, true),
Expression::Function(func, args, span) => type_function(typer, func, args, span, flags),
Expression::WindowFunction {
function,
args,
function_span,
over_span: _,
window_spec,
} => {
for (e, _) in &window_spec.order_by.1 {
type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
}
type_function(typer, function, args, function_span, flags)
}
Expression::Identifier(i) => {
let mut t = None;
match i.as_slice() {
[part] => {
let col = match part {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
let mut cnt = 0;
for r in &mut typer.reference_types {
for c in &mut r.columns {
if c.0 == *col {
cnt += 1;
if flags.not_null {
c.1.not_null = true;
}
t = Some(c);
}
}
}
if cnt > 1 {
let mut issue = typer.issues.err("Ambiguous reference", col);
for r in &typer.reference_types {
for c in &r.columns {
if c.0 == *col {
issue.frag("Defined here", &r.span);
}
}
}
return FullType::invalid();
}
}
[p1, p2] => {
let tbl = match p1 {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
let col = match p2 {
qusql_parse::IdentifierPart::Name(n) => n,
qusql_parse::IdentifierPart::Star(v) => {
typer.err("Not supported here", v);
return FullType::invalid();
}
};
for r in &mut typer.reference_types {
if r.name == Some(tbl.clone()) {
for c in &mut r.columns {
if c.0 == *col {
if flags.not_null {
c.1.not_null = true;
}
t = Some(c);
}
}
}
}
}
_ => {
typer.err("Bad identifier length", expression);
return FullType::invalid();
}
}
match t {
None => {
typer.err("Unknown identifier", expression);
FullType::invalid()
}
Some((_, type_)) => type_.clone(),
}
}
Expression::Arg((idx, span)) => FullType::new(
Type::Args(
BaseType::Any,
Arc::new(vec![(*idx, ArgType::Normal, span.clone())]),
),
false,
),
Expression::Exists(s) => {
type_union_select(typer, s, false);
FullType::new(BaseType::Bool, true)
}
Expression::In {
lhs, rhs, in_span, ..
} => {
let f2 = if flags.true_ {
flags.with_not_null(true).with_true(false)
} else {
flags
};
let mut lhs_type = type_expression(typer, lhs, f2, BaseType::Any);
let mut not_null = lhs_type.not_null;
lhs_type.not_null = false;
for rhs in rhs {
let rhs_type = match rhs {
Expression::Subquery(q) => {
let rhs_type = type_union_select(typer, q, false);
if rhs_type.columns.len() != 1 {
typer.err(
format!(
"Subquery in IN should yield one column but gave {}",
rhs_type.columns.len()
),
q,
);
}
if let Some(c) = rhs_type.columns.first() {
c.type_.clone()
} else {
FullType::invalid()
}
}
Expression::ListHack((idx, span)) => FullType::new(
Type::Args(
BaseType::Any,
Arc::new(vec![(*idx, ArgType::ListHack, span.clone())]),
),
false,
),
_ => type_expression(typer, rhs, flags.without_values(), BaseType::Any),
};
not_null &= rhs_type.not_null;
if typer.matched_type(&lhs_type, &rhs_type).is_none() {
typer
.err("Incompatible types", in_span)
.frag(lhs_type.t.to_string(), lhs)
.frag(rhs_type.to_string(), rhs);
}
}
FullType::new(BaseType::Bool, not_null)
}
Expression::Is(e, is, _) => {
let (flags, base_type) = match is {
qusql_parse::Is::Null => (flags.without_values(), BaseType::Any),
qusql_parse::Is::NotNull => {
if flags.true_ {
(flags.with_not_null(true).with_true(false), BaseType::Any)
} else {
(flags.with_not_null(false), BaseType::Any)
}
}
qusql_parse::Is::True
| qusql_parse::Is::NotTrue
| qusql_parse::Is::False
| qusql_parse::Is::NotFalse => (flags.without_values(), BaseType::Bool),
qusql_parse::Is::Unknown | qusql_parse::Is::NotUnknown => {
(flags.without_values(), BaseType::Any)
}
};
let t = type_expression(typer, e, flags, base_type);
match is {
qusql_parse::Is::Null => {
if t.not_null {
typer.warn("Cannot be null", e);
}
FullType::new(BaseType::Bool, true)
}
qusql_parse::Is::NotNull
| qusql_parse::Is::True
| qusql_parse::Is::NotTrue
| qusql_parse::Is::False
| qusql_parse::Is::NotFalse => FullType::new(BaseType::Bool, true),
qusql_parse::Is::Unknown | qusql_parse::Is::NotUnknown => {
issue_todo!(typer.issues, expression);
FullType::invalid()
}
}
}
Expression::Invalid(_) => FullType::invalid(),
Expression::Case {
value,
whens,
else_,
..
} => {
if value.is_some() {
issue_todo!(typer.issues, expression);
FullType::invalid()
} else {
let not_null = true;
let mut t: Option<Type> = None;
for when in whens {
let op_type = type_expression(typer, &when.when, flags, BaseType::Bool);
typer.ensure_base(&when.when, &op_type, BaseType::Bool);
let t2 = type_expression(typer, &when.then, flags, BaseType::Any);
if let Some(t1) = t {
t = typer.matched_type(&t1, &t2.t)
} else {
t = Some(t2.t);
}
}
if let Some((_, else_)) = else_ {
let t2 = type_expression(typer, else_, flags, BaseType::Any);
if let Some(t1) = t {
t = typer.matched_type(&t1, &t2.t)
} else {
t = Some(t2.t);
}
}
if let Some(t) = t {
FullType::new(t, not_null)
} else {
FullType::invalid()
}
}
}
Expression::Cast {
expr,
as_span,
type_,
..
} => {
let col = parse_column(
type_.clone(),
Identifier::new("", as_span.clone()),
typer.issues,
None,
);
if typer.dialect().is_maria() {
match type_.type_ {
qusql_parse::Type::Char(_)
| qusql_parse::Type::Date
| qusql_parse::Type::Inet4
| qusql_parse::Type::Inet6
| qusql_parse::Type::DateTime(_)
| qusql_parse::Type::Double(_)
| qusql_parse::Type::Float8
| qusql_parse::Type::Float(_)
| qusql_parse::Type::Integer(_)
| qusql_parse::Type::Int(_)
| qusql_parse::Type::Binary(_)
| qusql_parse::Type::Timestamptz
| qusql_parse::Type::Time(_) => {}
qusql_parse::Type::Boolean
| qusql_parse::Type::TinyInt(_)
| qusql_parse::Type::SmallInt(_)
| qusql_parse::Type::BigInt(_)
| qusql_parse::Type::VarChar(_)
| qusql_parse::Type::TinyText(_)
| qusql_parse::Type::MediumText(_)
| qusql_parse::Type::Text(_)
| qusql_parse::Type::LongText(_)
| qusql_parse::Type::Enum(_)
| qusql_parse::Type::Set(_)
| qusql_parse::Type::Numeric(_, _, _)
| qusql_parse::Type::Timestamp(_)
| qusql_parse::Type::TinyBlob(_)
| qusql_parse::Type::MediumBlob(_)
| qusql_parse::Type::Blob(_)
| qusql_parse::Type::LongBlob(_)
| qusql_parse::Type::Json
| qusql_parse::Type::Bit(_, _)
| qusql_parse::Type::Bytea
| qusql_parse::Type::Named(_) | qusql_parse::Type::VarBinary(_) => {
typer
.err("Type not allow in cast", type_);
}
};
} else {
}
let e = type_expression(typer, expr, flags, col.type_.base());
FullType::new(col.type_.t, e.not_null)
}
Expression::Count { expr, .. } => {
match expr.deref() {
Expression::Identifier(parts) => {
resolve_kleene_identifier(typer, parts, &None, |_, _, _, _, _| {})
}
arg => {
type_expression(typer, arg, flags.without_values(), BaseType::Any);
}
}
FullType::new(BaseType::Integer, true)
}
Expression::GroupConcat { expr, .. } => {
let e = type_expression(typer, expr, flags.without_values(), BaseType::Any);
FullType::new(BaseType::String, e.not_null)
}
Expression::Variable {
variable,
variable_span,
..
} => match variable {
Variable::TimeZone => FullType::new(BaseType::String, true),
Variable::Other(_) => {
typer.err("Unknown variable", variable_span);
FullType::new(BaseType::Any, false)
}
},
Expression::Interval {
time_interval,
time_unit,
..
} => {
let cnt = match time_unit.0 {
qusql_parse::TimeUnit::Microsecond => 1,
qusql_parse::TimeUnit::Second => 1,
qusql_parse::TimeUnit::Minute => 1,
qusql_parse::TimeUnit::Hour => 1,
qusql_parse::TimeUnit::Day => 1,
qusql_parse::TimeUnit::Week => 1,
qusql_parse::TimeUnit::Month => 1,
qusql_parse::TimeUnit::Quarter => 1,
qusql_parse::TimeUnit::Year => 1,
qusql_parse::TimeUnit::SecondMicrosecond => 2,
qusql_parse::TimeUnit::MinuteMicrosecond => 3,
qusql_parse::TimeUnit::MinuteSecond => 2,
qusql_parse::TimeUnit::HourMicrosecond => 4,
qusql_parse::TimeUnit::HourSecond => 3,
qusql_parse::TimeUnit::HourMinute => 2,
qusql_parse::TimeUnit::DayMicrosecond => 5,
qusql_parse::TimeUnit::DaySecond => 4,
qusql_parse::TimeUnit::DayMinute => 3,
qusql_parse::TimeUnit::DayHour => 2,
qusql_parse::TimeUnit::YearMonth => 2,
};
if cnt != time_interval.0.len() {
typer.err(
format!(
"Expected {} values for {:?} got {}",
cnt,
time_unit.0,
time_interval.0.len()
),
&time_interval.1,
);
}
FullType::new(BaseType::TimeInterval, true)
}
Expression::Extract { date, .. } => {
let t = type_expression(typer, date, flags, BaseType::Any);
FullType::new(BaseType::Integer, t.not_null)
}
Expression::TimestampAdd {
interval, datetime, ..
} => {
let t1 = type_expression(typer, interval, flags, BaseType::Integer);
let t2 = type_expression(typer, datetime, flags, BaseType::Any);
typer.ensure_base(interval, &t1, BaseType::Integer);
typer.ensure_datetime(interval, &t2, Restrict::Require, Restrict::Allow);
FullType::new(BaseType::DateTime, t1.not_null && t2.not_null)
}
Expression::TimestampDiff { e1, e2, .. } => {
let t1 = type_expression(typer, e1, flags, BaseType::Any);
let t2 = type_expression(typer, e2, flags, BaseType::Any);
typer.ensure_datetime(e1, &t1, Restrict::Require, Restrict::Allow);
typer.ensure_datetime(e2, &t2, Restrict::Require, Restrict::Allow);
FullType::new(BaseType::Integer, t1.not_null && t2.not_null)
}
}
}