Skip to main content

citadel_sql/
eval.rs

1//! Expression evaluator with SQL three-valued logic.
2
3use std::collections::HashMap;
4
5use crate::error::{Result, SqlError};
6use crate::parser::{BinOp, Expr, UnaryOp};
7use crate::types::{ColumnDef, CompactString, DataType, Value};
8
9pub struct ColumnMap {
10    exact: HashMap<String, usize>,
11    short: HashMap<String, ShortMatch>,
12}
13
14#[derive(Clone)]
15enum ShortMatch {
16    Unique(usize),
17    Ambiguous,
18}
19
20impl Clone for ColumnMap {
21    fn clone(&self) -> Self {
22        Self {
23            exact: self.exact.clone(),
24            short: self.short.clone(),
25        }
26    }
27}
28
29impl ColumnMap {
30    pub fn new(columns: &[ColumnDef]) -> Self {
31        let mut exact = HashMap::with_capacity(columns.len() * 2);
32        let mut short: HashMap<String, ShortMatch> = HashMap::with_capacity(columns.len());
33
34        for (i, col) in columns.iter().enumerate() {
35            let lower = col.name.to_ascii_lowercase();
36            exact.insert(lower.clone(), i);
37
38            let unqualified = if let Some(dot) = lower.rfind('.') {
39                &lower[dot + 1..]
40            } else {
41                &lower
42            };
43            short
44                .entry(unqualified.to_string())
45                .and_modify(|e| *e = ShortMatch::Ambiguous)
46                .or_insert(ShortMatch::Unique(i));
47        }
48
49        Self { exact, short }
50    }
51
52    pub(crate) fn resolve(&self, name: &str) -> Result<usize> {
53        if let Some(&idx) = self.exact.get(name) {
54            return Ok(idx);
55        }
56        match self.short.get(name) {
57            Some(ShortMatch::Unique(idx)) => Ok(*idx),
58            Some(ShortMatch::Ambiguous) => Err(SqlError::AmbiguousColumn(name.to_string())),
59            None => Err(SqlError::ColumnNotFound(name.to_string())),
60        }
61    }
62
63    pub(crate) fn resolve_qualified(&self, table: &str, column: &str) -> Result<usize> {
64        let qualified = format!("{table}.{column}");
65        if let Some(&idx) = self.exact.get(&qualified) {
66            return Ok(idx);
67        }
68        match self.short.get(column) {
69            Some(ShortMatch::Unique(idx)) => Ok(*idx),
70            _ => Err(SqlError::ColumnNotFound(format!("{table}.{column}"))),
71        }
72    }
73}
74
75pub fn eval_expr(expr: &Expr, col_map: &ColumnMap, row: &[Value]) -> Result<Value> {
76    match expr {
77        Expr::Literal(v) => Ok(v.clone()),
78
79        Expr::Column(name) => {
80            let idx = col_map.resolve(name)?;
81            Ok(row[idx].clone())
82        }
83
84        Expr::QualifiedColumn { table, column } => {
85            let idx = col_map.resolve_qualified(table, column)?;
86            Ok(row[idx].clone())
87        }
88
89        Expr::BinaryOp { left, op, right } => {
90            let lval = eval_expr(left, col_map, row)?;
91            let rval = eval_expr(right, col_map, row)?;
92            eval_binary_op(&lval, *op, &rval)
93        }
94
95        Expr::UnaryOp { op, expr } => {
96            let val = eval_expr(expr, col_map, row)?;
97            eval_unary_op(*op, &val)
98        }
99
100        Expr::IsNull(e) => {
101            let val = eval_expr(e, col_map, row)?;
102            Ok(Value::Boolean(val.is_null()))
103        }
104
105        Expr::IsNotNull(e) => {
106            let val = eval_expr(e, col_map, row)?;
107            Ok(Value::Boolean(!val.is_null()))
108        }
109
110        Expr::Function { name, args } => eval_scalar_function(name, args, col_map, row),
111
112        Expr::CountStar => Err(SqlError::Unsupported(
113            "COUNT(*) in non-aggregate context".into(),
114        )),
115
116        Expr::InList {
117            expr: e,
118            list,
119            negated,
120        } => {
121            let lhs = eval_expr(e, col_map, row)?;
122            eval_in_values(&lhs, list, col_map, row, *negated)
123        }
124
125        Expr::InSet {
126            expr: e,
127            values,
128            has_null,
129            negated,
130        } => {
131            let lhs = eval_expr(e, col_map, row)?;
132            eval_in_set(&lhs, values, *has_null, *negated)
133        }
134
135        Expr::Between {
136            expr: e,
137            low,
138            high,
139            negated,
140        } => {
141            let val = eval_expr(e, col_map, row)?;
142            let lo = eval_expr(low, col_map, row)?;
143            let hi = eval_expr(high, col_map, row)?;
144            eval_between(&val, &lo, &hi, *negated)
145        }
146
147        Expr::Like {
148            expr: e,
149            pattern,
150            escape,
151            negated,
152        } => {
153            let val = eval_expr(e, col_map, row)?;
154            let pat = eval_expr(pattern, col_map, row)?;
155            let esc = escape
156                .as_ref()
157                .map(|e| eval_expr(e, col_map, row))
158                .transpose()?;
159            eval_like(&val, &pat, esc.as_ref(), *negated)
160        }
161
162        Expr::Case {
163            operand,
164            conditions,
165            else_result,
166        } => eval_case(
167            operand.as_deref(),
168            conditions,
169            else_result.as_deref(),
170            col_map,
171            row,
172        ),
173
174        Expr::Coalesce(args) => {
175            for arg in args {
176                let val = eval_expr(arg, col_map, row)?;
177                if !val.is_null() {
178                    return Ok(val);
179                }
180            }
181            Ok(Value::Null)
182        }
183
184        Expr::Cast { expr: e, data_type } => {
185            let val = eval_expr(e, col_map, row)?;
186            eval_cast(&val, *data_type)
187        }
188
189        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
190            SqlError::Unsupported("subquery not materialized (internal error)".into()),
191        ),
192
193        Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
194
195        Expr::WindowFunction { .. } => Err(SqlError::Unsupported(
196            "window functions are only allowed in SELECT columns".into(),
197        )),
198    }
199}
200
201fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
202    // SQL three-valued logic for AND/OR
203    match op {
204        BinOp::And => return eval_and(left, right),
205        BinOp::Or => return eval_or(left, right),
206        _ => {}
207    }
208
209    // NULL propagation for all other ops (including || per SQL standard)
210    if left.is_null() || right.is_null() {
211        return Ok(Value::Null);
212    }
213
214    match op {
215        BinOp::Eq => Ok(Value::Boolean(left == right)),
216        BinOp::NotEq => Ok(Value::Boolean(left != right)),
217        BinOp::Lt => Ok(Value::Boolean(left < right)),
218        BinOp::Gt => Ok(Value::Boolean(left > right)),
219        BinOp::LtEq => Ok(Value::Boolean(left <= right)),
220        BinOp::GtEq => Ok(Value::Boolean(left >= right)),
221        BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
222        BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
223        BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
224        BinOp::Div => {
225            match right {
226                Value::Integer(0) => return Err(SqlError::DivisionByZero),
227                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
228                _ => {}
229            }
230            eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
231        }
232        BinOp::Mod => {
233            match right {
234                Value::Integer(0) => return Err(SqlError::DivisionByZero),
235                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
236                _ => {}
237            }
238            eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
239        }
240        BinOp::Concat => {
241            let ls = value_to_text(left);
242            let rs = value_to_text(right);
243            Ok(Value::Text(format!("{ls}{rs}").into()))
244        }
245        BinOp::And | BinOp::Or => unreachable!(),
246    }
247}
248
249/// SQL three-valued AND: NULL AND false = false, NULL AND true = NULL
250fn eval_and(left: &Value, right: &Value) -> Result<Value> {
251    let l = to_bool_or_null(left)?;
252    let r = to_bool_or_null(right)?;
253    match (l, r) {
254        (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
255        (Some(true), Some(true)) => Ok(Value::Boolean(true)),
256        _ => Ok(Value::Null),
257    }
258}
259
260/// SQL three-valued OR: NULL OR true = true, NULL OR false = NULL
261fn eval_or(left: &Value, right: &Value) -> Result<Value> {
262    let l = to_bool_or_null(left)?;
263    let r = to_bool_or_null(right)?;
264    match (l, r) {
265        (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
266        (Some(false), Some(false)) => Ok(Value::Boolean(false)),
267        _ => Ok(Value::Null),
268    }
269}
270
271fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
272    match val {
273        Value::Boolean(b) => Ok(Some(*b)),
274        Value::Null => Ok(None),
275        Value::Integer(i) => Ok(Some(*i != 0)),
276        _ => Err(SqlError::TypeMismatch {
277            expected: "BOOLEAN".into(),
278            got: format!("{}", val.data_type()),
279        }),
280    }
281}
282
283fn eval_arithmetic(
284    left: &Value,
285    right: &Value,
286    int_op: fn(i64, i64) -> Option<i64>,
287    real_op: fn(f64, f64) -> f64,
288) -> Result<Value> {
289    match (left, right) {
290        (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
291            .map(Value::Integer)
292            .ok_or(SqlError::IntegerOverflow),
293        (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
294        (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
295        (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
296        _ => Err(SqlError::TypeMismatch {
297            expected: "numeric".into(),
298            got: format!("{} and {}", left.data_type(), right.data_type()),
299        }),
300    }
301}
302
303fn eval_in_values(
304    lhs: &Value,
305    list: &[Expr],
306    col_map: &ColumnMap,
307    row: &[Value],
308    negated: bool,
309) -> Result<Value> {
310    if list.is_empty() {
311        return Ok(Value::Boolean(negated));
312    }
313    if lhs.is_null() {
314        return Ok(Value::Null);
315    }
316    let mut has_null = false;
317    for item in list {
318        let rhs = eval_expr(item, col_map, row)?;
319        if rhs.is_null() {
320            has_null = true;
321        } else if lhs == &rhs {
322            return Ok(Value::Boolean(!negated));
323        }
324    }
325    if has_null {
326        Ok(Value::Null)
327    } else {
328        Ok(Value::Boolean(negated))
329    }
330}
331
332fn eval_in_set(
333    lhs: &Value,
334    values: &std::collections::HashSet<Value>,
335    has_null: bool,
336    negated: bool,
337) -> Result<Value> {
338    if values.is_empty() && !has_null {
339        return Ok(Value::Boolean(negated));
340    }
341    if lhs.is_null() {
342        return Ok(Value::Null);
343    }
344    if values.contains(lhs) {
345        return Ok(Value::Boolean(!negated));
346    }
347    if has_null {
348        Ok(Value::Null)
349    } else {
350        Ok(Value::Boolean(negated))
351    }
352}
353
354fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
355    if val.is_null() {
356        return Ok(Value::Null);
357    }
358    match op {
359        UnaryOp::Neg => match val {
360            Value::Integer(i) => i
361                .checked_neg()
362                .map(Value::Integer)
363                .ok_or(SqlError::IntegerOverflow),
364            Value::Real(r) => Ok(Value::Real(-r)),
365            _ => Err(SqlError::TypeMismatch {
366                expected: "numeric".into(),
367                got: format!("{}", val.data_type()),
368            }),
369        },
370        UnaryOp::Not => match val {
371            Value::Boolean(b) => Ok(Value::Boolean(!b)),
372            Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
373            _ => Err(SqlError::TypeMismatch {
374                expected: "BOOLEAN".into(),
375                got: format!("{}", val.data_type()),
376            }),
377        },
378    }
379}
380
381fn value_to_text(val: &Value) -> String {
382    match val {
383        Value::Text(s) => s.to_string(),
384        Value::Integer(i) => i.to_string(),
385        Value::Real(r) => {
386            if r.fract() == 0.0 && r.is_finite() {
387                format!("{r:.1}")
388            } else {
389                format!("{r}")
390            }
391        }
392        Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
393        Value::Null => String::new(),
394        Value::Blob(b) => {
395            let mut s = String::with_capacity(b.len() * 2);
396            for byte in b {
397                s.push_str(&format!("{byte:02X}"));
398            }
399            s
400        }
401    }
402}
403
404fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
405    if val.is_null() || low.is_null() || high.is_null() {
406        let ge = if val.is_null() || low.is_null() {
407            None
408        } else {
409            Some(*val >= *low)
410        };
411        let le = if val.is_null() || high.is_null() {
412            None
413        } else {
414            Some(*val <= *high)
415        };
416
417        let result = match (ge, le) {
418            (Some(false), _) | (_, Some(false)) => Some(false),
419            (Some(true), Some(true)) => Some(true),
420            _ => None,
421        };
422
423        return match result {
424            Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
425            None => Ok(Value::Null),
426        };
427    }
428
429    let in_range = *val >= *low && *val <= *high;
430    Ok(Value::Boolean(if negated { !in_range } else { in_range }))
431}
432
433const MAX_LIKE_PATTERN_LEN: usize = 10_000;
434
435fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
436    if val.is_null() || pattern.is_null() {
437        return Ok(Value::Null);
438    }
439    let text = match val {
440        Value::Text(s) => s.as_str(),
441        _ => {
442            return Err(SqlError::TypeMismatch {
443                expected: "TEXT".into(),
444                got: val.data_type().to_string(),
445            })
446        }
447    };
448    let pat = match pattern {
449        Value::Text(s) => s.as_str(),
450        _ => {
451            return Err(SqlError::TypeMismatch {
452                expected: "TEXT".into(),
453                got: pattern.data_type().to_string(),
454            })
455        }
456    };
457
458    if pat.len() > MAX_LIKE_PATTERN_LEN {
459        return Err(SqlError::InvalidValue(format!(
460            "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
461            pat.len()
462        )));
463    }
464
465    let esc_char = match escape {
466        Some(Value::Text(s)) => {
467            let mut chars = s.chars();
468            let c = chars.next().ok_or_else(|| {
469                SqlError::InvalidValue("ESCAPE must be a single character".into())
470            })?;
471            if chars.next().is_some() {
472                return Err(SqlError::InvalidValue(
473                    "ESCAPE must be a single character".into(),
474                ));
475            }
476            Some(c)
477        }
478        Some(Value::Null) => return Ok(Value::Null),
479        Some(_) => {
480            return Err(SqlError::TypeMismatch {
481                expected: "TEXT".into(),
482                got: "non-text".into(),
483            })
484        }
485        None => None,
486    };
487
488    let matched = like_match(text, pat, esc_char);
489    Ok(Value::Boolean(if negated { !matched } else { matched }))
490}
491
492fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
493    let t: Vec<char> = text.chars().collect();
494    let p: Vec<char> = pattern.chars().collect();
495    like_match_impl(&t, &p, 0, 0, escape)
496}
497
498fn like_match_impl(
499    t: &[char],
500    p: &[char],
501    mut ti: usize,
502    mut pi: usize,
503    esc: Option<char>,
504) -> bool {
505    let mut star_pi: Option<usize> = None;
506    let mut star_ti: usize = 0;
507
508    while ti < t.len() {
509        if pi < p.len() {
510            if let Some(ec) = esc {
511                if p[pi] == ec && pi + 1 < p.len() {
512                    pi += 1;
513                    let pc_lower = p[pi].to_ascii_lowercase();
514                    let tc_lower = t[ti].to_ascii_lowercase();
515                    if pc_lower == tc_lower {
516                        pi += 1;
517                        ti += 1;
518                        continue;
519                    } else if let Some(sp) = star_pi {
520                        pi = sp + 1;
521                        star_ti += 1;
522                        ti = star_ti;
523                        continue;
524                    } else {
525                        return false;
526                    }
527                }
528            }
529            if p[pi] == '%' {
530                star_pi = Some(pi);
531                star_ti = ti;
532                pi += 1;
533                continue;
534            }
535            if p[pi] == '_' {
536                pi += 1;
537                ti += 1;
538                continue;
539            }
540            if p[pi].eq_ignore_ascii_case(&t[ti]) {
541                pi += 1;
542                ti += 1;
543                continue;
544            }
545        }
546        if let Some(sp) = star_pi {
547            pi = sp + 1;
548            star_ti += 1;
549            ti = star_ti;
550        } else {
551            return false;
552        }
553    }
554
555    while pi < p.len() && p[pi] == '%' {
556        pi += 1;
557    }
558    pi == p.len()
559}
560
561fn eval_case(
562    operand: Option<&Expr>,
563    conditions: &[(Expr, Expr)],
564    else_result: Option<&Expr>,
565    col_map: &ColumnMap,
566    row: &[Value],
567) -> Result<Value> {
568    if let Some(op_expr) = operand {
569        let op_val = eval_expr(op_expr, col_map, row)?;
570        for (cond, result) in conditions {
571            let cond_val = eval_expr(cond, col_map, row)?;
572            if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
573                return eval_expr(result, col_map, row);
574            }
575        }
576    } else {
577        for (cond, result) in conditions {
578            let cond_val = eval_expr(cond, col_map, row)?;
579            if is_truthy(&cond_val) {
580                return eval_expr(result, col_map, row);
581            }
582        }
583    }
584    match else_result {
585        Some(e) => eval_expr(e, col_map, row),
586        None => Ok(Value::Null),
587    }
588}
589
590fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
591    if val.is_null() {
592        return Ok(Value::Null);
593    }
594    match target {
595        DataType::Integer => match val {
596            Value::Integer(_) => Ok(val.clone()),
597            Value::Real(r) => Ok(Value::Integer(*r as i64)),
598            Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
599            Value::Text(s) => s
600                .trim()
601                .parse::<i64>()
602                .map(Value::Integer)
603                .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
604                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
605            _ => Err(SqlError::InvalidValue(format!(
606                "cannot cast {} to INTEGER",
607                val.data_type()
608            ))),
609        },
610        DataType::Real => match val {
611            Value::Real(_) => Ok(val.clone()),
612            Value::Integer(i) => Ok(Value::Real(*i as f64)),
613            Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
614            Value::Text(s) => s
615                .trim()
616                .parse::<f64>()
617                .map(Value::Real)
618                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
619            _ => Err(SqlError::InvalidValue(format!(
620                "cannot cast {} to REAL",
621                val.data_type()
622            ))),
623        },
624        DataType::Text => Ok(Value::Text(value_to_text(val).into())),
625        DataType::Boolean => match val {
626            Value::Boolean(_) => Ok(val.clone()),
627            Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
628            Value::Text(s) => {
629                let lower = s.trim().to_ascii_lowercase();
630                match lower.as_str() {
631                    "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
632                    "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
633                    _ => Err(SqlError::InvalidValue(format!(
634                        "cannot cast '{s}' to BOOLEAN"
635                    ))),
636                }
637            }
638            _ => Err(SqlError::InvalidValue(format!(
639                "cannot cast {} to BOOLEAN",
640                val.data_type()
641            ))),
642        },
643        DataType::Blob => match val {
644            Value::Blob(_) => Ok(val.clone()),
645            Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
646            _ => Err(SqlError::InvalidValue(format!(
647                "cannot cast {} to BLOB",
648                val.data_type()
649            ))),
650        },
651        DataType::Null => Ok(Value::Null),
652    }
653}
654
655fn eval_scalar_function(
656    name: &str,
657    args: &[Expr],
658    col_map: &ColumnMap,
659    row: &[Value],
660) -> Result<Value> {
661    let evaluated: Vec<Value> = args
662        .iter()
663        .map(|a| eval_expr(a, col_map, row))
664        .collect::<Result<Vec<_>>>()?;
665
666    match name {
667        "LENGTH" => {
668            check_args(name, &evaluated, 1)?;
669            match &evaluated[0] {
670                Value::Null => Ok(Value::Null),
671                Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
672                Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
673                _ => Ok(Value::Integer(
674                    value_to_text(&evaluated[0]).chars().count() as i64
675                )),
676            }
677        }
678        "UPPER" => {
679            check_args(name, &evaluated, 1)?;
680            match &evaluated[0] {
681                Value::Null => Ok(Value::Null),
682                Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
683                _ => Ok(Value::Text(
684                    value_to_text(&evaluated[0]).to_ascii_uppercase().into(),
685                )),
686            }
687        }
688        "LOWER" => {
689            check_args(name, &evaluated, 1)?;
690            match &evaluated[0] {
691                Value::Null => Ok(Value::Null),
692                Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
693                _ => Ok(Value::Text(
694                    value_to_text(&evaluated[0]).to_ascii_lowercase().into(),
695                )),
696            }
697        }
698        "SUBSTR" | "SUBSTRING" => {
699            if evaluated.len() < 2 || evaluated.len() > 3 {
700                return Err(SqlError::InvalidValue(format!(
701                    "{name} requires 2 or 3 arguments"
702                )));
703            }
704            if evaluated.iter().any(|v| v.is_null()) {
705                return Ok(Value::Null);
706            }
707            let s = value_to_text(&evaluated[0]);
708            let chars: Vec<char> = s.chars().collect();
709            let start = match &evaluated[1] {
710                Value::Integer(i) => *i,
711                _ => {
712                    return Err(SqlError::TypeMismatch {
713                        expected: "INTEGER".into(),
714                        got: evaluated[1].data_type().to_string(),
715                    })
716                }
717            };
718            let len = chars.len() as i64;
719
720            let (begin, count) = if evaluated.len() == 3 {
721                let cnt = match &evaluated[2] {
722                    Value::Integer(i) => *i,
723                    _ => {
724                        return Err(SqlError::TypeMismatch {
725                            expected: "INTEGER".into(),
726                            got: evaluated[2].data_type().to_string(),
727                        })
728                    }
729                };
730                if start >= 1 {
731                    let b = (start - 1).min(len) as usize;
732                    let c = cnt.max(0) as usize;
733                    (b, c)
734                } else if start == 0 {
735                    let c = (cnt - 1).max(0) as usize;
736                    (0usize, c)
737                } else {
738                    let adjusted_cnt = (cnt + start - 1).max(0) as usize;
739                    (0usize, adjusted_cnt)
740                }
741            } else if start >= 1 {
742                let b = (start - 1).min(len) as usize;
743                (b, chars.len() - b)
744            } else if start == 0 {
745                (0usize, chars.len())
746            } else {
747                let b = (len + start).max(0) as usize;
748                (b, chars.len() - b)
749            };
750
751            let result: String = chars.iter().skip(begin).take(count).collect();
752            Ok(Value::Text(result.into()))
753        }
754        "TRIM" | "LTRIM" | "RTRIM" => {
755            if evaluated.is_empty() || evaluated.len() > 2 {
756                return Err(SqlError::InvalidValue(format!(
757                    "{name} requires 1 or 2 arguments"
758                )));
759            }
760            if evaluated[0].is_null() {
761                return Ok(Value::Null);
762            }
763            let s = value_to_text(&evaluated[0]);
764            let trim_chars: Vec<char> = if evaluated.len() == 2 {
765                if evaluated[1].is_null() {
766                    return Ok(Value::Null);
767                }
768                value_to_text(&evaluated[1]).chars().collect()
769            } else {
770                vec![' ']
771            };
772            let result = match name {
773                "TRIM" => s
774                    .trim_matches(|c: char| trim_chars.contains(&c))
775                    .to_string(),
776                "LTRIM" => s
777                    .trim_start_matches(|c: char| trim_chars.contains(&c))
778                    .to_string(),
779                "RTRIM" => s
780                    .trim_end_matches(|c: char| trim_chars.contains(&c))
781                    .to_string(),
782                _ => unreachable!(),
783            };
784            Ok(Value::Text(result.into()))
785        }
786        "REPLACE" => {
787            check_args(name, &evaluated, 3)?;
788            if evaluated.iter().any(|v| v.is_null()) {
789                return Ok(Value::Null);
790            }
791            let s = value_to_text(&evaluated[0]);
792            let from = value_to_text(&evaluated[1]);
793            let to = value_to_text(&evaluated[2]);
794            if from.is_empty() {
795                return Ok(Value::Text(s.into()));
796            }
797            Ok(Value::Text(s.replace(&from, &to).into()))
798        }
799        "INSTR" => {
800            check_args(name, &evaluated, 2)?;
801            if evaluated.iter().any(|v| v.is_null()) {
802                return Ok(Value::Null);
803            }
804            let haystack = value_to_text(&evaluated[0]);
805            let needle = value_to_text(&evaluated[1]);
806            let pos = haystack
807                .find(&needle)
808                .map(|i| haystack[..i].chars().count() as i64 + 1)
809                .unwrap_or(0);
810            Ok(Value::Integer(pos))
811        }
812        "CONCAT" => {
813            if evaluated.is_empty() {
814                return Ok(Value::Text(CompactString::default()));
815            }
816            let mut result = String::new();
817            for v in &evaluated {
818                match v {
819                    Value::Null => {}
820                    _ => result.push_str(&value_to_text(v)),
821                }
822            }
823            Ok(Value::Text(result.into()))
824        }
825        "ABS" => {
826            check_args(name, &evaluated, 1)?;
827            match &evaluated[0] {
828                Value::Null => Ok(Value::Null),
829                Value::Integer(i) => i
830                    .checked_abs()
831                    .map(Value::Integer)
832                    .ok_or(SqlError::IntegerOverflow),
833                Value::Real(r) => Ok(Value::Real(r.abs())),
834                _ => Err(SqlError::TypeMismatch {
835                    expected: "numeric".into(),
836                    got: evaluated[0].data_type().to_string(),
837                }),
838            }
839        }
840        "ROUND" => {
841            if evaluated.is_empty() || evaluated.len() > 2 {
842                return Err(SqlError::InvalidValue(
843                    "ROUND requires 1 or 2 arguments".into(),
844                ));
845            }
846            if evaluated[0].is_null() {
847                return Ok(Value::Null);
848            }
849            let val = match &evaluated[0] {
850                Value::Integer(i) => *i as f64,
851                Value::Real(r) => *r,
852                _ => {
853                    return Err(SqlError::TypeMismatch {
854                        expected: "numeric".into(),
855                        got: evaluated[0].data_type().to_string(),
856                    })
857                }
858            };
859            let places = if evaluated.len() == 2 {
860                match &evaluated[1] {
861                    Value::Null => return Ok(Value::Null),
862                    Value::Integer(i) => *i,
863                    _ => {
864                        return Err(SqlError::TypeMismatch {
865                            expected: "INTEGER".into(),
866                            got: evaluated[1].data_type().to_string(),
867                        })
868                    }
869                }
870            } else {
871                0
872            };
873            let factor = 10f64.powi(places as i32);
874            let rounded = (val * factor).round() / factor;
875            Ok(Value::Real(rounded))
876        }
877        "CEIL" | "CEILING" => {
878            check_args(name, &evaluated, 1)?;
879            match &evaluated[0] {
880                Value::Null => Ok(Value::Null),
881                Value::Integer(i) => Ok(Value::Integer(*i)),
882                Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
883                _ => Err(SqlError::TypeMismatch {
884                    expected: "numeric".into(),
885                    got: evaluated[0].data_type().to_string(),
886                }),
887            }
888        }
889        "FLOOR" => {
890            check_args(name, &evaluated, 1)?;
891            match &evaluated[0] {
892                Value::Null => Ok(Value::Null),
893                Value::Integer(i) => Ok(Value::Integer(*i)),
894                Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
895                _ => Err(SqlError::TypeMismatch {
896                    expected: "numeric".into(),
897                    got: evaluated[0].data_type().to_string(),
898                }),
899            }
900        }
901        "SIGN" => {
902            check_args(name, &evaluated, 1)?;
903            match &evaluated[0] {
904                Value::Null => Ok(Value::Null),
905                Value::Integer(i) => Ok(Value::Integer(i.signum())),
906                Value::Real(r) => {
907                    if *r > 0.0 {
908                        Ok(Value::Integer(1))
909                    } else if *r < 0.0 {
910                        Ok(Value::Integer(-1))
911                    } else {
912                        Ok(Value::Integer(0))
913                    }
914                }
915                _ => Err(SqlError::TypeMismatch {
916                    expected: "numeric".into(),
917                    got: evaluated[0].data_type().to_string(),
918                }),
919            }
920        }
921        "SQRT" => {
922            check_args(name, &evaluated, 1)?;
923            match &evaluated[0] {
924                Value::Null => Ok(Value::Null),
925                Value::Integer(i) => {
926                    if *i < 0 {
927                        Ok(Value::Null)
928                    } else {
929                        Ok(Value::Real((*i as f64).sqrt()))
930                    }
931                }
932                Value::Real(r) => {
933                    if *r < 0.0 {
934                        Ok(Value::Null)
935                    } else {
936                        Ok(Value::Real(r.sqrt()))
937                    }
938                }
939                _ => Err(SqlError::TypeMismatch {
940                    expected: "numeric".into(),
941                    got: evaluated[0].data_type().to_string(),
942                }),
943            }
944        }
945        "RANDOM" => {
946            check_args(name, &evaluated, 0)?;
947            use std::collections::hash_map::DefaultHasher;
948            use std::hash::{Hash, Hasher};
949            use std::time::SystemTime;
950            let mut hasher = DefaultHasher::new();
951            SystemTime::now().hash(&mut hasher);
952            std::thread::current().id().hash(&mut hasher);
953            let mut val = hasher.finish() as i64;
954            if val == i64::MIN {
955                val = i64::MAX;
956            }
957            Ok(Value::Integer(val))
958        }
959        "TYPEOF" => {
960            check_args(name, &evaluated, 1)?;
961            let type_name = match &evaluated[0] {
962                Value::Null => "null",
963                Value::Integer(_) => "integer",
964                Value::Real(_) => "real",
965                Value::Text(_) => "text",
966                Value::Blob(_) => "blob",
967                Value::Boolean(_) => "boolean",
968            };
969            Ok(Value::Text(type_name.into()))
970        }
971        "MIN" => {
972            check_args(name, &evaluated, 2)?;
973            if evaluated[0].is_null() {
974                return Ok(evaluated[1].clone());
975            }
976            if evaluated[1].is_null() {
977                return Ok(evaluated[0].clone());
978            }
979            if evaluated[0] <= evaluated[1] {
980                Ok(evaluated[0].clone())
981            } else {
982                Ok(evaluated[1].clone())
983            }
984        }
985        "MAX" => {
986            check_args(name, &evaluated, 2)?;
987            if evaluated[0].is_null() {
988                return Ok(evaluated[1].clone());
989            }
990            if evaluated[1].is_null() {
991                return Ok(evaluated[0].clone());
992            }
993            if evaluated[0] >= evaluated[1] {
994                Ok(evaluated[0].clone())
995            } else {
996                Ok(evaluated[1].clone())
997            }
998        }
999        "HEX" => {
1000            check_args(name, &evaluated, 1)?;
1001            match &evaluated[0] {
1002                Value::Null => Ok(Value::Null),
1003                Value::Blob(b) => {
1004                    let mut s = String::with_capacity(b.len() * 2);
1005                    for byte in b {
1006                        s.push_str(&format!("{byte:02X}"));
1007                    }
1008                    Ok(Value::Text(s.into()))
1009                }
1010                Value::Text(s) => {
1011                    let mut r = String::with_capacity(s.len() * 2);
1012                    for byte in s.as_bytes() {
1013                        r.push_str(&format!("{byte:02X}"));
1014                    }
1015                    Ok(Value::Text(r.into()))
1016                }
1017                _ => Ok(Value::Text(value_to_text(&evaluated[0]).into())),
1018            }
1019        }
1020        _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
1021    }
1022}
1023
1024fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
1025    if args.len() != expected {
1026        Err(SqlError::InvalidValue(format!(
1027            "{name} requires {expected} argument(s), got {}",
1028            args.len()
1029        )))
1030    } else {
1031        Ok(())
1032    }
1033}
1034
1035pub fn referenced_columns(expr: &Expr, columns: &[ColumnDef]) -> Vec<usize> {
1036    let mut indices = Vec::new();
1037    collect_column_refs(expr, columns, &mut indices);
1038    indices.sort_unstable();
1039    indices.dedup();
1040    indices
1041}
1042
1043fn collect_column_refs(expr: &Expr, columns: &[ColumnDef], out: &mut Vec<usize>) {
1044    match expr {
1045        Expr::Column(name) => {
1046            for (i, c) in columns.iter().enumerate() {
1047                if c.name == *name || c.name.ends_with(&format!(".{name}")) {
1048                    out.push(i);
1049                    break;
1050                }
1051            }
1052        }
1053        Expr::QualifiedColumn { table, column } => {
1054            let qualified = format!("{table}.{column}");
1055            if let Some(idx) = columns.iter().position(|c| c.name == qualified) {
1056                out.push(idx);
1057            } else {
1058                let matches: Vec<usize> = columns
1059                    .iter()
1060                    .enumerate()
1061                    .filter(|(_, c)| c.name == *column)
1062                    .map(|(i, _)| i)
1063                    .collect();
1064                if matches.len() == 1 {
1065                    out.push(matches[0]);
1066                }
1067            }
1068        }
1069        Expr::BinaryOp { left, right, .. } => {
1070            collect_column_refs(left, columns, out);
1071            collect_column_refs(right, columns, out);
1072        }
1073        Expr::UnaryOp { expr, .. } => {
1074            collect_column_refs(expr, columns, out);
1075        }
1076        Expr::IsNull(e) | Expr::IsNotNull(e) => {
1077            collect_column_refs(e, columns, out);
1078        }
1079        Expr::Function { args, .. } => {
1080            for arg in args {
1081                collect_column_refs(arg, columns, out);
1082            }
1083        }
1084        Expr::InSubquery { expr, .. } => {
1085            collect_column_refs(expr, columns, out);
1086        }
1087        Expr::InList { expr, list, .. } => {
1088            collect_column_refs(expr, columns, out);
1089            for item in list {
1090                collect_column_refs(item, columns, out);
1091            }
1092        }
1093        Expr::InSet { expr, .. } => {
1094            collect_column_refs(expr, columns, out);
1095        }
1096        Expr::Between {
1097            expr, low, high, ..
1098        } => {
1099            collect_column_refs(expr, columns, out);
1100            collect_column_refs(low, columns, out);
1101            collect_column_refs(high, columns, out);
1102        }
1103        Expr::Like {
1104            expr,
1105            pattern,
1106            escape,
1107            ..
1108        } => {
1109            collect_column_refs(expr, columns, out);
1110            collect_column_refs(pattern, columns, out);
1111            if let Some(esc) = escape {
1112                collect_column_refs(esc, columns, out);
1113            }
1114        }
1115        Expr::Case {
1116            operand,
1117            conditions,
1118            else_result,
1119        } => {
1120            if let Some(op) = operand {
1121                collect_column_refs(op, columns, out);
1122            }
1123            for (when, then) in conditions {
1124                collect_column_refs(when, columns, out);
1125                collect_column_refs(then, columns, out);
1126            }
1127            if let Some(e) = else_result {
1128                collect_column_refs(e, columns, out);
1129            }
1130        }
1131        Expr::Coalesce(args) => {
1132            for arg in args {
1133                collect_column_refs(arg, columns, out);
1134            }
1135        }
1136        Expr::Cast { expr, .. } => {
1137            collect_column_refs(expr, columns, out);
1138        }
1139        Expr::WindowFunction { args, spec, .. } => {
1140            for arg in args {
1141                collect_column_refs(arg, columns, out);
1142            }
1143            for pb in &spec.partition_by {
1144                collect_column_refs(pb, columns, out);
1145            }
1146            for ob in &spec.order_by {
1147                collect_column_refs(&ob.expr, columns, out);
1148            }
1149        }
1150        Expr::Literal(_)
1151        | Expr::Parameter(_)
1152        | Expr::CountStar
1153        | Expr::Exists { .. }
1154        | Expr::ScalarSubquery(_) => {}
1155    }
1156}
1157
1158/// Check if an expression result is truthy (for WHERE/HAVING).
1159pub fn is_truthy(val: &Value) -> bool {
1160    match val {
1161        Value::Boolean(b) => *b,
1162        Value::Integer(i) => *i != 0,
1163        Value::Null => false,
1164        _ => true,
1165    }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170    use super::*;
1171    use crate::types::DataType;
1172
1173    fn col(name: &str, dt: DataType, nullable: bool, pos: u16) -> ColumnDef {
1174        ColumnDef {
1175            name: name.into(),
1176            data_type: dt,
1177            nullable,
1178            position: pos,
1179            default_expr: None,
1180            default_sql: None,
1181            check_expr: None,
1182            check_sql: None,
1183            check_name: None,
1184        }
1185    }
1186
1187    fn test_columns() -> Vec<ColumnDef> {
1188        vec![
1189            col("id", DataType::Integer, false, 0),
1190            col("name", DataType::Text, true, 1),
1191            col("score", DataType::Real, true, 2),
1192            col("active", DataType::Boolean, false, 3),
1193        ]
1194    }
1195
1196    fn test_row() -> Vec<Value> {
1197        vec![
1198            Value::Integer(1),
1199            Value::Text("Alice".into()),
1200            Value::Real(95.5),
1201            Value::Boolean(true),
1202        ]
1203    }
1204
1205    #[test]
1206    fn eval_literal() {
1207        let cols = test_columns();
1208        let cm = ColumnMap::new(&cols);
1209        let row = test_row();
1210        let expr = Expr::Literal(Value::Integer(42));
1211        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(42));
1212    }
1213
1214    #[test]
1215    fn eval_column_ref() {
1216        let cols = test_columns();
1217        let cm = ColumnMap::new(&cols);
1218        let row = test_row();
1219        let expr = Expr::Column("name".into());
1220        assert_eq!(
1221            eval_expr(&expr, &cm, &row).unwrap(),
1222            Value::Text("Alice".into())
1223        );
1224    }
1225
1226    #[test]
1227    fn eval_column_case_insensitive() {
1228        let cols = test_columns();
1229        let cm = ColumnMap::new(&cols);
1230        let row = test_row();
1231        let expr = Expr::Column("name".into());
1232        assert_eq!(
1233            eval_expr(&expr, &cm, &row).unwrap(),
1234            Value::Text("Alice".into())
1235        );
1236    }
1237
1238    #[test]
1239    fn eval_arithmetic_int() {
1240        let cols = test_columns();
1241        let cm = ColumnMap::new(&cols);
1242        let row = test_row();
1243        let expr = Expr::BinaryOp {
1244            left: Box::new(Expr::Column("id".into())),
1245            op: BinOp::Add,
1246            right: Box::new(Expr::Literal(Value::Integer(10))),
1247        };
1248        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(11));
1249    }
1250
1251    #[test]
1252    fn eval_comparison() {
1253        let cols = test_columns();
1254        let cm = ColumnMap::new(&cols);
1255        let row = test_row();
1256        let expr = Expr::BinaryOp {
1257            left: Box::new(Expr::Column("score".into())),
1258            op: BinOp::Gt,
1259            right: Box::new(Expr::Literal(Value::Real(90.0))),
1260        };
1261        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1262    }
1263
1264    #[test]
1265    fn eval_null_propagation() {
1266        let cols = test_columns();
1267        let cm = ColumnMap::new(&cols);
1268        let row = vec![
1269            Value::Integer(1),
1270            Value::Null,
1271            Value::Null,
1272            Value::Boolean(true),
1273        ];
1274        let expr = Expr::BinaryOp {
1275            left: Box::new(Expr::Column("name".into())),
1276            op: BinOp::Eq,
1277            right: Box::new(Expr::Literal(Value::Text("test".into()))),
1278        };
1279        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1280    }
1281
1282    #[test]
1283    fn eval_and_three_valued() {
1284        let cols = test_columns();
1285        let cm = ColumnMap::new(&cols);
1286        let row = vec![
1287            Value::Integer(1),
1288            Value::Null,
1289            Value::Null,
1290            Value::Boolean(true),
1291        ];
1292
1293        // NULL AND false = false
1294        let expr = Expr::BinaryOp {
1295            left: Box::new(Expr::Column("name".into())),
1296            op: BinOp::And,
1297            right: Box::new(Expr::Literal(Value::Boolean(false))),
1298        };
1299        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1300
1301        // NULL AND true = NULL
1302        let expr = Expr::BinaryOp {
1303            left: Box::new(Expr::Column("name".into())),
1304            op: BinOp::And,
1305            right: Box::new(Expr::Literal(Value::Boolean(true))),
1306        };
1307        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1308    }
1309
1310    #[test]
1311    fn eval_or_three_valued() {
1312        let cols = test_columns();
1313        let cm = ColumnMap::new(&cols);
1314        let row = vec![
1315            Value::Integer(1),
1316            Value::Null,
1317            Value::Null,
1318            Value::Boolean(true),
1319        ];
1320
1321        // NULL OR true = true
1322        let expr = Expr::BinaryOp {
1323            left: Box::new(Expr::Column("name".into())),
1324            op: BinOp::Or,
1325            right: Box::new(Expr::Literal(Value::Boolean(true))),
1326        };
1327        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1328
1329        // NULL OR false = NULL
1330        let expr = Expr::BinaryOp {
1331            left: Box::new(Expr::Column("name".into())),
1332            op: BinOp::Or,
1333            right: Box::new(Expr::Literal(Value::Boolean(false))),
1334        };
1335        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1336    }
1337
1338    #[test]
1339    fn eval_is_null() {
1340        let cols = test_columns();
1341        let cm = ColumnMap::new(&cols);
1342        let row = vec![
1343            Value::Integer(1),
1344            Value::Null,
1345            Value::Null,
1346            Value::Boolean(true),
1347        ];
1348        let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1349        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1350
1351        let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1352        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1353    }
1354
1355    #[test]
1356    fn eval_not() {
1357        let cols = test_columns();
1358        let cm = ColumnMap::new(&cols);
1359        let row = test_row();
1360        let expr = Expr::UnaryOp {
1361            op: UnaryOp::Not,
1362            expr: Box::new(Expr::Column("active".into())),
1363        };
1364        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1365    }
1366
1367    #[test]
1368    fn eval_neg() {
1369        let cols = test_columns();
1370        let cm = ColumnMap::new(&cols);
1371        let row = test_row();
1372        let expr = Expr::UnaryOp {
1373            op: UnaryOp::Neg,
1374            expr: Box::new(Expr::Column("id".into())),
1375        };
1376        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(-1));
1377    }
1378
1379    #[test]
1380    fn eval_division_by_zero() {
1381        let cols = test_columns();
1382        let cm = ColumnMap::new(&cols);
1383        let row = test_row();
1384        let expr = Expr::BinaryOp {
1385            left: Box::new(Expr::Column("id".into())),
1386            op: BinOp::Div,
1387            right: Box::new(Expr::Literal(Value::Integer(0))),
1388        };
1389        assert!(matches!(
1390            eval_expr(&expr, &cm, &row),
1391            Err(SqlError::DivisionByZero)
1392        ));
1393    }
1394
1395    #[test]
1396    fn eval_mixed_numeric() {
1397        let cols = test_columns();
1398        let cm = ColumnMap::new(&cols);
1399        let row = test_row();
1400        // id (int 1) + score (real 95.5) = real 96.5
1401        let expr = Expr::BinaryOp {
1402            left: Box::new(Expr::Column("id".into())),
1403            op: BinOp::Add,
1404            right: Box::new(Expr::Column("score".into())),
1405        };
1406        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Real(96.5));
1407    }
1408
1409    #[test]
1410    fn is_truthy_values() {
1411        assert!(is_truthy(&Value::Boolean(true)));
1412        assert!(!is_truthy(&Value::Boolean(false)));
1413        assert!(!is_truthy(&Value::Null));
1414        assert!(is_truthy(&Value::Integer(1)));
1415        assert!(!is_truthy(&Value::Integer(0)));
1416    }
1417}