use super::eval::{self, EvalCtx};
use crate::error::{Error, Result};
use crate::sql::ast::Expr;
use crate::value::Value;
use alloc::string::String;
use alloc::vec::Vec;
pub fn is_aggregate(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"count" | "sum" | "total" | "avg" | "min" | "max" | "group_concat"
)
}
pub fn eval_scalar(name: &str, args: &[Expr], star: bool, ctx: &EvalCtx) -> Result<Value> {
let lname = name.to_ascii_lowercase();
if is_aggregate(&lname) {
return Err(Error::Error(alloc::format!(
"aggregate function {name} used outside an aggregate context"
)));
}
if star {
return Err(Error::Error(alloc::format!(
"{name}(*) is not a scalar call"
)));
}
match lname.as_str() {
"coalesce" => {
for a in args {
let v = eval::eval(a, ctx)?;
if !matches!(v, Value::Null) {
return Ok(v);
}
}
return Ok(Value::Null);
}
"ifnull" => {
arity(&lname, args, 2)?;
let a = eval::eval(&args[0], ctx)?;
return if matches!(a, Value::Null) {
eval::eval(&args[1], ctx)
} else {
Ok(a)
};
}
_ => {}
}
let v: Vec<Value> = args
.iter()
.map(|a| eval::eval(a, ctx))
.collect::<Result<_>>()?;
Ok(match lname.as_str() {
"abs" => {
arity(&lname, args, 1)?;
match eval::to_number(&v[0]) {
Value::Integer(i) => Value::Integer(i.wrapping_abs()),
Value::Real(r) => Value::Real(crate::util::float::abs(r)),
_ => Value::Null,
}
}
"length" => {
arity(&lname, args, 1)?;
match &v[0] {
Value::Null => Value::Null,
Value::Blob(b) => Value::Integer(b.len() as i64),
other => Value::Integer(eval::to_text(other).chars().count() as i64),
}
}
"lower" => {
arity(&lname, args, 1)?;
str_map(&v[0], |s| s.to_lowercase())
}
"upper" => {
arity(&lname, args, 1)?;
str_map(&v[0], |s| s.to_uppercase())
}
"trim" => trim_fn(&v, true, true),
"ltrim" => trim_fn(&v, true, false),
"rtrim" => trim_fn(&v, false, true),
"typeof" => Value::Text(String::from(type_name(&v[0]))),
"nullif" => {
arity(&lname, args, 2)?;
if eval::compare(&v[0], &v[1]) == core::cmp::Ordering::Equal {
Value::Null
} else {
v[0].clone()
}
}
"n/a" => unreachable!(),
"substr" | "substring" => substr(&v)?,
"instr" => instr(&v)?,
"replace" => replace(&v)?,
"round" => round(&v)?,
"min" => scalar_min_max(&v, true),
"max" => scalar_min_max(&v, false),
"hex" => Value::Text(hex_encode(&v[0])),
"char" => char_fn(&v),
"unicode" => match &v[0] {
Value::Null => Value::Null,
other => eval::to_text(other)
.chars()
.next()
.map(|c| Value::Integer(c as i64))
.unwrap_or(Value::Null),
},
"iif" => {
arity(&lname, args, 3)?;
if eval::truth(&v[0]) == Some(true) {
v[1].clone()
} else {
v[2].clone()
}
}
_ => return Err(Error::Unsupported("unknown scalar function")),
})
}
fn arity(name: &str, args: &[Expr], n: usize) -> Result<()> {
if args.len() == n {
Ok(())
} else {
Err(Error::Error(alloc::format!(
"wrong number of arguments to function {name}() (want {n}, got {})",
args.len()
)))
}
}
fn str_map(v: &Value, f: impl Fn(&str) -> String) -> Value {
match v {
Value::Null => Value::Null,
other => Value::Text(f(&eval::to_text(other))),
}
}
fn type_name(v: &Value) -> &'static str {
match v {
Value::Null => "null",
Value::Integer(_) => "integer",
Value::Real(_) => "real",
Value::Text(_) => "text",
Value::Blob(_) => "blob",
}
}
fn trim_fn(v: &[Value], left: bool, right: bool) -> Value {
if v.is_empty() || matches!(v[0], Value::Null) {
return Value::Null;
}
let s = eval::to_text(&v[0]);
let trim_chars: Vec<char> = if v.len() >= 2 {
eval::to_text(&v[1]).chars().collect()
} else {
alloc::vec![' ']
};
let is_trim = |c: char| trim_chars.contains(&c);
let chars: Vec<char> = s.chars().collect();
let mut start = 0;
let mut end = chars.len();
if left {
while start < end && is_trim(chars[start]) {
start += 1;
}
}
if right {
while end > start && is_trim(chars[end - 1]) {
end -= 1;
}
}
Value::Text(chars[start..end].iter().collect())
}
fn substr(v: &[Value]) -> Result<Value> {
if v.len() < 2 || v.len() > 3 {
return Err(Error::Error("substr() takes 2 or 3 arguments".into()));
}
if matches!(v[0], Value::Null) {
return Ok(Value::Null);
}
let s: Vec<char> = eval::to_text(&v[0]).chars().collect();
let len = s.len() as i64;
let mut start = eval::to_i64(&v[1]);
if start < 0 {
start += len + 1;
}
if start < 1 {
start = 1;
}
let count = if v.len() == 3 {
eval::to_i64(&v[2])
} else {
len
};
let begin = (start - 1).clamp(0, len) as usize;
let end = if count < 0 {
begin
} else {
((start - 1 + count).clamp(0, len)) as usize
};
Ok(Value::Text(s[begin..end.max(begin)].iter().collect()))
}
fn instr(v: &[Value]) -> Result<Value> {
if v.len() != 2 {
return Err(Error::Error("instr() takes 2 arguments".into()));
}
if matches!(v[0], Value::Null) || matches!(v[1], Value::Null) {
return Ok(Value::Null);
}
let hay = eval::to_text(&v[0]);
let needle = eval::to_text(&v[1]);
match hay.find(&needle) {
None => Ok(Value::Integer(0)),
Some(byte_idx) => {
let char_idx = hay[..byte_idx].chars().count();
Ok(Value::Integer(char_idx as i64 + 1))
}
}
}
fn replace(v: &[Value]) -> Result<Value> {
if v.len() != 3 {
return Err(Error::Error("replace() takes 3 arguments".into()));
}
if v.iter().any(|x| matches!(x, Value::Null)) {
return Ok(Value::Null);
}
let s = eval::to_text(&v[0]);
let from = eval::to_text(&v[1]);
let to = eval::to_text(&v[2]);
if from.is_empty() {
return Ok(Value::Text(s));
}
Ok(Value::Text(s.replace(&from, &to)))
}
fn round(v: &[Value]) -> Result<Value> {
if v.is_empty() || v.len() > 2 {
return Err(Error::Error("round() takes 1 or 2 arguments".into()));
}
if matches!(v[0], Value::Null) {
return Ok(Value::Null);
}
let x = eval::to_f64(&v[0]);
let digits = if v.len() == 2 {
eval::to_i64(&v[1]).max(0)
} else {
0
};
let factor = crate::util::float::powi(10.0, digits as i32);
Ok(Value::Real(crate::util::float::round(x * factor) / factor))
}
fn scalar_min_max(v: &[Value], want_min: bool) -> Value {
if v.iter().any(|x| matches!(x, Value::Null)) {
return Value::Null;
}
let mut best = v[0].clone();
for x in &v[1..] {
let ord = eval::compare(x, &best);
let take = if want_min {
ord == core::cmp::Ordering::Less
} else {
ord == core::cmp::Ordering::Greater
};
if take {
best = x.clone();
}
}
best
}
fn hex_encode(v: &Value) -> String {
let bytes = match v {
Value::Blob(b) => b.clone(),
other => eval::to_text(other).into_bytes(),
};
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
s.push(nibble(b >> 4));
s.push(nibble(b & 0xf));
}
s
}
fn nibble(n: u8) -> char {
match n {
0..=9 => (b'0' + n) as char,
_ => (b'A' + n - 10) as char,
}
}
fn char_fn(v: &[Value]) -> Value {
let mut s = String::new();
for x in v {
if let Some(c) = char::from_u32(eval::to_i64(x) as u32) {
s.push(c);
}
}
Value::Text(s)
}