Skip to main content

graphitesql/exec/
func.rs

1//! Built-in scalar functions.
2//!
3//! Aggregate functions (`count`, `sum`, …) are handled by the executor, which
4//! folds over rows; this module covers the per-row scalar functions. The set is
5//! a useful core and grows toward SQLite's full library (`func.c`, `date.c`).
6
7use super::eval::{self, EvalCtx};
8use crate::error::{Error, Result};
9use crate::sql::ast::Expr;
10use crate::value::Value;
11use alloc::string::String;
12use alloc::vec::Vec;
13
14/// Names that *can* be aggregates (used for catalog/name checks).
15pub fn is_aggregate(name: &str) -> bool {
16    matches!(
17        name.to_ascii_lowercase().as_str(),
18        "count" | "sum" | "total" | "avg" | "min" | "max" | "group_concat"
19    )
20}
21
22/// Whether a *specific call* is an aggregate. `min`/`max` are scalar with 2+
23/// arguments and aggregate with exactly one (or `*`), matching SQLite.
24pub fn is_aggregate_call(name: &str, nargs: usize, star: bool) -> bool {
25    match name.to_ascii_lowercase().as_str() {
26        "count" | "sum" | "total" | "avg" | "group_concat" => true,
27        "min" | "max" => star || nargs == 1,
28        _ => false,
29    }
30}
31
32/// Evaluate a scalar function call.
33pub fn eval_scalar(name: &str, args: &[Expr], star: bool, ctx: &EvalCtx) -> Result<Value> {
34    let lname = name.to_ascii_lowercase();
35    if is_aggregate_call(&lname, args.len(), star) {
36        return Err(Error::Error(alloc::format!(
37            "aggregate function {name} used outside an aggregate context"
38        )));
39    }
40    if star {
41        return Err(Error::Error(alloc::format!(
42            "{name}(*) is not a scalar call"
43        )));
44    }
45
46    // Functions whose NULL-handling is special are done before arg evaluation.
47    match lname.as_str() {
48        "coalesce" => {
49            for a in args {
50                let v = eval::eval(a, ctx)?;
51                if !matches!(v, Value::Null) {
52                    return Ok(v);
53                }
54            }
55            return Ok(Value::Null);
56        }
57        "ifnull" => {
58            arity(&lname, args, 2)?;
59            let a = eval::eval(&args[0], ctx)?;
60            return if matches!(a, Value::Null) {
61                eval::eval(&args[1], ctx)
62            } else {
63                Ok(a)
64            };
65        }
66        _ => {}
67    }
68
69    let v: Vec<Value> = args
70        .iter()
71        .map(|a| eval::eval(a, ctx))
72        .collect::<Result<_>>()?;
73
74    Ok(match lname.as_str() {
75        "abs" => {
76            arity(&lname, args, 1)?;
77            match eval::to_number(&v[0]) {
78                Value::Integer(i) => Value::Integer(i.wrapping_abs()),
79                Value::Real(r) => Value::Real(crate::util::float::abs(r)),
80                _ => Value::Null,
81            }
82        }
83        "length" => {
84            arity(&lname, args, 1)?;
85            match &v[0] {
86                Value::Null => Value::Null,
87                Value::Blob(b) => Value::Integer(b.len() as i64),
88                other => Value::Integer(eval::to_text(other).chars().count() as i64),
89            }
90        }
91        "lower" => {
92            arity(&lname, args, 1)?;
93            str_map(&v[0], |s| s.to_lowercase())
94        }
95        "upper" => {
96            arity(&lname, args, 1)?;
97            str_map(&v[0], |s| s.to_uppercase())
98        }
99        "trim" => trim_fn(&v, true, true),
100        "ltrim" => trim_fn(&v, true, false),
101        "rtrim" => trim_fn(&v, false, true),
102        "typeof" => Value::Text(String::from(type_name(&v[0]))),
103        "nullif" => {
104            arity(&lname, args, 2)?;
105            if eval::compare(&v[0], &v[1]) == core::cmp::Ordering::Equal {
106                Value::Null
107            } else {
108                v[0].clone()
109            }
110        }
111        "n/a" => unreachable!(),
112        "substr" | "substring" => substr(&v)?,
113        "instr" => instr(&v)?,
114        "replace" => replace(&v)?,
115        "round" => round(&v)?,
116        "min" => scalar_min_max(&v, true),
117        "max" => scalar_min_max(&v, false),
118        "hex" => Value::Text(hex_encode(&v[0])),
119        "char" => char_fn(&v),
120        "unicode" => match &v[0] {
121            Value::Null => Value::Null,
122            other => eval::to_text(other)
123                .chars()
124                .next()
125                .map(|c| Value::Integer(c as i64))
126                .unwrap_or(Value::Null),
127        },
128        "iif" => {
129            arity(&lname, args, 3)?;
130            if eval::truth(&v[0]) == Some(true) {
131                v[1].clone()
132            } else {
133                v[2].clone()
134            }
135        }
136        "zeroblob" => {
137            arity(&lname, args, 1)?;
138            match &v[0] {
139                Value::Null => Value::Null,
140                other => {
141                    let n = eval::to_i64(other).max(0) as usize;
142                    Value::Blob(alloc::vec![0u8; n])
143                }
144            }
145        }
146        "quote" => {
147            arity(&lname, args, 1)?;
148            Value::Text(quote_value(&v[0]))
149        }
150        "sign" => {
151            arity(&lname, args, 1)?;
152            match eval::to_number(&v[0]) {
153                Value::Integer(i) => Value::Integer(i.signum()),
154                Value::Real(r) => Value::Integer(if r > 0.0 {
155                    1
156                } else if r < 0.0 {
157                    -1
158                } else {
159                    0
160                }),
161                _ => Value::Null,
162            }
163        }
164        "concat" => {
165            // SQLite 3.44+: concatenate all args, treating NULL as empty.
166            let mut s = String::new();
167            for x in &v {
168                if !matches!(x, Value::Null) {
169                    s.push_str(&eval::to_text(x));
170                }
171            }
172            Value::Text(s)
173        }
174        "concat_ws" => {
175            if v.is_empty() {
176                return Err(Error::Error("concat_ws() needs a separator".into()));
177            }
178            if matches!(v[0], Value::Null) {
179                Value::Null
180            } else {
181                let sep = eval::to_text(&v[0]);
182                let parts: alloc::vec::Vec<String> = v[1..]
183                    .iter()
184                    .filter(|x| !matches!(x, Value::Null))
185                    .map(eval::to_text)
186                    .collect();
187                Value::Text(parts.join(&sep))
188            }
189        }
190        "unhex" => {
191            arity(&lname, args, 1)?;
192            match &v[0] {
193                Value::Null => Value::Null,
194                other => match unhex(&eval::to_text(other)) {
195                    Some(b) => Value::Blob(b),
196                    None => Value::Null,
197                },
198            }
199        }
200        // Date/time functions (see `super::datetime`).
201        "date" => super::datetime::date(&v),
202        "time" => super::datetime::time(&v),
203        "datetime" => super::datetime::datetime(&v),
204        "julianday" => super::datetime::julianday(&v),
205        "unixepoch" => super::datetime::unixepoch(&v),
206        "strftime" => super::datetime::strftime(&v),
207        "printf" | "format" => super::datetime::printf(&v),
208        _ => return Err(Error::Unsupported("unknown scalar function")),
209    })
210}
211
212/// Render a value as a SQL literal, like SQLite's `quote()`.
213fn quote_value(v: &Value) -> String {
214    match v {
215        Value::Null => String::from("NULL"),
216        Value::Integer(i) => alloc::format!("{i}"),
217        Value::Real(r) => eval::format_real(*r),
218        Value::Text(s) => alloc::format!("'{}'", s.replace('\'', "''")),
219        Value::Blob(b) => {
220            let mut s = String::from("x'");
221            for byte in b {
222                s.push_str(&alloc::format!("{byte:02x}"));
223            }
224            s.push('\'');
225            s
226        }
227    }
228}
229
230/// Decode a hex string to bytes (even length, all hex digits), else `None`.
231fn unhex(s: &str) -> Option<alloc::vec::Vec<u8>> {
232    let bytes = s.as_bytes();
233    if !bytes.len().is_multiple_of(2) {
234        return None;
235    }
236    let hexval = |c: u8| -> Option<u8> {
237        match c {
238            b'0'..=b'9' => Some(c - b'0'),
239            b'a'..=b'f' => Some(c - b'a' + 10),
240            b'A'..=b'F' => Some(c - b'A' + 10),
241            _ => None,
242        }
243    };
244    let mut out = alloc::vec::Vec::with_capacity(bytes.len() / 2);
245    let mut i = 0;
246    while i < bytes.len() {
247        out.push((hexval(bytes[i])? << 4) | hexval(bytes[i + 1])?);
248        i += 2;
249    }
250    Some(out)
251}
252
253fn arity(name: &str, args: &[Expr], n: usize) -> Result<()> {
254    if args.len() == n {
255        Ok(())
256    } else {
257        Err(Error::Error(alloc::format!(
258            "wrong number of arguments to function {name}() (want {n}, got {})",
259            args.len()
260        )))
261    }
262}
263
264fn str_map(v: &Value, f: impl Fn(&str) -> String) -> Value {
265    match v {
266        Value::Null => Value::Null,
267        other => Value::Text(f(&eval::to_text(other))),
268    }
269}
270
271fn type_name(v: &Value) -> &'static str {
272    match v {
273        Value::Null => "null",
274        Value::Integer(_) => "integer",
275        Value::Real(_) => "real",
276        Value::Text(_) => "text",
277        Value::Blob(_) => "blob",
278    }
279}
280
281fn trim_fn(v: &[Value], left: bool, right: bool) -> Value {
282    if v.is_empty() || matches!(v[0], Value::Null) {
283        return Value::Null;
284    }
285    let s = eval::to_text(&v[0]);
286    let trim_chars: Vec<char> = if v.len() >= 2 {
287        eval::to_text(&v[1]).chars().collect()
288    } else {
289        alloc::vec![' ']
290    };
291    let is_trim = |c: char| trim_chars.contains(&c);
292    let chars: Vec<char> = s.chars().collect();
293    let mut start = 0;
294    let mut end = chars.len();
295    if left {
296        while start < end && is_trim(chars[start]) {
297            start += 1;
298        }
299    }
300    if right {
301        while end > start && is_trim(chars[end - 1]) {
302            end -= 1;
303        }
304    }
305    Value::Text(chars[start..end].iter().collect())
306}
307
308fn substr(v: &[Value]) -> Result<Value> {
309    if v.len() < 2 || v.len() > 3 {
310        return Err(Error::Error("substr() takes 2 or 3 arguments".into()));
311    }
312    if matches!(v[0], Value::Null) {
313        return Ok(Value::Null);
314    }
315    let s: Vec<char> = eval::to_text(&v[0]).chars().collect();
316    let len = s.len() as i64;
317    // 1-based start; a negative start counts from the end. Unlike a naive clamp,
318    // SQLite keeps the requested window and only the positions in 1..=len are
319    // returned, so `substr('hello',0,3)` yields "he" (positions 0,1,2 → 1,2).
320    let mut start = eval::to_i64(&v[1]);
321    if start < 0 {
322        start += len + 1;
323    }
324    let (wstart, wend) = if v.len() == 3 {
325        let z = eval::to_i64(&v[2]);
326        if z < 0 {
327            (start + z, start)
328        } else {
329            (start, start + z)
330        }
331    } else {
332        (start, len + 1)
333    };
334    let b = wstart.max(1);
335    let e = wend.min(len + 1);
336    if b >= e {
337        Ok(Value::Text(String::new()))
338    } else {
339        Ok(Value::Text(
340            s[(b - 1) as usize..(e - 1) as usize].iter().collect(),
341        ))
342    }
343}
344
345fn instr(v: &[Value]) -> Result<Value> {
346    if v.len() != 2 {
347        return Err(Error::Error("instr() takes 2 arguments".into()));
348    }
349    if matches!(v[0], Value::Null) || matches!(v[1], Value::Null) {
350        return Ok(Value::Null);
351    }
352    let hay = eval::to_text(&v[0]);
353    let needle = eval::to_text(&v[1]);
354    // SQLite returns a 1-based character index, 0 if not found.
355    match hay.find(&needle) {
356        None => Ok(Value::Integer(0)),
357        Some(byte_idx) => {
358            let char_idx = hay[..byte_idx].chars().count();
359            Ok(Value::Integer(char_idx as i64 + 1))
360        }
361    }
362}
363
364fn replace(v: &[Value]) -> Result<Value> {
365    if v.len() != 3 {
366        return Err(Error::Error("replace() takes 3 arguments".into()));
367    }
368    if v.iter().any(|x| matches!(x, Value::Null)) {
369        return Ok(Value::Null);
370    }
371    let s = eval::to_text(&v[0]);
372    let from = eval::to_text(&v[1]);
373    let to = eval::to_text(&v[2]);
374    if from.is_empty() {
375        return Ok(Value::Text(s));
376    }
377    Ok(Value::Text(s.replace(&from, &to)))
378}
379
380fn round(v: &[Value]) -> Result<Value> {
381    if v.is_empty() || v.len() > 2 {
382        return Err(Error::Error("round() takes 1 or 2 arguments".into()));
383    }
384    if matches!(v[0], Value::Null) {
385        return Ok(Value::Null);
386    }
387    let x = eval::to_f64(&v[0]);
388    let digits = if v.len() == 2 {
389        eval::to_i64(&v[1]).max(0)
390    } else {
391        0
392    };
393    let factor = crate::util::float::powi(10.0, digits as i32);
394    Ok(Value::Real(crate::util::float::round(x * factor) / factor))
395}
396
397fn scalar_min_max(v: &[Value], want_min: bool) -> Value {
398    // Scalar min()/max() with 2+ args; NULL if any arg is NULL.
399    if v.iter().any(|x| matches!(x, Value::Null)) {
400        return Value::Null;
401    }
402    let mut best = v[0].clone();
403    for x in &v[1..] {
404        let ord = eval::compare(x, &best);
405        let take = if want_min {
406            ord == core::cmp::Ordering::Less
407        } else {
408            ord == core::cmp::Ordering::Greater
409        };
410        if take {
411            best = x.clone();
412        }
413    }
414    best
415}
416
417fn hex_encode(v: &Value) -> String {
418    let bytes = match v {
419        Value::Blob(b) => b.clone(),
420        other => eval::to_text(other).into_bytes(),
421    };
422    let mut s = String::with_capacity(bytes.len() * 2);
423    for b in bytes {
424        s.push(nibble(b >> 4));
425        s.push(nibble(b & 0xf));
426    }
427    s
428}
429
430fn nibble(n: u8) -> char {
431    match n {
432        0..=9 => (b'0' + n) as char,
433        _ => (b'A' + n - 10) as char,
434    }
435}
436
437fn char_fn(v: &[Value]) -> Value {
438    let mut s = String::new();
439    for x in v {
440        if let Some(c) = char::from_u32(eval::to_i64(x) as u32) {
441            s.push(c);
442        }
443    }
444    Value::Text(s)
445}