use alloc::{format, vec::Vec};
use qusql_parse::{BinaryOperator, Expression, Spanned};
use crate::{
Type,
type_::{BaseType, FullType},
type_expression::{ExpressionFlags, type_expression},
typer::{Restrict, Typer},
};
pub(crate) fn type_binary_expression<'a>(
typer: &mut Typer<'a, '_>,
op: &BinaryOperator<'a>,
lhs: &Expression<'a>,
rhs: &Expression<'a>,
flags: ExpressionFlags,
) -> FullType<'a> {
let op_span = op.span();
let (flags, context) = match op {
BinaryOperator::Assignment(_) => (flags, BaseType::Any),
BinaryOperator::And(_) => {
if flags.true_ {
(flags.with_not_null(true), BaseType::Bool)
} else {
(flags, BaseType::Bool)
}
}
BinaryOperator::Or(_) if flags.true_ => {
let child_flags = flags.with_not_null(true);
let snapshot_not_null: Vec<Vec<bool>> = typer
.reference_types
.iter()
.map(|rt| rt.columns.iter().map(|c| c.1.not_null).collect())
.collect();
let lhs_type = type_expression(typer, lhs, child_flags, BaseType::Bool);
let after_lhs_not_null: Vec<Vec<bool>> = typer
.reference_types
.iter()
.map(|rt| rt.columns.iter().map(|c| c.1.not_null).collect())
.collect();
for (rt, snap_nn) in typer
.reference_types
.iter_mut()
.zip(snapshot_not_null.iter())
{
for (col, &nn) in rt.columns.iter_mut().zip(snap_nn.iter()) {
col.1.not_null = nn;
}
}
let rhs_type = type_expression(typer, rhs, child_flags, BaseType::Bool);
for (cur, lhs_nn) in typer
.reference_types
.iter_mut()
.zip(after_lhs_not_null.iter())
{
for (cur_col, &lhs_col_nn) in cur.columns.iter_mut().zip(lhs_nn.iter()) {
cur_col.1.not_null = cur_col.1.not_null && lhs_col_nn;
}
}
typer.ensure_base(lhs, &lhs_type, BaseType::Bool);
typer.ensure_base(rhs, &rhs_type, BaseType::Bool);
return FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null);
}
BinaryOperator::Or(_) => (flags.without_values(), BaseType::Bool),
BinaryOperator::Xor(_) => (flags.without_values(), BaseType::Bool),
BinaryOperator::NullSafeEq(_) => (flags.without_values(), BaseType::Any),
BinaryOperator::Eq(_)
| BinaryOperator::GtEq(_)
| BinaryOperator::Gt(_)
| BinaryOperator::LtEq(_)
| BinaryOperator::Lt(_)
| BinaryOperator::Neq(_)
| BinaryOperator::Add(_)
| BinaryOperator::Subtract(_)
| BinaryOperator::Divide(_)
| BinaryOperator::Div(_)
| BinaryOperator::Mod(_)
| BinaryOperator::Mult(_) => {
if flags.true_ {
(flags.with_not_null(true).with_true(false), BaseType::Any)
} else {
(flags, BaseType::Any)
}
}
BinaryOperator::Like(_)
| BinaryOperator::NotLike(_)
| BinaryOperator::Regexp(_)
| BinaryOperator::NotRegexp(_)
| BinaryOperator::Rlike(_)
| BinaryOperator::NotRlike(_) => {
if flags.true_ {
(flags.with_not_null(true).with_true(false), BaseType::String)
} else {
(flags, BaseType::String)
}
}
BinaryOperator::ShiftLeft(_)
| BinaryOperator::ShiftRight(_)
| BinaryOperator::BitAnd(_)
| BinaryOperator::BitOr(_)
| BinaryOperator::BitXor(_) => {
if flags.true_ {
(
flags.with_not_null(true).with_true(false),
BaseType::Integer,
)
} else {
(flags, BaseType::Integer)
}
}
BinaryOperator::Collate(_) => (flags, BaseType::String),
BinaryOperator::Concat(_) => (flags.without_values(), BaseType::String),
BinaryOperator::JsonExtract(_) => (flags, BaseType::String), BinaryOperator::JsonExtractUnquote(_) => (flags, BaseType::String), BinaryOperator::User(_, _) => (flags, BaseType::Any),
BinaryOperator::Contains(_)
| BinaryOperator::ContainedBy(_)
| BinaryOperator::JsonPathMatch(_)
| BinaryOperator::JsonPathExists(_)
| BinaryOperator::JsonbKeyExists(_)
| BinaryOperator::JsonbAnyKeyExists(_)
| BinaryOperator::JsonbAllKeyExists(_) => (flags.without_values(), BaseType::Any),
BinaryOperator::JsonGetPath(_)
| BinaryOperator::JsonGetPathText(_)
| BinaryOperator::JsonDeletePath(_) => (flags, BaseType::Any),
BinaryOperator::RegexMatch(_)
| BinaryOperator::RegexIMatch(_)
| BinaryOperator::NotRegexMatch(_)
| BinaryOperator::NotRegexIMatch(_) => (flags.without_values(), BaseType::String),
BinaryOperator::Operator(_, _) => (flags, BaseType::Any),
};
let lhs_type = type_expression(typer, lhs, flags, context);
let rhs_type = type_expression(typer, rhs, flags, context);
match op {
BinaryOperator::Or(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::Bool);
typer.ensure_base(rhs, &rhs_type, BaseType::Bool);
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Concat(_) => {
if let Some(t) = typer.matched_type(&lhs_type, &rhs_type) {
return FullType::new(t, lhs_type.not_null && rhs_type.not_null);
}
typer.ensure_base(lhs, &lhs_type, BaseType::String);
typer.ensure_base(rhs, &rhs_type, BaseType::String);
FullType::new(BaseType::String, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Xor(_) | BinaryOperator::And(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::Bool);
typer.ensure_base(rhs, &rhs_type, BaseType::Bool);
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Eq(_)
| BinaryOperator::Neq(_)
| BinaryOperator::GtEq(_)
| BinaryOperator::Gt(_)
| BinaryOperator::LtEq(_)
| BinaryOperator::Lt(_) => {
if lhs_type.t == Type::Null {
typer.warn("Comparison with null", lhs);
}
if rhs_type.t == Type::Null {
typer.warn("Comparison with null", rhs);
}
if typer.matched_type(&lhs_type, &rhs_type).is_none() {
typer
.err("Type error in comparison", &op_span)
.frag(format!("Of type {}", lhs_type.t), lhs)
.frag(format!("Of type {}", rhs_type.t), rhs);
}
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::NullSafeEq(_) => {
if typer.matched_type(&lhs_type, &rhs_type).is_none() {
typer
.err("Type error in comparison", &op_span)
.frag(format!("Of type {}", lhs_type.t), lhs)
.frag(format!("Of type {}", rhs_type.t), rhs);
}
FullType::new(BaseType::Bool, true)
}
BinaryOperator::ShiftLeft(_)
| BinaryOperator::ShiftRight(_)
| BinaryOperator::BitAnd(_)
| BinaryOperator::BitOr(_)
| BinaryOperator::BitXor(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::Integer);
typer.ensure_base(rhs, &rhs_type, BaseType::Integer);
FullType::new(BaseType::Integer, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Add(_) | BinaryOperator::Subtract(_) => {
if matches!(lhs_type.base(), BaseType::TimeInterval) {
let t =
typer.ensure_datetime(&op_span, &rhs_type, Restrict::Allow, Restrict::Allow);
FullType::new(t, lhs_type.not_null && rhs_type.not_null)
} else if matches!(rhs_type.base(), BaseType::TimeInterval) {
let t =
typer.ensure_datetime(&op_span, &lhs_type, Restrict::Allow, Restrict::Allow);
FullType::new(t, lhs_type.not_null && rhs_type.not_null)
} else if let Some(t) = typer.matched_type(&lhs_type, &rhs_type) {
match t.base() {
BaseType::Any | BaseType::Float | BaseType::Integer => {
FullType::new(t, lhs_type.not_null && rhs_type.not_null)
}
_ => {
typer
.err("Type error in addition/subtraction", &op_span)
.frag(format!("type {}", lhs_type.t), lhs)
.frag(format!("type {}", rhs_type.t), rhs);
FullType::invalid()
}
}
} else {
typer
.err("Type error in addition/subtraction", &op_span)
.frag(format!("type {}", lhs_type.t), lhs)
.frag(format!("type {}", rhs_type.t), rhs);
FullType::invalid()
}
}
BinaryOperator::Divide(_)
| BinaryOperator::Div(_)
| BinaryOperator::Mod(_)
| BinaryOperator::Mult(_) => {
if let Some(t) = typer.matched_type(&lhs_type, &rhs_type) {
match t.base() {
BaseType::Any | BaseType::Float | BaseType::Integer => {
FullType::new(t, lhs_type.not_null && rhs_type.not_null)
}
_ => {
typer
.err("Type error in multiplication/division", &op_span)
.frag(format!("type {}", lhs_type.t), lhs)
.frag(format!("type {}", rhs_type.t), rhs);
FullType::invalid()
}
}
} else {
typer
.err("Type error in multiplication/division", &op_span)
.frag(format!("type {}", lhs_type.t), lhs)
.frag(format!("type {}", rhs_type.t), rhs);
FullType::invalid()
}
}
BinaryOperator::Like(_)
| BinaryOperator::NotLike(_)
| BinaryOperator::Regexp(_)
| BinaryOperator::NotRegexp(_)
| BinaryOperator::Rlike(_)
| BinaryOperator::NotRlike(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::String);
typer.ensure_base(rhs, &rhs_type, BaseType::String);
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Collate(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::String);
lhs_type
}
BinaryOperator::JsonExtract(_) | BinaryOperator::JsonExtractUnquote(_) => {
typer.ensure_base(rhs, &rhs_type, BaseType::String);
FullType::new(BaseType::String, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Contains(_)
| BinaryOperator::ContainedBy(_)
| BinaryOperator::JsonPathMatch(_)
| BinaryOperator::JsonPathExists(_)
| BinaryOperator::JsonbKeyExists(_)
| BinaryOperator::JsonbAnyKeyExists(_)
| BinaryOperator::JsonbAllKeyExists(_) => {
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::JsonGetPath(_) | BinaryOperator::JsonGetPathText(_) => {
FullType::new(BaseType::String, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::JsonDeletePath(_) => {
FullType::new(BaseType::Any, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::RegexMatch(_)
| BinaryOperator::RegexIMatch(_)
| BinaryOperator::NotRegexMatch(_)
| BinaryOperator::NotRegexIMatch(_) => {
typer.ensure_base(lhs, &lhs_type, BaseType::String);
typer.ensure_base(rhs, &rhs_type, BaseType::String);
FullType::new(BaseType::Bool, lhs_type.not_null && rhs_type.not_null)
}
BinaryOperator::Assignment(_) => {
rhs_type
}
BinaryOperator::User(_, _) => {
FullType::new(BaseType::Any, lhs_type.not_null && rhs_type.not_null)
}
o @ BinaryOperator::Operator(_, _) => {
typer.err("Not supported", o);
FullType::invalid()
}
}
}