use alloc::{format, vec::Vec};
use qusql_parse::{Expression, Function, Span};
use crate::{
Type,
type_::{BaseType, FullType},
type_expression::{ExpressionFlags, type_expression},
typer::{Restrict, Typer},
};
fn arg_cnt<'a>(
typer: &mut Typer<'a, '_>,
rng: core::ops::Range<usize>,
args: &[Expression<'a>],
span: &Span,
) {
if args.len() >= rng.start && args.len() <= rng.end {
return;
}
let mut issue = if rng.is_empty() {
typer.err(
format!("Expected {} arguments got {}", rng.start, args.len()),
span,
)
} else {
typer.err(
format!(
"Expected between {} and {} arguments got {}",
rng.start,
rng.end,
args.len()
),
span,
)
};
if let Some(args) = args.get(rng.end..) {
for (cnt, arg) in args.iter().enumerate() {
issue.frag(format!("Argument {}", rng.end + cnt), arg);
}
}
}
fn typed_args<'a, 'b, 'c>(
typer: &mut Typer<'a, 'b>,
args: &'c [Expression<'a>],
flags: ExpressionFlags,
) -> Vec<(&'c Expression<'a>, FullType<'a>)> {
let mut typed: Vec<(&'_ Expression, FullType<'a>)> = Vec::new();
for arg in args {
typed.push((
arg,
type_expression(typer, arg, flags.without_values(), BaseType::Any),
));
}
typed
}
pub(crate) fn type_function<'a, 'b>(
typer: &mut Typer<'a, 'b>,
func: &Function<'a>,
args: &[Expression<'a>],
span: &Span,
flags: ExpressionFlags,
) -> FullType<'a> {
let mut tf = |return_type: Type<'a>,
required_args: &[BaseType],
optional_args: &[BaseType]|
-> FullType<'a> {
let mut not_null = true;
let mut arg_iter = args.iter();
arg_cnt(
typer,
required_args.len()..required_args.len() + optional_args.len(),
args,
span,
);
for et in required_args {
if let Some(arg) = arg_iter.next() {
let t = type_expression(typer, arg, flags.without_values(), *et);
not_null = not_null && t.not_null;
typer.ensure_base(arg, &t, *et);
}
}
for et in optional_args {
if let Some(arg) = arg_iter.next() {
let t = type_expression(typer, arg, flags.without_values(), *et);
not_null = not_null && t.not_null;
typer.ensure_base(arg, &t, *et);
}
}
for arg in arg_iter {
type_expression(typer, arg, flags.without_values(), BaseType::Any);
}
FullType::new(return_type, not_null)
};
match func {
Function::Rand => tf(Type::F64, &[], &[BaseType::Integer]),
Function::Right | Function::Left => tf(
BaseType::String.into(),
&[BaseType::String, BaseType::Integer],
&[],
),
Function::SubStr => {
arg_cnt(typer, 2..3, args, span);
let mut return_type = if let Some(arg) = args.first() {
let t = type_expression(typer, arg, flags.without_values(), BaseType::Any);
if !matches!(t.base(), BaseType::Any | BaseType::String | BaseType::Bytes) {
typer.err(format!("Expected type String or Bytes got {t}"), arg);
}
t
} else {
FullType::invalid()
};
if let Some(arg) = args.get(1) {
let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
return_type.not_null = return_type.not_null && t.not_null;
typer.ensure_base(arg, &t, BaseType::Integer);
};
if let Some(arg) = args.get(2) {
let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
return_type.not_null = return_type.not_null && t.not_null;
typer.ensure_base(arg, &t, BaseType::Integer);
};
return_type
}
Function::FindInSet => tf(
BaseType::Integer.into(),
&[BaseType::String, BaseType::String],
&[],
),
Function::SubStringIndex => tf(
BaseType::String.into(),
&[BaseType::String, BaseType::String, BaseType::Integer],
&[],
),
Function::ExtractValue => tf(
BaseType::String.into(),
&[BaseType::String, BaseType::String],
&[],
),
Function::Replace => tf(
BaseType::String.into(),
&[BaseType::String, BaseType::String, BaseType::String],
&[],
),
Function::CharacterLength => tf(BaseType::Integer.into(), &[BaseType::String], &[]),
Function::UnixTimestamp => {
let mut not_null = true;
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 0..1, args, span);
if let Some((a, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*a, t, BaseType::DateTime);
}
FullType::new(Type::I64, not_null)
}
Function::IfNull => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let t = if let Some((e, t)) = typed.first() {
if t.not_null {
typer.warn("Cannot be null", *e);
}
t.clone()
} else {
FullType::invalid()
};
if let Some((e, t2)) = typed.get(1) {
typer.ensure_type(*e, t2, &t);
t2.clone()
} else {
t.clone()
}
}
Function::Lead | Function::Lag => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..2, args, span);
if let Some((a, t)) = typed.get(1) {
typer.ensure_base(*a, t, BaseType::Integer);
}
if let Some((_, t)) = typed.first() {
let mut t = t.clone();
t.not_null = false;
t
} else {
FullType::invalid()
}
}
Function::JsonExtract => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..999, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(Type::JSON, false)
}
Function::JsonValue => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(Type::JSON, false)
}
Function::JsonReplace => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 3..999, args, span);
for (i, (a, t)) in typed.iter().enumerate() {
if i == 0 || i % 2 == 1 {
typer.ensure_base(*a, t, BaseType::String);
}
}
FullType::new(Type::JSON, false)
}
Function::JsonSet => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 3..999, args, span);
for (i, (a, t)) in typed.iter().enumerate() {
if i == 0 || i % 2 == 1 {
typer.ensure_base(*a, t, BaseType::String);
}
}
FullType::new(Type::JSON, false)
}
Function::JsonUnquote => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(BaseType::String, false)
}
Function::JsonQuery => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(Type::JSON, false)
}
Function::JsonRemove => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..999, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(Type::JSON, false)
}
Function::JsonContains => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..3, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
if let (Some(t0), Some(t1), t2) = (typed.first(), typed.get(1), typed.get(2)) {
let not_null =
t0.1.not_null && t1.1.not_null && t2.map(|t| t.1.not_null).unwrap_or(true);
FullType::new(Type::Base(BaseType::Bool), not_null)
} else {
FullType::invalid()
}
}
Function::JsonContainsPath => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 3..999, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
FullType::new(Type::JSON, false)
}
Function::JsonOverlaps => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::String);
}
if let (Some(t0), Some(t1)) = (typed.first(), typed.get(1)) {
let not_null = t0.1.not_null && t1.1.not_null;
FullType::new(Type::Base(BaseType::Bool), not_null)
} else {
FullType::invalid()
}
}
Function::Min | Function::Max | Function::Sum => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
if let Some((_, t2)) = typed.first() {
let mut v = t2.clone();
v.not_null = false;
v
} else {
FullType::invalid()
}
}
Function::Now => tf(BaseType::DateTime.into(), &[], &[BaseType::Integer]),
Function::CurDate | Function::UtcDate => tf(BaseType::Date.into(), &[], &[]),
Function::CurTime | Function::UtcTime => tf(BaseType::Time.into(), &[], &[]),
Function::UtcTimeStamp => tf(BaseType::DateTime.into(), &[], &[]),
Function::CurrentTimestamp => tf(BaseType::TimeStamp.into(), &[], &[BaseType::Integer]),
Function::Concat => {
let typed = typed_args(typer, args, flags);
let mut not_null = true;
for (a, t) in &typed {
typer.ensure_base(*a, t, BaseType::Any);
not_null = not_null && t.not_null;
}
FullType::new(BaseType::String, not_null)
}
Function::Least | Function::Greatest => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..9999, args, span);
if let Some((a, at)) = typed.first() {
let mut not_null = true;
let mut t = at.t.clone();
for (b, bt) in &typed[1..] {
not_null = not_null && bt.not_null;
if bt.t == t {
continue;
};
if let Some(tt) = typer.matched_type(&bt.t, &t) {
t = tt;
} else {
typer
.err("None matching input types", span)
.frag(format!("Type {}", at.t), *a)
.frag(format!("Type {}", bt.t), *b);
}
}
FullType::new(t, true);
}
FullType::new(BaseType::Any, true)
}
Function::If => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 3..3, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::Bool);
}
let mut ans = FullType::invalid();
if let Some((e1, t1)) = typed.get(1) {
not_null = not_null && t1.not_null;
if let Some((e2, t2)) = typed.get(2) {
not_null = not_null && t2.not_null;
if let Some(t) = typer.matched_type(t1, t2) {
ans = FullType::new(t, not_null);
} else {
typer
.err("Incompatible types", span)
.frag(format!("Of type {}", t1.t), *e1)
.frag(format!("Of type {}", t2.t), *e2);
}
}
}
ans
}
Function::FromUnixTime => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::Float);
}
if let Some((e, t)) = typed.get(1) {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
FullType::new(BaseType::String, not_null)
} else {
FullType::new(BaseType::DateTime, not_null)
}
}
Function::DateFormat => tf(
BaseType::String.into(),
&[BaseType::DateTime, BaseType::String],
&[BaseType::String],
),
Function::Value => {
let typed = typed_args(typer, args, flags);
if !flags.in_on_duplicate_key_update {
typer.err("VALUE is only allowed within ON DUPLICATE KEY UPDATE", span);
}
arg_cnt(typer, 1..1, args, span);
if let Some((_, t)) = typed.first() {
t.clone()
} else {
FullType::invalid()
}
}
Function::Length => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let mut not_null = true;
for (_, t) in &typed {
not_null = not_null && t.not_null;
if typer
.matched_type(t, &FullType::new(BaseType::String, false))
.is_none()
&& typer
.matched_type(t, &FullType::new(BaseType::Bytes, false))
.is_none()
{
typer.err(format!("Expected type Bytes or String got {t}"), span);
}
}
FullType::new(Type::I64, not_null)
}
Function::Strftime => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
if let Some((e, t)) = typed.last() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::DateTime);
}
FullType::new(BaseType::String, not_null)
}
Function::StartsWith => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
if let Some((e, t)) = typed.last() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
FullType::new(BaseType::Bool, not_null)
}
Function::Datetime => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
FullType::new(BaseType::DateTime, not_null)
}
Function::AddMonths => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.last() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::Integer);
}
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
let t = typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
FullType::new(t, not_null)
} else {
FullType::invalid()
}
}
Function::AddDate | Function::DateSub => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
if let Some((e, t)) = typed.last()
&& t.base() != BaseType::Integer
{
typer.ensure_base(*e, t, BaseType::TimeInterval);
}
if let Some((e, t)) = typed.first() {
let t = typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
FullType::new(t, false)
} else {
FullType::invalid()
}
}
Function::AddTime | Function::SubTime => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.last() {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::Time);
}
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
let t = typer.ensure_datetime(*e, t, Restrict::Allow, Restrict::Require);
FullType::new(t, not_null)
} else {
FullType::invalid()
}
}
Function::ConvertTz => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 3..3, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.get(1) {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
if let Some((e, t)) = typed.get(2) {
not_null = not_null && t.not_null;
typer.ensure_base(*e, t, BaseType::String);
}
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
let t = typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Require);
FullType::new(t, not_null)
} else {
FullType::invalid()
}
}
Function::Date | Function::LastDay => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
t.not_null
} else {
true
};
FullType::new(BaseType::Date, not_null)
}
Function::Time => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Allow, Restrict::Require);
t.not_null
} else {
true
};
FullType::new(BaseType::Time, not_null)
}
Function::DateDiff => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let mut not_null = true;
if let Some((e, t)) = typed.last() {
not_null = not_null && t.not_null;
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
}
if let Some((e, t)) = typed.first() {
not_null = not_null && t.not_null;
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
}
FullType::new(BaseType::Integer, not_null)
}
Function::DayName | Function::MonthName => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
t.not_null
} else {
true
};
FullType::new(BaseType::String, not_null)
}
Function::DayOfMonth
| Function::DayOfWeek
| Function::DayOfYear
| Function::Month
| Function::Quarter
| Function::ToDays
| Function::ToSeconds
| Function::Year
| Function::Weekday => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
t.not_null
} else {
true
};
FullType::new(BaseType::Integer, not_null)
}
Function::Week | Function::YearWeek => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..2, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
t.not_null
} else {
true
};
if let Some((e, t)) = typed.get(2) {
typer.ensure_base(*e, t, BaseType::Integer);
};
FullType::new(BaseType::Integer, not_null)
}
Function::Hour | Function::MicroSecond | Function::Minute | Function::Second => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..1, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Allow, Restrict::Require);
t.not_null
} else {
true
};
FullType::new(BaseType::Integer, not_null)
}
Function::FromDays => tf(BaseType::Date.into(), &[BaseType::Integer], &[]),
Function::SecToTime => tf(BaseType::Time.into(), &[BaseType::Integer], &[]),
Function::MakeDate => FullType::new(
tf(
BaseType::Date.into(),
&[BaseType::Integer, BaseType::Integer],
&[],
)
.t,
false,
),
Function::MakeTime => FullType::new(
tf(
BaseType::Time.into(),
&[BaseType::Integer, BaseType::Integer, BaseType::Integer],
&[],
)
.t,
false,
),
Function::PeriodAdd | Function::PeriodDiff => tf(
BaseType::Integer.into(),
&[BaseType::Integer, BaseType::Integer],
&[],
),
Function::StrToDate => tf(
BaseType::DateTime.into(),
&[BaseType::String, BaseType::String],
&[],
),
Function::SysDate => tf(BaseType::DateTime.into(), &[], &[]),
Function::TimeFormat => tf(
BaseType::String.into(),
&[BaseType::Time, BaseType::String],
&[],
),
Function::TimeToSec => tf(BaseType::Float.into(), &[BaseType::Time], &[]),
Function::TimeDiff => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 2..2, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Allow, Restrict::Require);
t.not_null
} else {
true
};
let not_null = if let Some((e, t)) = typed.last() {
typer.ensure_datetime(*e, t, Restrict::Allow, Restrict::Require);
t.not_null & not_null
} else {
not_null
};
FullType::new(BaseType::Time, not_null)
}
Function::Timestamp => {
let typed = typed_args(typer, args, flags);
arg_cnt(typer, 1..2, args, span);
let not_null = if let Some((e, t)) = typed.first() {
typer.ensure_datetime(*e, t, Restrict::Require, Restrict::Allow);
t.not_null
} else {
true
};
let not_null = if let Some((e, t)) = typed.get(2) {
typer.ensure_datetime(*e, t, Restrict::Disallow, Restrict::Require);
t.not_null & not_null
} else {
not_null
};
FullType::new(BaseType::DateTime, not_null)
}
Function::Sleep => tf(BaseType::Integer.into(), &[BaseType::Float], &[]),
_ => {
typer.err("Typing for function not implemented", span);
FullType::invalid()
}
}
}