1use super::eval::{self, EvalCtx};
8use crate::error::{Error, Result};
9use crate::sql::ast::Expr;
10use crate::value::Value;
11use alloc::string::String;
12use alloc::vec::Vec;
13
14pub fn is_aggregate(name: &str) -> bool {
16 matches!(
17 name.to_ascii_lowercase().as_str(),
18 "count" | "sum" | "total" | "avg" | "min" | "max" | "group_concat"
19 )
20}
21
22pub fn is_aggregate_call(name: &str, nargs: usize, star: bool) -> bool {
25 match name.to_ascii_lowercase().as_str() {
26 "count" | "sum" | "total" | "avg" | "group_concat" => true,
27 "min" | "max" => star || nargs == 1,
28 _ => false,
29 }
30}
31
32pub fn eval_scalar(name: &str, args: &[Expr], star: bool, ctx: &EvalCtx) -> Result<Value> {
34 let lname = name.to_ascii_lowercase();
35 if is_aggregate_call(&lname, args.len(), star) {
36 return Err(Error::Error(alloc::format!(
37 "aggregate function {name} used outside an aggregate context"
38 )));
39 }
40 if star {
41 return Err(Error::Error(alloc::format!(
42 "{name}(*) is not a scalar call"
43 )));
44 }
45
46 match lname.as_str() {
48 "coalesce" => {
49 for a in args {
50 let v = eval::eval(a, ctx)?;
51 if !matches!(v, Value::Null) {
52 return Ok(v);
53 }
54 }
55 return Ok(Value::Null);
56 }
57 "ifnull" => {
58 arity(&lname, args, 2)?;
59 let a = eval::eval(&args[0], ctx)?;
60 return if matches!(a, Value::Null) {
61 eval::eval(&args[1], ctx)
62 } else {
63 Ok(a)
64 };
65 }
66 _ => {}
67 }
68
69 let v: Vec<Value> = args
70 .iter()
71 .map(|a| eval::eval(a, ctx))
72 .collect::<Result<_>>()?;
73
74 Ok(match lname.as_str() {
75 "abs" => {
76 arity(&lname, args, 1)?;
77 match eval::to_number(&v[0]) {
78 Value::Integer(i) => Value::Integer(i.wrapping_abs()),
79 Value::Real(r) => Value::Real(crate::util::float::abs(r)),
80 _ => Value::Null,
81 }
82 }
83 "length" => {
84 arity(&lname, args, 1)?;
85 match &v[0] {
86 Value::Null => Value::Null,
87 Value::Blob(b) => Value::Integer(b.len() as i64),
88 other => Value::Integer(eval::to_text(other).chars().count() as i64),
89 }
90 }
91 "lower" => {
92 arity(&lname, args, 1)?;
93 str_map(&v[0], |s| s.to_lowercase())
94 }
95 "upper" => {
96 arity(&lname, args, 1)?;
97 str_map(&v[0], |s| s.to_uppercase())
98 }
99 "trim" => trim_fn(&v, true, true),
100 "ltrim" => trim_fn(&v, true, false),
101 "rtrim" => trim_fn(&v, false, true),
102 "typeof" => Value::Text(String::from(type_name(&v[0]))),
103 "nullif" => {
104 arity(&lname, args, 2)?;
105 if eval::compare(&v[0], &v[1]) == core::cmp::Ordering::Equal {
106 Value::Null
107 } else {
108 v[0].clone()
109 }
110 }
111 "n/a" => unreachable!(),
112 "substr" | "substring" => substr(&v)?,
113 "instr" => instr(&v)?,
114 "replace" => replace(&v)?,
115 "round" => round(&v)?,
116 "min" => scalar_min_max(&v, true),
117 "max" => scalar_min_max(&v, false),
118 "hex" => Value::Text(hex_encode(&v[0])),
119 "char" => char_fn(&v),
120 "unicode" => match &v[0] {
121 Value::Null => Value::Null,
122 other => eval::to_text(other)
123 .chars()
124 .next()
125 .map(|c| Value::Integer(c as i64))
126 .unwrap_or(Value::Null),
127 },
128 "iif" => {
129 arity(&lname, args, 3)?;
130 if eval::truth(&v[0]) == Some(true) {
131 v[1].clone()
132 } else {
133 v[2].clone()
134 }
135 }
136 "zeroblob" => {
137 arity(&lname, args, 1)?;
138 match &v[0] {
139 Value::Null => Value::Null,
140 other => {
141 let n = eval::to_i64(other).max(0) as usize;
142 Value::Blob(alloc::vec![0u8; n])
143 }
144 }
145 }
146 "quote" => {
147 arity(&lname, args, 1)?;
148 Value::Text(quote_value(&v[0]))
149 }
150 "sign" => {
151 arity(&lname, args, 1)?;
152 match eval::to_number(&v[0]) {
153 Value::Integer(i) => Value::Integer(i.signum()),
154 Value::Real(r) => Value::Integer(if r > 0.0 {
155 1
156 } else if r < 0.0 {
157 -1
158 } else {
159 0
160 }),
161 _ => Value::Null,
162 }
163 }
164 "concat" => {
165 let mut s = String::new();
167 for x in &v {
168 if !matches!(x, Value::Null) {
169 s.push_str(&eval::to_text(x));
170 }
171 }
172 Value::Text(s)
173 }
174 "concat_ws" => {
175 if v.is_empty() {
176 return Err(Error::Error("concat_ws() needs a separator".into()));
177 }
178 if matches!(v[0], Value::Null) {
179 Value::Null
180 } else {
181 let sep = eval::to_text(&v[0]);
182 let parts: alloc::vec::Vec<String> = v[1..]
183 .iter()
184 .filter(|x| !matches!(x, Value::Null))
185 .map(eval::to_text)
186 .collect();
187 Value::Text(parts.join(&sep))
188 }
189 }
190 "unhex" => {
191 arity(&lname, args, 1)?;
192 match &v[0] {
193 Value::Null => Value::Null,
194 other => match unhex(&eval::to_text(other)) {
195 Some(b) => Value::Blob(b),
196 None => Value::Null,
197 },
198 }
199 }
200 "date" => super::datetime::date(&v),
202 "time" => super::datetime::time(&v),
203 "datetime" => super::datetime::datetime(&v),
204 "julianday" => super::datetime::julianday(&v),
205 "unixepoch" => super::datetime::unixepoch(&v),
206 "strftime" => super::datetime::strftime(&v),
207 "printf" | "format" => super::datetime::printf(&v),
208 _ => return Err(Error::Unsupported("unknown scalar function")),
209 })
210}
211
212fn quote_value(v: &Value) -> String {
214 match v {
215 Value::Null => String::from("NULL"),
216 Value::Integer(i) => alloc::format!("{i}"),
217 Value::Real(r) => eval::format_real(*r),
218 Value::Text(s) => alloc::format!("'{}'", s.replace('\'', "''")),
219 Value::Blob(b) => {
220 let mut s = String::from("x'");
221 for byte in b {
222 s.push_str(&alloc::format!("{byte:02x}"));
223 }
224 s.push('\'');
225 s
226 }
227 }
228}
229
230fn unhex(s: &str) -> Option<alloc::vec::Vec<u8>> {
232 let bytes = s.as_bytes();
233 if !bytes.len().is_multiple_of(2) {
234 return None;
235 }
236 let hexval = |c: u8| -> Option<u8> {
237 match c {
238 b'0'..=b'9' => Some(c - b'0'),
239 b'a'..=b'f' => Some(c - b'a' + 10),
240 b'A'..=b'F' => Some(c - b'A' + 10),
241 _ => None,
242 }
243 };
244 let mut out = alloc::vec::Vec::with_capacity(bytes.len() / 2);
245 let mut i = 0;
246 while i < bytes.len() {
247 out.push((hexval(bytes[i])? << 4) | hexval(bytes[i + 1])?);
248 i += 2;
249 }
250 Some(out)
251}
252
253fn arity(name: &str, args: &[Expr], n: usize) -> Result<()> {
254 if args.len() == n {
255 Ok(())
256 } else {
257 Err(Error::Error(alloc::format!(
258 "wrong number of arguments to function {name}() (want {n}, got {})",
259 args.len()
260 )))
261 }
262}
263
264fn str_map(v: &Value, f: impl Fn(&str) -> String) -> Value {
265 match v {
266 Value::Null => Value::Null,
267 other => Value::Text(f(&eval::to_text(other))),
268 }
269}
270
271fn type_name(v: &Value) -> &'static str {
272 match v {
273 Value::Null => "null",
274 Value::Integer(_) => "integer",
275 Value::Real(_) => "real",
276 Value::Text(_) => "text",
277 Value::Blob(_) => "blob",
278 }
279}
280
281fn trim_fn(v: &[Value], left: bool, right: bool) -> Value {
282 if v.is_empty() || matches!(v[0], Value::Null) {
283 return Value::Null;
284 }
285 let s = eval::to_text(&v[0]);
286 let trim_chars: Vec<char> = if v.len() >= 2 {
287 eval::to_text(&v[1]).chars().collect()
288 } else {
289 alloc::vec![' ']
290 };
291 let is_trim = |c: char| trim_chars.contains(&c);
292 let chars: Vec<char> = s.chars().collect();
293 let mut start = 0;
294 let mut end = chars.len();
295 if left {
296 while start < end && is_trim(chars[start]) {
297 start += 1;
298 }
299 }
300 if right {
301 while end > start && is_trim(chars[end - 1]) {
302 end -= 1;
303 }
304 }
305 Value::Text(chars[start..end].iter().collect())
306}
307
308fn substr(v: &[Value]) -> Result<Value> {
309 if v.len() < 2 || v.len() > 3 {
310 return Err(Error::Error("substr() takes 2 or 3 arguments".into()));
311 }
312 if matches!(v[0], Value::Null) {
313 return Ok(Value::Null);
314 }
315 let s: Vec<char> = eval::to_text(&v[0]).chars().collect();
316 let len = s.len() as i64;
317 let mut start = eval::to_i64(&v[1]);
321 if start < 0 {
322 start += len + 1;
323 }
324 let (wstart, wend) = if v.len() == 3 {
325 let z = eval::to_i64(&v[2]);
326 if z < 0 {
327 (start + z, start)
328 } else {
329 (start, start + z)
330 }
331 } else {
332 (start, len + 1)
333 };
334 let b = wstart.max(1);
335 let e = wend.min(len + 1);
336 if b >= e {
337 Ok(Value::Text(String::new()))
338 } else {
339 Ok(Value::Text(
340 s[(b - 1) as usize..(e - 1) as usize].iter().collect(),
341 ))
342 }
343}
344
345fn instr(v: &[Value]) -> Result<Value> {
346 if v.len() != 2 {
347 return Err(Error::Error("instr() takes 2 arguments".into()));
348 }
349 if matches!(v[0], Value::Null) || matches!(v[1], Value::Null) {
350 return Ok(Value::Null);
351 }
352 let hay = eval::to_text(&v[0]);
353 let needle = eval::to_text(&v[1]);
354 match hay.find(&needle) {
356 None => Ok(Value::Integer(0)),
357 Some(byte_idx) => {
358 let char_idx = hay[..byte_idx].chars().count();
359 Ok(Value::Integer(char_idx as i64 + 1))
360 }
361 }
362}
363
364fn replace(v: &[Value]) -> Result<Value> {
365 if v.len() != 3 {
366 return Err(Error::Error("replace() takes 3 arguments".into()));
367 }
368 if v.iter().any(|x| matches!(x, Value::Null)) {
369 return Ok(Value::Null);
370 }
371 let s = eval::to_text(&v[0]);
372 let from = eval::to_text(&v[1]);
373 let to = eval::to_text(&v[2]);
374 if from.is_empty() {
375 return Ok(Value::Text(s));
376 }
377 Ok(Value::Text(s.replace(&from, &to)))
378}
379
380fn round(v: &[Value]) -> Result<Value> {
381 if v.is_empty() || v.len() > 2 {
382 return Err(Error::Error("round() takes 1 or 2 arguments".into()));
383 }
384 if matches!(v[0], Value::Null) {
385 return Ok(Value::Null);
386 }
387 let x = eval::to_f64(&v[0]);
388 let digits = if v.len() == 2 {
389 eval::to_i64(&v[1]).max(0)
390 } else {
391 0
392 };
393 let factor = crate::util::float::powi(10.0, digits as i32);
394 Ok(Value::Real(crate::util::float::round(x * factor) / factor))
395}
396
397fn scalar_min_max(v: &[Value], want_min: bool) -> Value {
398 if v.iter().any(|x| matches!(x, Value::Null)) {
400 return Value::Null;
401 }
402 let mut best = v[0].clone();
403 for x in &v[1..] {
404 let ord = eval::compare(x, &best);
405 let take = if want_min {
406 ord == core::cmp::Ordering::Less
407 } else {
408 ord == core::cmp::Ordering::Greater
409 };
410 if take {
411 best = x.clone();
412 }
413 }
414 best
415}
416
417fn hex_encode(v: &Value) -> String {
418 let bytes = match v {
419 Value::Blob(b) => b.clone(),
420 other => eval::to_text(other).into_bytes(),
421 };
422 let mut s = String::with_capacity(bytes.len() * 2);
423 for b in bytes {
424 s.push(nibble(b >> 4));
425 s.push(nibble(b & 0xf));
426 }
427 s
428}
429
430fn nibble(n: u8) -> char {
431 match n {
432 0..=9 => (b'0' + n) as char,
433 _ => (b'A' + n - 10) as char,
434 }
435}
436
437fn char_fn(v: &[Value]) -> Value {
438 let mut s = String::new();
439 for x in v {
440 if let Some(c) = char::from_u32(eval::to_i64(x) as u32) {
441 s.push(c);
442 }
443 }
444 Value::Text(s)
445}