Skip to main content

citadel_sql/
eval.rs

1//! Expression evaluator with SQL three-valued logic.
2
3use crate::error::{Result, SqlError};
4use crate::parser::{BinOp, Expr, UnaryOp};
5use crate::types::{ColumnDef, DataType, Value};
6
7/// Evaluate an expression against a row.
8///
9/// `columns` maps column names to their positions.
10/// `row` is the full row of values (all columns).
11pub fn eval_expr(expr: &Expr, columns: &[ColumnDef], row: &[Value]) -> Result<Value> {
12    match expr {
13        Expr::Literal(v) => Ok(v.clone()),
14
15        Expr::Column(name) => {
16            let lower = name.to_ascii_lowercase();
17            let matches: Vec<usize> = columns
18                .iter()
19                .enumerate()
20                .filter(|(_, c)| {
21                    let cn = c.name.to_ascii_lowercase();
22                    cn == lower || cn.ends_with(&format!(".{lower}"))
23                })
24                .map(|(i, _)| i)
25                .collect();
26            match matches.len() {
27                0 => Err(SqlError::ColumnNotFound(name.clone())),
28                1 => Ok(row[matches[0]].clone()),
29                _ => Err(SqlError::AmbiguousColumn(name.clone())),
30            }
31        }
32
33        Expr::QualifiedColumn { table, column } => {
34            let qualified = format!(
35                "{}.{}",
36                table.to_ascii_lowercase(),
37                column.to_ascii_lowercase()
38            );
39            let idx = columns
40                .iter()
41                .position(|c| c.name.to_ascii_lowercase() == qualified)
42                .or_else(|| {
43                    let lower_col = column.to_ascii_lowercase();
44                    let matches: Vec<usize> = columns
45                        .iter()
46                        .enumerate()
47                        .filter(|(_, c)| c.name.to_ascii_lowercase() == lower_col)
48                        .map(|(i, _)| i)
49                        .collect();
50                    if matches.len() == 1 {
51                        Some(matches[0])
52                    } else {
53                        None
54                    }
55                })
56                .ok_or_else(|| SqlError::ColumnNotFound(format!("{table}.{column}")))?;
57            Ok(row[idx].clone())
58        }
59
60        Expr::BinaryOp { left, op, right } => {
61            let lval = eval_expr(left, columns, row)?;
62            let rval = eval_expr(right, columns, row)?;
63            eval_binary_op(&lval, *op, &rval)
64        }
65
66        Expr::UnaryOp { op, expr } => {
67            let val = eval_expr(expr, columns, row)?;
68            eval_unary_op(*op, &val)
69        }
70
71        Expr::IsNull(e) => {
72            let val = eval_expr(e, columns, row)?;
73            Ok(Value::Boolean(val.is_null()))
74        }
75
76        Expr::IsNotNull(e) => {
77            let val = eval_expr(e, columns, row)?;
78            Ok(Value::Boolean(!val.is_null()))
79        }
80
81        Expr::Function { name, args } => eval_scalar_function(name, args, columns, row),
82
83        Expr::CountStar => Err(SqlError::Unsupported(
84            "COUNT(*) in non-aggregate context".into(),
85        )),
86
87        Expr::InList {
88            expr: e,
89            list,
90            negated,
91        } => {
92            let lhs = eval_expr(e, columns, row)?;
93            eval_in_values(&lhs, list, columns, row, *negated)
94        }
95
96        Expr::InSet {
97            expr: e,
98            values,
99            has_null,
100            negated,
101        } => {
102            let lhs = eval_expr(e, columns, row)?;
103            eval_in_set(&lhs, values, *has_null, *negated)
104        }
105
106        Expr::Between {
107            expr: e,
108            low,
109            high,
110            negated,
111        } => {
112            let val = eval_expr(e, columns, row)?;
113            let lo = eval_expr(low, columns, row)?;
114            let hi = eval_expr(high, columns, row)?;
115            eval_between(&val, &lo, &hi, *negated)
116        }
117
118        Expr::Like {
119            expr: e,
120            pattern,
121            escape,
122            negated,
123        } => {
124            let val = eval_expr(e, columns, row)?;
125            let pat = eval_expr(pattern, columns, row)?;
126            let esc = escape
127                .as_ref()
128                .map(|e| eval_expr(e, columns, row))
129                .transpose()?;
130            eval_like(&val, &pat, esc.as_ref(), *negated)
131        }
132
133        Expr::Case {
134            operand,
135            conditions,
136            else_result,
137        } => eval_case(
138            operand.as_deref(),
139            conditions,
140            else_result.as_deref(),
141            columns,
142            row,
143        ),
144
145        Expr::Coalesce(args) => {
146            for arg in args {
147                let val = eval_expr(arg, columns, row)?;
148                if !val.is_null() {
149                    return Ok(val);
150                }
151            }
152            Ok(Value::Null)
153        }
154
155        Expr::Cast { expr: e, data_type } => {
156            let val = eval_expr(e, columns, row)?;
157            eval_cast(&val, *data_type)
158        }
159
160        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
161            SqlError::Unsupported("subquery not materialized (internal error)".into()),
162        ),
163
164        Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
165    }
166}
167
168fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
169    // SQL three-valued logic for AND/OR
170    match op {
171        BinOp::And => return eval_and(left, right),
172        BinOp::Or => return eval_or(left, right),
173        _ => {}
174    }
175
176    // NULL propagation for all other ops (including || per SQL standard)
177    if left.is_null() || right.is_null() {
178        return Ok(Value::Null);
179    }
180
181    match op {
182        BinOp::Eq => Ok(Value::Boolean(left == right)),
183        BinOp::NotEq => Ok(Value::Boolean(left != right)),
184        BinOp::Lt => Ok(Value::Boolean(left < right)),
185        BinOp::Gt => Ok(Value::Boolean(left > right)),
186        BinOp::LtEq => Ok(Value::Boolean(left <= right)),
187        BinOp::GtEq => Ok(Value::Boolean(left >= right)),
188        BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
189        BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
190        BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
191        BinOp::Div => {
192            match right {
193                Value::Integer(0) => return Err(SqlError::DivisionByZero),
194                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
195                _ => {}
196            }
197            eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
198        }
199        BinOp::Mod => {
200            match right {
201                Value::Integer(0) => return Err(SqlError::DivisionByZero),
202                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
203                _ => {}
204            }
205            eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
206        }
207        BinOp::Concat => {
208            let ls = value_to_text(left);
209            let rs = value_to_text(right);
210            Ok(Value::Text(format!("{ls}{rs}")))
211        }
212        BinOp::And | BinOp::Or => unreachable!(),
213    }
214}
215
216/// SQL three-valued AND: NULL AND false = false, NULL AND true = NULL
217fn eval_and(left: &Value, right: &Value) -> Result<Value> {
218    let l = to_bool_or_null(left)?;
219    let r = to_bool_or_null(right)?;
220    match (l, r) {
221        (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
222        (Some(true), Some(true)) => Ok(Value::Boolean(true)),
223        _ => Ok(Value::Null),
224    }
225}
226
227/// SQL three-valued OR: NULL OR true = true, NULL OR false = NULL
228fn eval_or(left: &Value, right: &Value) -> Result<Value> {
229    let l = to_bool_or_null(left)?;
230    let r = to_bool_or_null(right)?;
231    match (l, r) {
232        (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
233        (Some(false), Some(false)) => Ok(Value::Boolean(false)),
234        _ => Ok(Value::Null),
235    }
236}
237
238fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
239    match val {
240        Value::Boolean(b) => Ok(Some(*b)),
241        Value::Null => Ok(None),
242        Value::Integer(i) => Ok(Some(*i != 0)),
243        _ => Err(SqlError::TypeMismatch {
244            expected: "BOOLEAN".into(),
245            got: format!("{}", val.data_type()),
246        }),
247    }
248}
249
250fn eval_arithmetic(
251    left: &Value,
252    right: &Value,
253    int_op: fn(i64, i64) -> Option<i64>,
254    real_op: fn(f64, f64) -> f64,
255) -> Result<Value> {
256    match (left, right) {
257        (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
258            .map(Value::Integer)
259            .ok_or(SqlError::IntegerOverflow),
260        (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
261        (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
262        (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
263        _ => Err(SqlError::TypeMismatch {
264            expected: "numeric".into(),
265            got: format!("{} and {}", left.data_type(), right.data_type()),
266        }),
267    }
268}
269
270fn eval_in_values(
271    lhs: &Value,
272    list: &[Expr],
273    columns: &[ColumnDef],
274    row: &[Value],
275    negated: bool,
276) -> Result<Value> {
277    if list.is_empty() {
278        return Ok(Value::Boolean(negated));
279    }
280    if lhs.is_null() {
281        return Ok(Value::Null);
282    }
283    let mut has_null = false;
284    for item in list {
285        let rhs = eval_expr(item, columns, row)?;
286        if rhs.is_null() {
287            has_null = true;
288        } else if lhs == &rhs {
289            return Ok(Value::Boolean(!negated));
290        }
291    }
292    if has_null {
293        Ok(Value::Null)
294    } else {
295        Ok(Value::Boolean(negated))
296    }
297}
298
299fn eval_in_set(
300    lhs: &Value,
301    values: &std::collections::HashSet<Value>,
302    has_null: bool,
303    negated: bool,
304) -> Result<Value> {
305    if values.is_empty() && !has_null {
306        return Ok(Value::Boolean(negated));
307    }
308    if lhs.is_null() {
309        return Ok(Value::Null);
310    }
311    if values.contains(lhs) {
312        return Ok(Value::Boolean(!negated));
313    }
314    if has_null {
315        Ok(Value::Null)
316    } else {
317        Ok(Value::Boolean(negated))
318    }
319}
320
321fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
322    if val.is_null() {
323        return Ok(Value::Null);
324    }
325    match op {
326        UnaryOp::Neg => match val {
327            Value::Integer(i) => i
328                .checked_neg()
329                .map(Value::Integer)
330                .ok_or(SqlError::IntegerOverflow),
331            Value::Real(r) => Ok(Value::Real(-r)),
332            _ => Err(SqlError::TypeMismatch {
333                expected: "numeric".into(),
334                got: format!("{}", val.data_type()),
335            }),
336        },
337        UnaryOp::Not => match val {
338            Value::Boolean(b) => Ok(Value::Boolean(!b)),
339            Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
340            _ => Err(SqlError::TypeMismatch {
341                expected: "BOOLEAN".into(),
342                got: format!("{}", val.data_type()),
343            }),
344        },
345    }
346}
347
348fn value_to_text(val: &Value) -> String {
349    match val {
350        Value::Text(s) => s.clone(),
351        Value::Integer(i) => i.to_string(),
352        Value::Real(r) => {
353            if r.fract() == 0.0 && r.is_finite() {
354                format!("{r:.1}")
355            } else {
356                format!("{r}")
357            }
358        }
359        Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
360        Value::Null => String::new(),
361        Value::Blob(b) => {
362            let mut s = String::with_capacity(b.len() * 2);
363            for byte in b {
364                s.push_str(&format!("{byte:02X}"));
365            }
366            s
367        }
368    }
369}
370
371fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
372    if val.is_null() || low.is_null() || high.is_null() {
373        let ge = if val.is_null() || low.is_null() {
374            None
375        } else {
376            Some(*val >= *low)
377        };
378        let le = if val.is_null() || high.is_null() {
379            None
380        } else {
381            Some(*val <= *high)
382        };
383
384        let result = match (ge, le) {
385            (Some(false), _) | (_, Some(false)) => Some(false),
386            (Some(true), Some(true)) => Some(true),
387            _ => None,
388        };
389
390        return match result {
391            Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
392            None => Ok(Value::Null),
393        };
394    }
395
396    let in_range = *val >= *low && *val <= *high;
397    Ok(Value::Boolean(if negated { !in_range } else { in_range }))
398}
399
400const MAX_LIKE_PATTERN_LEN: usize = 10_000;
401
402fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
403    if val.is_null() || pattern.is_null() {
404        return Ok(Value::Null);
405    }
406    let text = match val {
407        Value::Text(s) => s.as_str(),
408        _ => {
409            return Err(SqlError::TypeMismatch {
410                expected: "TEXT".into(),
411                got: val.data_type().to_string(),
412            })
413        }
414    };
415    let pat = match pattern {
416        Value::Text(s) => s.as_str(),
417        _ => {
418            return Err(SqlError::TypeMismatch {
419                expected: "TEXT".into(),
420                got: pattern.data_type().to_string(),
421            })
422        }
423    };
424
425    if pat.len() > MAX_LIKE_PATTERN_LEN {
426        return Err(SqlError::InvalidValue(format!(
427            "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
428            pat.len()
429        )));
430    }
431
432    let esc_char = match escape {
433        Some(Value::Text(s)) => {
434            let mut chars = s.chars();
435            let c = chars.next().ok_or_else(|| {
436                SqlError::InvalidValue("ESCAPE must be a single character".into())
437            })?;
438            if chars.next().is_some() {
439                return Err(SqlError::InvalidValue(
440                    "ESCAPE must be a single character".into(),
441                ));
442            }
443            Some(c)
444        }
445        Some(Value::Null) => return Ok(Value::Null),
446        Some(_) => {
447            return Err(SqlError::TypeMismatch {
448                expected: "TEXT".into(),
449                got: "non-text".into(),
450            })
451        }
452        None => None,
453    };
454
455    let matched = like_match(text, pat, esc_char);
456    Ok(Value::Boolean(if negated { !matched } else { matched }))
457}
458
459fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
460    let t: Vec<char> = text.chars().collect();
461    let p: Vec<char> = pattern.chars().collect();
462    like_match_impl(&t, &p, 0, 0, escape)
463}
464
465fn like_match_impl(
466    t: &[char],
467    p: &[char],
468    mut ti: usize,
469    mut pi: usize,
470    esc: Option<char>,
471) -> bool {
472    let mut star_pi: Option<usize> = None;
473    let mut star_ti: usize = 0;
474
475    while ti < t.len() {
476        if pi < p.len() {
477            if let Some(ec) = esc {
478                if p[pi] == ec && pi + 1 < p.len() {
479                    pi += 1;
480                    let pc_lower = p[pi].to_ascii_lowercase();
481                    let tc_lower = t[ti].to_ascii_lowercase();
482                    if pc_lower == tc_lower {
483                        pi += 1;
484                        ti += 1;
485                        continue;
486                    } else if let Some(sp) = star_pi {
487                        pi = sp + 1;
488                        star_ti += 1;
489                        ti = star_ti;
490                        continue;
491                    } else {
492                        return false;
493                    }
494                }
495            }
496            if p[pi] == '%' {
497                star_pi = Some(pi);
498                star_ti = ti;
499                pi += 1;
500                continue;
501            }
502            if p[pi] == '_' {
503                pi += 1;
504                ti += 1;
505                continue;
506            }
507            if p[pi].eq_ignore_ascii_case(&t[ti]) {
508                pi += 1;
509                ti += 1;
510                continue;
511            }
512        }
513        if let Some(sp) = star_pi {
514            pi = sp + 1;
515            star_ti += 1;
516            ti = star_ti;
517        } else {
518            return false;
519        }
520    }
521
522    while pi < p.len() && p[pi] == '%' {
523        pi += 1;
524    }
525    pi == p.len()
526}
527
528fn eval_case(
529    operand: Option<&Expr>,
530    conditions: &[(Expr, Expr)],
531    else_result: Option<&Expr>,
532    columns: &[ColumnDef],
533    row: &[Value],
534) -> Result<Value> {
535    if let Some(op_expr) = operand {
536        let op_val = eval_expr(op_expr, columns, row)?;
537        for (cond, result) in conditions {
538            let cond_val = eval_expr(cond, columns, row)?;
539            if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
540                return eval_expr(result, columns, row);
541            }
542        }
543    } else {
544        for (cond, result) in conditions {
545            let cond_val = eval_expr(cond, columns, row)?;
546            if is_truthy(&cond_val) {
547                return eval_expr(result, columns, row);
548            }
549        }
550    }
551    match else_result {
552        Some(e) => eval_expr(e, columns, row),
553        None => Ok(Value::Null),
554    }
555}
556
557fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
558    if val.is_null() {
559        return Ok(Value::Null);
560    }
561    match target {
562        DataType::Integer => match val {
563            Value::Integer(_) => Ok(val.clone()),
564            Value::Real(r) => Ok(Value::Integer(*r as i64)),
565            Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
566            Value::Text(s) => s
567                .trim()
568                .parse::<i64>()
569                .map(Value::Integer)
570                .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
571                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
572            _ => Err(SqlError::InvalidValue(format!(
573                "cannot cast {} to INTEGER",
574                val.data_type()
575            ))),
576        },
577        DataType::Real => match val {
578            Value::Real(_) => Ok(val.clone()),
579            Value::Integer(i) => Ok(Value::Real(*i as f64)),
580            Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
581            Value::Text(s) => s
582                .trim()
583                .parse::<f64>()
584                .map(Value::Real)
585                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
586            _ => Err(SqlError::InvalidValue(format!(
587                "cannot cast {} to REAL",
588                val.data_type()
589            ))),
590        },
591        DataType::Text => Ok(Value::Text(value_to_text(val))),
592        DataType::Boolean => match val {
593            Value::Boolean(_) => Ok(val.clone()),
594            Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
595            Value::Text(s) => {
596                let lower = s.trim().to_ascii_lowercase();
597                match lower.as_str() {
598                    "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
599                    "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
600                    _ => Err(SqlError::InvalidValue(format!(
601                        "cannot cast '{s}' to BOOLEAN"
602                    ))),
603                }
604            }
605            _ => Err(SqlError::InvalidValue(format!(
606                "cannot cast {} to BOOLEAN",
607                val.data_type()
608            ))),
609        },
610        DataType::Blob => match val {
611            Value::Blob(_) => Ok(val.clone()),
612            Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
613            _ => Err(SqlError::InvalidValue(format!(
614                "cannot cast {} to BLOB",
615                val.data_type()
616            ))),
617        },
618        DataType::Null => Ok(Value::Null),
619    }
620}
621
622fn eval_scalar_function(
623    name: &str,
624    args: &[Expr],
625    columns: &[ColumnDef],
626    row: &[Value],
627) -> Result<Value> {
628    let evaluated: Vec<Value> = args
629        .iter()
630        .map(|a| eval_expr(a, columns, row))
631        .collect::<Result<Vec<_>>>()?;
632
633    match name {
634        "LENGTH" => {
635            check_args(name, &evaluated, 1)?;
636            match &evaluated[0] {
637                Value::Null => Ok(Value::Null),
638                Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
639                Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
640                _ => Ok(Value::Integer(
641                    value_to_text(&evaluated[0]).chars().count() as i64
642                )),
643            }
644        }
645        "UPPER" => {
646            check_args(name, &evaluated, 1)?;
647            match &evaluated[0] {
648                Value::Null => Ok(Value::Null),
649                Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
650                _ => Ok(Value::Text(
651                    value_to_text(&evaluated[0]).to_ascii_uppercase(),
652                )),
653            }
654        }
655        "LOWER" => {
656            check_args(name, &evaluated, 1)?;
657            match &evaluated[0] {
658                Value::Null => Ok(Value::Null),
659                Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
660                _ => Ok(Value::Text(
661                    value_to_text(&evaluated[0]).to_ascii_lowercase(),
662                )),
663            }
664        }
665        "SUBSTR" | "SUBSTRING" => {
666            if evaluated.len() < 2 || evaluated.len() > 3 {
667                return Err(SqlError::InvalidValue(format!(
668                    "{name} requires 2 or 3 arguments"
669                )));
670            }
671            if evaluated.iter().any(|v| v.is_null()) {
672                return Ok(Value::Null);
673            }
674            let s = value_to_text(&evaluated[0]);
675            let chars: Vec<char> = s.chars().collect();
676            let start = match &evaluated[1] {
677                Value::Integer(i) => *i,
678                _ => {
679                    return Err(SqlError::TypeMismatch {
680                        expected: "INTEGER".into(),
681                        got: evaluated[1].data_type().to_string(),
682                    })
683                }
684            };
685            let len = chars.len() as i64;
686
687            let (begin, count) = if evaluated.len() == 3 {
688                let cnt = match &evaluated[2] {
689                    Value::Integer(i) => *i,
690                    _ => {
691                        return Err(SqlError::TypeMismatch {
692                            expected: "INTEGER".into(),
693                            got: evaluated[2].data_type().to_string(),
694                        })
695                    }
696                };
697                if start >= 1 {
698                    let b = (start - 1).min(len) as usize;
699                    let c = cnt.max(0) as usize;
700                    (b, c)
701                } else if start == 0 {
702                    let c = (cnt - 1).max(0) as usize;
703                    (0usize, c)
704                } else {
705                    let adjusted_cnt = (cnt + start - 1).max(0) as usize;
706                    (0usize, adjusted_cnt)
707                }
708            } else if start >= 1 {
709                let b = (start - 1).min(len) as usize;
710                (b, chars.len() - b)
711            } else if start == 0 {
712                (0usize, chars.len())
713            } else {
714                let b = (len + start).max(0) as usize;
715                (b, chars.len() - b)
716            };
717
718            let result: String = chars.iter().skip(begin).take(count).collect();
719            Ok(Value::Text(result))
720        }
721        "TRIM" | "LTRIM" | "RTRIM" => {
722            if evaluated.is_empty() || evaluated.len() > 2 {
723                return Err(SqlError::InvalidValue(format!(
724                    "{name} requires 1 or 2 arguments"
725                )));
726            }
727            if evaluated[0].is_null() {
728                return Ok(Value::Null);
729            }
730            let s = value_to_text(&evaluated[0]);
731            let trim_chars: Vec<char> = if evaluated.len() == 2 {
732                if evaluated[1].is_null() {
733                    return Ok(Value::Null);
734                }
735                value_to_text(&evaluated[1]).chars().collect()
736            } else {
737                vec![' ']
738            };
739            let result = match name {
740                "TRIM" => s
741                    .trim_matches(|c: char| trim_chars.contains(&c))
742                    .to_string(),
743                "LTRIM" => s
744                    .trim_start_matches(|c: char| trim_chars.contains(&c))
745                    .to_string(),
746                "RTRIM" => s
747                    .trim_end_matches(|c: char| trim_chars.contains(&c))
748                    .to_string(),
749                _ => unreachable!(),
750            };
751            Ok(Value::Text(result))
752        }
753        "REPLACE" => {
754            check_args(name, &evaluated, 3)?;
755            if evaluated.iter().any(|v| v.is_null()) {
756                return Ok(Value::Null);
757            }
758            let s = value_to_text(&evaluated[0]);
759            let from = value_to_text(&evaluated[1]);
760            let to = value_to_text(&evaluated[2]);
761            if from.is_empty() {
762                return Ok(Value::Text(s));
763            }
764            Ok(Value::Text(s.replace(&from, &to)))
765        }
766        "INSTR" => {
767            check_args(name, &evaluated, 2)?;
768            if evaluated.iter().any(|v| v.is_null()) {
769                return Ok(Value::Null);
770            }
771            let haystack = value_to_text(&evaluated[0]);
772            let needle = value_to_text(&evaluated[1]);
773            let pos = haystack
774                .find(&needle)
775                .map(|i| haystack[..i].chars().count() as i64 + 1)
776                .unwrap_or(0);
777            Ok(Value::Integer(pos))
778        }
779        "CONCAT" => {
780            if evaluated.is_empty() {
781                return Ok(Value::Text(String::new()));
782            }
783            let mut result = String::new();
784            for v in &evaluated {
785                match v {
786                    Value::Null => {}
787                    _ => result.push_str(&value_to_text(v)),
788                }
789            }
790            Ok(Value::Text(result))
791        }
792        "ABS" => {
793            check_args(name, &evaluated, 1)?;
794            match &evaluated[0] {
795                Value::Null => Ok(Value::Null),
796                Value::Integer(i) => i
797                    .checked_abs()
798                    .map(Value::Integer)
799                    .ok_or(SqlError::IntegerOverflow),
800                Value::Real(r) => Ok(Value::Real(r.abs())),
801                _ => Err(SqlError::TypeMismatch {
802                    expected: "numeric".into(),
803                    got: evaluated[0].data_type().to_string(),
804                }),
805            }
806        }
807        "ROUND" => {
808            if evaluated.is_empty() || evaluated.len() > 2 {
809                return Err(SqlError::InvalidValue(
810                    "ROUND requires 1 or 2 arguments".into(),
811                ));
812            }
813            if evaluated[0].is_null() {
814                return Ok(Value::Null);
815            }
816            let val = match &evaluated[0] {
817                Value::Integer(i) => *i as f64,
818                Value::Real(r) => *r,
819                _ => {
820                    return Err(SqlError::TypeMismatch {
821                        expected: "numeric".into(),
822                        got: evaluated[0].data_type().to_string(),
823                    })
824                }
825            };
826            let places = if evaluated.len() == 2 {
827                match &evaluated[1] {
828                    Value::Null => return Ok(Value::Null),
829                    Value::Integer(i) => *i,
830                    _ => {
831                        return Err(SqlError::TypeMismatch {
832                            expected: "INTEGER".into(),
833                            got: evaluated[1].data_type().to_string(),
834                        })
835                    }
836                }
837            } else {
838                0
839            };
840            let factor = 10f64.powi(places as i32);
841            let rounded = (val * factor).round() / factor;
842            Ok(Value::Real(rounded))
843        }
844        "CEIL" | "CEILING" => {
845            check_args(name, &evaluated, 1)?;
846            match &evaluated[0] {
847                Value::Null => Ok(Value::Null),
848                Value::Integer(i) => Ok(Value::Integer(*i)),
849                Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
850                _ => Err(SqlError::TypeMismatch {
851                    expected: "numeric".into(),
852                    got: evaluated[0].data_type().to_string(),
853                }),
854            }
855        }
856        "FLOOR" => {
857            check_args(name, &evaluated, 1)?;
858            match &evaluated[0] {
859                Value::Null => Ok(Value::Null),
860                Value::Integer(i) => Ok(Value::Integer(*i)),
861                Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
862                _ => Err(SqlError::TypeMismatch {
863                    expected: "numeric".into(),
864                    got: evaluated[0].data_type().to_string(),
865                }),
866            }
867        }
868        "SIGN" => {
869            check_args(name, &evaluated, 1)?;
870            match &evaluated[0] {
871                Value::Null => Ok(Value::Null),
872                Value::Integer(i) => Ok(Value::Integer(i.signum())),
873                Value::Real(r) => {
874                    if *r > 0.0 {
875                        Ok(Value::Integer(1))
876                    } else if *r < 0.0 {
877                        Ok(Value::Integer(-1))
878                    } else {
879                        Ok(Value::Integer(0))
880                    }
881                }
882                _ => Err(SqlError::TypeMismatch {
883                    expected: "numeric".into(),
884                    got: evaluated[0].data_type().to_string(),
885                }),
886            }
887        }
888        "SQRT" => {
889            check_args(name, &evaluated, 1)?;
890            match &evaluated[0] {
891                Value::Null => Ok(Value::Null),
892                Value::Integer(i) => {
893                    if *i < 0 {
894                        Ok(Value::Null)
895                    } else {
896                        Ok(Value::Real((*i as f64).sqrt()))
897                    }
898                }
899                Value::Real(r) => {
900                    if *r < 0.0 {
901                        Ok(Value::Null)
902                    } else {
903                        Ok(Value::Real(r.sqrt()))
904                    }
905                }
906                _ => Err(SqlError::TypeMismatch {
907                    expected: "numeric".into(),
908                    got: evaluated[0].data_type().to_string(),
909                }),
910            }
911        }
912        "RANDOM" => {
913            check_args(name, &evaluated, 0)?;
914            use std::collections::hash_map::DefaultHasher;
915            use std::hash::{Hash, Hasher};
916            use std::time::SystemTime;
917            let mut hasher = DefaultHasher::new();
918            SystemTime::now().hash(&mut hasher);
919            std::thread::current().id().hash(&mut hasher);
920            let mut val = hasher.finish() as i64;
921            if val == i64::MIN {
922                val = i64::MAX;
923            }
924            Ok(Value::Integer(val))
925        }
926        "TYPEOF" => {
927            check_args(name, &evaluated, 1)?;
928            let type_name = match &evaluated[0] {
929                Value::Null => "null",
930                Value::Integer(_) => "integer",
931                Value::Real(_) => "real",
932                Value::Text(_) => "text",
933                Value::Blob(_) => "blob",
934                Value::Boolean(_) => "boolean",
935            };
936            Ok(Value::Text(type_name.into()))
937        }
938        "MIN" => {
939            check_args(name, &evaluated, 2)?;
940            if evaluated[0].is_null() {
941                return Ok(evaluated[1].clone());
942            }
943            if evaluated[1].is_null() {
944                return Ok(evaluated[0].clone());
945            }
946            if evaluated[0] <= evaluated[1] {
947                Ok(evaluated[0].clone())
948            } else {
949                Ok(evaluated[1].clone())
950            }
951        }
952        "MAX" => {
953            check_args(name, &evaluated, 2)?;
954            if evaluated[0].is_null() {
955                return Ok(evaluated[1].clone());
956            }
957            if evaluated[1].is_null() {
958                return Ok(evaluated[0].clone());
959            }
960            if evaluated[0] >= evaluated[1] {
961                Ok(evaluated[0].clone())
962            } else {
963                Ok(evaluated[1].clone())
964            }
965        }
966        "HEX" => {
967            check_args(name, &evaluated, 1)?;
968            match &evaluated[0] {
969                Value::Null => Ok(Value::Null),
970                Value::Blob(b) => {
971                    let mut s = String::with_capacity(b.len() * 2);
972                    for byte in b {
973                        s.push_str(&format!("{byte:02X}"));
974                    }
975                    Ok(Value::Text(s))
976                }
977                Value::Text(s) => {
978                    let mut r = String::with_capacity(s.len() * 2);
979                    for byte in s.as_bytes() {
980                        r.push_str(&format!("{byte:02X}"));
981                    }
982                    Ok(Value::Text(r))
983                }
984                _ => Ok(Value::Text(value_to_text(&evaluated[0]))),
985            }
986        }
987        _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
988    }
989}
990
991fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
992    if args.len() != expected {
993        Err(SqlError::InvalidValue(format!(
994            "{name} requires {expected} argument(s), got {}",
995            args.len()
996        )))
997    } else {
998        Ok(())
999    }
1000}
1001
1002/// Check if an expression result is truthy (for WHERE/HAVING).
1003pub fn is_truthy(val: &Value) -> bool {
1004    match val {
1005        Value::Boolean(b) => *b,
1006        Value::Integer(i) => *i != 0,
1007        Value::Null => false,
1008        _ => true,
1009    }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015    use crate::types::DataType;
1016
1017    fn test_columns() -> Vec<ColumnDef> {
1018        vec![
1019            ColumnDef {
1020                name: "id".into(),
1021                data_type: DataType::Integer,
1022                nullable: false,
1023                position: 0,
1024            },
1025            ColumnDef {
1026                name: "name".into(),
1027                data_type: DataType::Text,
1028                nullable: true,
1029                position: 1,
1030            },
1031            ColumnDef {
1032                name: "score".into(),
1033                data_type: DataType::Real,
1034                nullable: true,
1035                position: 2,
1036            },
1037            ColumnDef {
1038                name: "active".into(),
1039                data_type: DataType::Boolean,
1040                nullable: false,
1041                position: 3,
1042            },
1043        ]
1044    }
1045
1046    fn test_row() -> Vec<Value> {
1047        vec![
1048            Value::Integer(1),
1049            Value::Text("Alice".into()),
1050            Value::Real(95.5),
1051            Value::Boolean(true),
1052        ]
1053    }
1054
1055    #[test]
1056    fn eval_literal() {
1057        let cols = test_columns();
1058        let row = test_row();
1059        let expr = Expr::Literal(Value::Integer(42));
1060        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(42));
1061    }
1062
1063    #[test]
1064    fn eval_column_ref() {
1065        let cols = test_columns();
1066        let row = test_row();
1067        let expr = Expr::Column("name".into());
1068        assert_eq!(
1069            eval_expr(&expr, &cols, &row).unwrap(),
1070            Value::Text("Alice".into())
1071        );
1072    }
1073
1074    #[test]
1075    fn eval_column_case_insensitive() {
1076        let cols = test_columns();
1077        let row = test_row();
1078        let expr = Expr::Column("NAME".into());
1079        assert_eq!(
1080            eval_expr(&expr, &cols, &row).unwrap(),
1081            Value::Text("Alice".into())
1082        );
1083    }
1084
1085    #[test]
1086    fn eval_arithmetic_int() {
1087        let cols = test_columns();
1088        let row = test_row();
1089        let expr = Expr::BinaryOp {
1090            left: Box::new(Expr::Column("id".into())),
1091            op: BinOp::Add,
1092            right: Box::new(Expr::Literal(Value::Integer(10))),
1093        };
1094        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(11));
1095    }
1096
1097    #[test]
1098    fn eval_comparison() {
1099        let cols = test_columns();
1100        let row = test_row();
1101        let expr = Expr::BinaryOp {
1102            left: Box::new(Expr::Column("score".into())),
1103            op: BinOp::Gt,
1104            right: Box::new(Expr::Literal(Value::Real(90.0))),
1105        };
1106        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1107    }
1108
1109    #[test]
1110    fn eval_null_propagation() {
1111        let cols = test_columns();
1112        let row = vec![
1113            Value::Integer(1),
1114            Value::Null,
1115            Value::Null,
1116            Value::Boolean(true),
1117        ];
1118        let expr = Expr::BinaryOp {
1119            left: Box::new(Expr::Column("name".into())),
1120            op: BinOp::Eq,
1121            right: Box::new(Expr::Literal(Value::Text("test".into()))),
1122        };
1123        assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1124    }
1125
1126    #[test]
1127    fn eval_and_three_valued() {
1128        let cols = test_columns();
1129        let row = vec![
1130            Value::Integer(1),
1131            Value::Null,
1132            Value::Null,
1133            Value::Boolean(true),
1134        ];
1135
1136        // NULL AND false = false
1137        let expr = Expr::BinaryOp {
1138            left: Box::new(Expr::Column("name".into())),
1139            op: BinOp::And,
1140            right: Box::new(Expr::Literal(Value::Boolean(false))),
1141        };
1142        assert_eq!(
1143            eval_expr(&expr, &cols, &row).unwrap(),
1144            Value::Boolean(false)
1145        );
1146
1147        // NULL AND true = NULL
1148        let expr = Expr::BinaryOp {
1149            left: Box::new(Expr::Column("name".into())),
1150            op: BinOp::And,
1151            right: Box::new(Expr::Literal(Value::Boolean(true))),
1152        };
1153        assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1154    }
1155
1156    #[test]
1157    fn eval_or_three_valued() {
1158        let cols = test_columns();
1159        let row = vec![
1160            Value::Integer(1),
1161            Value::Null,
1162            Value::Null,
1163            Value::Boolean(true),
1164        ];
1165
1166        // NULL OR true = true
1167        let expr = Expr::BinaryOp {
1168            left: Box::new(Expr::Column("name".into())),
1169            op: BinOp::Or,
1170            right: Box::new(Expr::Literal(Value::Boolean(true))),
1171        };
1172        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1173
1174        // NULL OR false = NULL
1175        let expr = Expr::BinaryOp {
1176            left: Box::new(Expr::Column("name".into())),
1177            op: BinOp::Or,
1178            right: Box::new(Expr::Literal(Value::Boolean(false))),
1179        };
1180        assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1181    }
1182
1183    #[test]
1184    fn eval_is_null() {
1185        let cols = test_columns();
1186        let row = vec![
1187            Value::Integer(1),
1188            Value::Null,
1189            Value::Null,
1190            Value::Boolean(true),
1191        ];
1192        let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1193        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1194
1195        let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1196        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1197    }
1198
1199    #[test]
1200    fn eval_not() {
1201        let cols = test_columns();
1202        let row = test_row();
1203        let expr = Expr::UnaryOp {
1204            op: UnaryOp::Not,
1205            expr: Box::new(Expr::Column("active".into())),
1206        };
1207        assert_eq!(
1208            eval_expr(&expr, &cols, &row).unwrap(),
1209            Value::Boolean(false)
1210        );
1211    }
1212
1213    #[test]
1214    fn eval_neg() {
1215        let cols = test_columns();
1216        let row = test_row();
1217        let expr = Expr::UnaryOp {
1218            op: UnaryOp::Neg,
1219            expr: Box::new(Expr::Column("id".into())),
1220        };
1221        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(-1));
1222    }
1223
1224    #[test]
1225    fn eval_division_by_zero() {
1226        let cols = test_columns();
1227        let row = test_row();
1228        let expr = Expr::BinaryOp {
1229            left: Box::new(Expr::Column("id".into())),
1230            op: BinOp::Div,
1231            right: Box::new(Expr::Literal(Value::Integer(0))),
1232        };
1233        assert!(matches!(
1234            eval_expr(&expr, &cols, &row),
1235            Err(SqlError::DivisionByZero)
1236        ));
1237    }
1238
1239    #[test]
1240    fn eval_mixed_numeric() {
1241        let cols = test_columns();
1242        let row = test_row();
1243        // id (int 1) + score (real 95.5) = real 96.5
1244        let expr = Expr::BinaryOp {
1245            left: Box::new(Expr::Column("id".into())),
1246            op: BinOp::Add,
1247            right: Box::new(Expr::Column("score".into())),
1248        };
1249        assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Real(96.5));
1250    }
1251
1252    #[test]
1253    fn is_truthy_values() {
1254        assert!(is_truthy(&Value::Boolean(true)));
1255        assert!(!is_truthy(&Value::Boolean(false)));
1256        assert!(!is_truthy(&Value::Null));
1257        assert!(is_truthy(&Value::Integer(1)));
1258        assert!(!is_truthy(&Value::Integer(0)));
1259    }
1260}