Skip to main content

citadel_sql/
eval.rs

1//! Expression evaluator with SQL three-valued logic.
2
3use std::collections::HashMap;
4
5use crate::error::{Result, SqlError};
6use crate::parser::{BinOp, Expr, UnaryOp};
7use crate::types::{ColumnDef, CompactString, DataType, Value};
8
9pub struct ColumnMap {
10    exact: HashMap<String, usize>,
11    short: HashMap<String, ShortMatch>,
12}
13
14enum ShortMatch {
15    Unique(usize),
16    Ambiguous,
17}
18
19impl ColumnMap {
20    pub fn new(columns: &[ColumnDef]) -> Self {
21        let mut exact = HashMap::with_capacity(columns.len() * 2);
22        let mut short: HashMap<String, ShortMatch> = HashMap::with_capacity(columns.len());
23
24        for (i, col) in columns.iter().enumerate() {
25            let lower = col.name.to_ascii_lowercase();
26            exact.insert(lower.clone(), i);
27
28            let unqualified = if let Some(dot) = lower.rfind('.') {
29                &lower[dot + 1..]
30            } else {
31                &lower
32            };
33            short
34                .entry(unqualified.to_string())
35                .and_modify(|e| *e = ShortMatch::Ambiguous)
36                .or_insert(ShortMatch::Unique(i));
37        }
38
39        Self { exact, short }
40    }
41
42    pub(crate) fn resolve(&self, name: &str) -> Result<usize> {
43        if let Some(&idx) = self.exact.get(name) {
44            return Ok(idx);
45        }
46        match self.short.get(name) {
47            Some(ShortMatch::Unique(idx)) => Ok(*idx),
48            Some(ShortMatch::Ambiguous) => Err(SqlError::AmbiguousColumn(name.to_string())),
49            None => Err(SqlError::ColumnNotFound(name.to_string())),
50        }
51    }
52
53    pub(crate) fn resolve_qualified(&self, table: &str, column: &str) -> Result<usize> {
54        let qualified = format!("{table}.{column}");
55        if let Some(&idx) = self.exact.get(&qualified) {
56            return Ok(idx);
57        }
58        match self.short.get(column) {
59            Some(ShortMatch::Unique(idx)) => Ok(*idx),
60            _ => Err(SqlError::ColumnNotFound(format!("{table}.{column}"))),
61        }
62    }
63}
64
65pub fn eval_expr(expr: &Expr, col_map: &ColumnMap, row: &[Value]) -> Result<Value> {
66    match expr {
67        Expr::Literal(v) => Ok(v.clone()),
68
69        Expr::Column(name) => {
70            let idx = col_map.resolve(name)?;
71            Ok(row[idx].clone())
72        }
73
74        Expr::QualifiedColumn { table, column } => {
75            let idx = col_map.resolve_qualified(table, column)?;
76            Ok(row[idx].clone())
77        }
78
79        Expr::BinaryOp { left, op, right } => {
80            let lval = eval_expr(left, col_map, row)?;
81            let rval = eval_expr(right, col_map, row)?;
82            eval_binary_op(&lval, *op, &rval)
83        }
84
85        Expr::UnaryOp { op, expr } => {
86            let val = eval_expr(expr, col_map, row)?;
87            eval_unary_op(*op, &val)
88        }
89
90        Expr::IsNull(e) => {
91            let val = eval_expr(e, col_map, row)?;
92            Ok(Value::Boolean(val.is_null()))
93        }
94
95        Expr::IsNotNull(e) => {
96            let val = eval_expr(e, col_map, row)?;
97            Ok(Value::Boolean(!val.is_null()))
98        }
99
100        Expr::Function { name, args } => eval_scalar_function(name, args, col_map, row),
101
102        Expr::CountStar => Err(SqlError::Unsupported(
103            "COUNT(*) in non-aggregate context".into(),
104        )),
105
106        Expr::InList {
107            expr: e,
108            list,
109            negated,
110        } => {
111            let lhs = eval_expr(e, col_map, row)?;
112            eval_in_values(&lhs, list, col_map, row, *negated)
113        }
114
115        Expr::InSet {
116            expr: e,
117            values,
118            has_null,
119            negated,
120        } => {
121            let lhs = eval_expr(e, col_map, row)?;
122            eval_in_set(&lhs, values, *has_null, *negated)
123        }
124
125        Expr::Between {
126            expr: e,
127            low,
128            high,
129            negated,
130        } => {
131            let val = eval_expr(e, col_map, row)?;
132            let lo = eval_expr(low, col_map, row)?;
133            let hi = eval_expr(high, col_map, row)?;
134            eval_between(&val, &lo, &hi, *negated)
135        }
136
137        Expr::Like {
138            expr: e,
139            pattern,
140            escape,
141            negated,
142        } => {
143            let val = eval_expr(e, col_map, row)?;
144            let pat = eval_expr(pattern, col_map, row)?;
145            let esc = escape
146                .as_ref()
147                .map(|e| eval_expr(e, col_map, row))
148                .transpose()?;
149            eval_like(&val, &pat, esc.as_ref(), *negated)
150        }
151
152        Expr::Case {
153            operand,
154            conditions,
155            else_result,
156        } => eval_case(
157            operand.as_deref(),
158            conditions,
159            else_result.as_deref(),
160            col_map,
161            row,
162        ),
163
164        Expr::Coalesce(args) => {
165            for arg in args {
166                let val = eval_expr(arg, col_map, row)?;
167                if !val.is_null() {
168                    return Ok(val);
169                }
170            }
171            Ok(Value::Null)
172        }
173
174        Expr::Cast { expr: e, data_type } => {
175            let val = eval_expr(e, col_map, row)?;
176            eval_cast(&val, *data_type)
177        }
178
179        Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
180            SqlError::Unsupported("subquery not materialized (internal error)".into()),
181        ),
182
183        Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
184
185        Expr::WindowFunction { .. } => Err(SqlError::Unsupported(
186            "window functions are only allowed in SELECT columns".into(),
187        )),
188    }
189}
190
191fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
192    // SQL three-valued logic for AND/OR
193    match op {
194        BinOp::And => return eval_and(left, right),
195        BinOp::Or => return eval_or(left, right),
196        _ => {}
197    }
198
199    // NULL propagation for all other ops (including || per SQL standard)
200    if left.is_null() || right.is_null() {
201        return Ok(Value::Null);
202    }
203
204    match op {
205        BinOp::Eq => Ok(Value::Boolean(left == right)),
206        BinOp::NotEq => Ok(Value::Boolean(left != right)),
207        BinOp::Lt => Ok(Value::Boolean(left < right)),
208        BinOp::Gt => Ok(Value::Boolean(left > right)),
209        BinOp::LtEq => Ok(Value::Boolean(left <= right)),
210        BinOp::GtEq => Ok(Value::Boolean(left >= right)),
211        BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
212        BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
213        BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
214        BinOp::Div => {
215            match right {
216                Value::Integer(0) => return Err(SqlError::DivisionByZero),
217                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
218                _ => {}
219            }
220            eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
221        }
222        BinOp::Mod => {
223            match right {
224                Value::Integer(0) => return Err(SqlError::DivisionByZero),
225                Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
226                _ => {}
227            }
228            eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
229        }
230        BinOp::Concat => {
231            let ls = value_to_text(left);
232            let rs = value_to_text(right);
233            Ok(Value::Text(format!("{ls}{rs}").into()))
234        }
235        BinOp::And | BinOp::Or => unreachable!(),
236    }
237}
238
239/// SQL three-valued AND: NULL AND false = false, NULL AND true = NULL
240fn eval_and(left: &Value, right: &Value) -> Result<Value> {
241    let l = to_bool_or_null(left)?;
242    let r = to_bool_or_null(right)?;
243    match (l, r) {
244        (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
245        (Some(true), Some(true)) => Ok(Value::Boolean(true)),
246        _ => Ok(Value::Null),
247    }
248}
249
250/// SQL three-valued OR: NULL OR true = true, NULL OR false = NULL
251fn eval_or(left: &Value, right: &Value) -> Result<Value> {
252    let l = to_bool_or_null(left)?;
253    let r = to_bool_or_null(right)?;
254    match (l, r) {
255        (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
256        (Some(false), Some(false)) => Ok(Value::Boolean(false)),
257        _ => Ok(Value::Null),
258    }
259}
260
261fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
262    match val {
263        Value::Boolean(b) => Ok(Some(*b)),
264        Value::Null => Ok(None),
265        Value::Integer(i) => Ok(Some(*i != 0)),
266        _ => Err(SqlError::TypeMismatch {
267            expected: "BOOLEAN".into(),
268            got: format!("{}", val.data_type()),
269        }),
270    }
271}
272
273fn eval_arithmetic(
274    left: &Value,
275    right: &Value,
276    int_op: fn(i64, i64) -> Option<i64>,
277    real_op: fn(f64, f64) -> f64,
278) -> Result<Value> {
279    match (left, right) {
280        (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
281            .map(Value::Integer)
282            .ok_or(SqlError::IntegerOverflow),
283        (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
284        (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
285        (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
286        _ => Err(SqlError::TypeMismatch {
287            expected: "numeric".into(),
288            got: format!("{} and {}", left.data_type(), right.data_type()),
289        }),
290    }
291}
292
293fn eval_in_values(
294    lhs: &Value,
295    list: &[Expr],
296    col_map: &ColumnMap,
297    row: &[Value],
298    negated: bool,
299) -> Result<Value> {
300    if list.is_empty() {
301        return Ok(Value::Boolean(negated));
302    }
303    if lhs.is_null() {
304        return Ok(Value::Null);
305    }
306    let mut has_null = false;
307    for item in list {
308        let rhs = eval_expr(item, col_map, row)?;
309        if rhs.is_null() {
310            has_null = true;
311        } else if lhs == &rhs {
312            return Ok(Value::Boolean(!negated));
313        }
314    }
315    if has_null {
316        Ok(Value::Null)
317    } else {
318        Ok(Value::Boolean(negated))
319    }
320}
321
322fn eval_in_set(
323    lhs: &Value,
324    values: &std::collections::HashSet<Value>,
325    has_null: bool,
326    negated: bool,
327) -> Result<Value> {
328    if values.is_empty() && !has_null {
329        return Ok(Value::Boolean(negated));
330    }
331    if lhs.is_null() {
332        return Ok(Value::Null);
333    }
334    if values.contains(lhs) {
335        return Ok(Value::Boolean(!negated));
336    }
337    if has_null {
338        Ok(Value::Null)
339    } else {
340        Ok(Value::Boolean(negated))
341    }
342}
343
344fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
345    if val.is_null() {
346        return Ok(Value::Null);
347    }
348    match op {
349        UnaryOp::Neg => match val {
350            Value::Integer(i) => i
351                .checked_neg()
352                .map(Value::Integer)
353                .ok_or(SqlError::IntegerOverflow),
354            Value::Real(r) => Ok(Value::Real(-r)),
355            _ => Err(SqlError::TypeMismatch {
356                expected: "numeric".into(),
357                got: format!("{}", val.data_type()),
358            }),
359        },
360        UnaryOp::Not => match val {
361            Value::Boolean(b) => Ok(Value::Boolean(!b)),
362            Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
363            _ => Err(SqlError::TypeMismatch {
364                expected: "BOOLEAN".into(),
365                got: format!("{}", val.data_type()),
366            }),
367        },
368    }
369}
370
371fn value_to_text(val: &Value) -> String {
372    match val {
373        Value::Text(s) => s.to_string(),
374        Value::Integer(i) => i.to_string(),
375        Value::Real(r) => {
376            if r.fract() == 0.0 && r.is_finite() {
377                format!("{r:.1}")
378            } else {
379                format!("{r}")
380            }
381        }
382        Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
383        Value::Null => String::new(),
384        Value::Blob(b) => {
385            let mut s = String::with_capacity(b.len() * 2);
386            for byte in b {
387                s.push_str(&format!("{byte:02X}"));
388            }
389            s
390        }
391    }
392}
393
394fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
395    if val.is_null() || low.is_null() || high.is_null() {
396        let ge = if val.is_null() || low.is_null() {
397            None
398        } else {
399            Some(*val >= *low)
400        };
401        let le = if val.is_null() || high.is_null() {
402            None
403        } else {
404            Some(*val <= *high)
405        };
406
407        let result = match (ge, le) {
408            (Some(false), _) | (_, Some(false)) => Some(false),
409            (Some(true), Some(true)) => Some(true),
410            _ => None,
411        };
412
413        return match result {
414            Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
415            None => Ok(Value::Null),
416        };
417    }
418
419    let in_range = *val >= *low && *val <= *high;
420    Ok(Value::Boolean(if negated { !in_range } else { in_range }))
421}
422
423const MAX_LIKE_PATTERN_LEN: usize = 10_000;
424
425fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
426    if val.is_null() || pattern.is_null() {
427        return Ok(Value::Null);
428    }
429    let text = match val {
430        Value::Text(s) => s.as_str(),
431        _ => {
432            return Err(SqlError::TypeMismatch {
433                expected: "TEXT".into(),
434                got: val.data_type().to_string(),
435            })
436        }
437    };
438    let pat = match pattern {
439        Value::Text(s) => s.as_str(),
440        _ => {
441            return Err(SqlError::TypeMismatch {
442                expected: "TEXT".into(),
443                got: pattern.data_type().to_string(),
444            })
445        }
446    };
447
448    if pat.len() > MAX_LIKE_PATTERN_LEN {
449        return Err(SqlError::InvalidValue(format!(
450            "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
451            pat.len()
452        )));
453    }
454
455    let esc_char = match escape {
456        Some(Value::Text(s)) => {
457            let mut chars = s.chars();
458            let c = chars.next().ok_or_else(|| {
459                SqlError::InvalidValue("ESCAPE must be a single character".into())
460            })?;
461            if chars.next().is_some() {
462                return Err(SqlError::InvalidValue(
463                    "ESCAPE must be a single character".into(),
464                ));
465            }
466            Some(c)
467        }
468        Some(Value::Null) => return Ok(Value::Null),
469        Some(_) => {
470            return Err(SqlError::TypeMismatch {
471                expected: "TEXT".into(),
472                got: "non-text".into(),
473            })
474        }
475        None => None,
476    };
477
478    let matched = like_match(text, pat, esc_char);
479    Ok(Value::Boolean(if negated { !matched } else { matched }))
480}
481
482fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
483    let t: Vec<char> = text.chars().collect();
484    let p: Vec<char> = pattern.chars().collect();
485    like_match_impl(&t, &p, 0, 0, escape)
486}
487
488fn like_match_impl(
489    t: &[char],
490    p: &[char],
491    mut ti: usize,
492    mut pi: usize,
493    esc: Option<char>,
494) -> bool {
495    let mut star_pi: Option<usize> = None;
496    let mut star_ti: usize = 0;
497
498    while ti < t.len() {
499        if pi < p.len() {
500            if let Some(ec) = esc {
501                if p[pi] == ec && pi + 1 < p.len() {
502                    pi += 1;
503                    let pc_lower = p[pi].to_ascii_lowercase();
504                    let tc_lower = t[ti].to_ascii_lowercase();
505                    if pc_lower == tc_lower {
506                        pi += 1;
507                        ti += 1;
508                        continue;
509                    } else if let Some(sp) = star_pi {
510                        pi = sp + 1;
511                        star_ti += 1;
512                        ti = star_ti;
513                        continue;
514                    } else {
515                        return false;
516                    }
517                }
518            }
519            if p[pi] == '%' {
520                star_pi = Some(pi);
521                star_ti = ti;
522                pi += 1;
523                continue;
524            }
525            if p[pi] == '_' {
526                pi += 1;
527                ti += 1;
528                continue;
529            }
530            if p[pi].eq_ignore_ascii_case(&t[ti]) {
531                pi += 1;
532                ti += 1;
533                continue;
534            }
535        }
536        if let Some(sp) = star_pi {
537            pi = sp + 1;
538            star_ti += 1;
539            ti = star_ti;
540        } else {
541            return false;
542        }
543    }
544
545    while pi < p.len() && p[pi] == '%' {
546        pi += 1;
547    }
548    pi == p.len()
549}
550
551fn eval_case(
552    operand: Option<&Expr>,
553    conditions: &[(Expr, Expr)],
554    else_result: Option<&Expr>,
555    col_map: &ColumnMap,
556    row: &[Value],
557) -> Result<Value> {
558    if let Some(op_expr) = operand {
559        let op_val = eval_expr(op_expr, col_map, row)?;
560        for (cond, result) in conditions {
561            let cond_val = eval_expr(cond, col_map, row)?;
562            if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
563                return eval_expr(result, col_map, row);
564            }
565        }
566    } else {
567        for (cond, result) in conditions {
568            let cond_val = eval_expr(cond, col_map, row)?;
569            if is_truthy(&cond_val) {
570                return eval_expr(result, col_map, row);
571            }
572        }
573    }
574    match else_result {
575        Some(e) => eval_expr(e, col_map, row),
576        None => Ok(Value::Null),
577    }
578}
579
580fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
581    if val.is_null() {
582        return Ok(Value::Null);
583    }
584    match target {
585        DataType::Integer => match val {
586            Value::Integer(_) => Ok(val.clone()),
587            Value::Real(r) => Ok(Value::Integer(*r as i64)),
588            Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
589            Value::Text(s) => s
590                .trim()
591                .parse::<i64>()
592                .map(Value::Integer)
593                .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
594                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
595            _ => Err(SqlError::InvalidValue(format!(
596                "cannot cast {} to INTEGER",
597                val.data_type()
598            ))),
599        },
600        DataType::Real => match val {
601            Value::Real(_) => Ok(val.clone()),
602            Value::Integer(i) => Ok(Value::Real(*i as f64)),
603            Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
604            Value::Text(s) => s
605                .trim()
606                .parse::<f64>()
607                .map(Value::Real)
608                .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
609            _ => Err(SqlError::InvalidValue(format!(
610                "cannot cast {} to REAL",
611                val.data_type()
612            ))),
613        },
614        DataType::Text => Ok(Value::Text(value_to_text(val).into())),
615        DataType::Boolean => match val {
616            Value::Boolean(_) => Ok(val.clone()),
617            Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
618            Value::Text(s) => {
619                let lower = s.trim().to_ascii_lowercase();
620                match lower.as_str() {
621                    "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
622                    "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
623                    _ => Err(SqlError::InvalidValue(format!(
624                        "cannot cast '{s}' to BOOLEAN"
625                    ))),
626                }
627            }
628            _ => Err(SqlError::InvalidValue(format!(
629                "cannot cast {} to BOOLEAN",
630                val.data_type()
631            ))),
632        },
633        DataType::Blob => match val {
634            Value::Blob(_) => Ok(val.clone()),
635            Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
636            _ => Err(SqlError::InvalidValue(format!(
637                "cannot cast {} to BLOB",
638                val.data_type()
639            ))),
640        },
641        DataType::Null => Ok(Value::Null),
642    }
643}
644
645fn eval_scalar_function(
646    name: &str,
647    args: &[Expr],
648    col_map: &ColumnMap,
649    row: &[Value],
650) -> Result<Value> {
651    let evaluated: Vec<Value> = args
652        .iter()
653        .map(|a| eval_expr(a, col_map, row))
654        .collect::<Result<Vec<_>>>()?;
655
656    match name {
657        "LENGTH" => {
658            check_args(name, &evaluated, 1)?;
659            match &evaluated[0] {
660                Value::Null => Ok(Value::Null),
661                Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
662                Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
663                _ => Ok(Value::Integer(
664                    value_to_text(&evaluated[0]).chars().count() as i64
665                )),
666            }
667        }
668        "UPPER" => {
669            check_args(name, &evaluated, 1)?;
670            match &evaluated[0] {
671                Value::Null => Ok(Value::Null),
672                Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
673                _ => Ok(Value::Text(
674                    value_to_text(&evaluated[0]).to_ascii_uppercase().into(),
675                )),
676            }
677        }
678        "LOWER" => {
679            check_args(name, &evaluated, 1)?;
680            match &evaluated[0] {
681                Value::Null => Ok(Value::Null),
682                Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
683                _ => Ok(Value::Text(
684                    value_to_text(&evaluated[0]).to_ascii_lowercase().into(),
685                )),
686            }
687        }
688        "SUBSTR" | "SUBSTRING" => {
689            if evaluated.len() < 2 || evaluated.len() > 3 {
690                return Err(SqlError::InvalidValue(format!(
691                    "{name} requires 2 or 3 arguments"
692                )));
693            }
694            if evaluated.iter().any(|v| v.is_null()) {
695                return Ok(Value::Null);
696            }
697            let s = value_to_text(&evaluated[0]);
698            let chars: Vec<char> = s.chars().collect();
699            let start = match &evaluated[1] {
700                Value::Integer(i) => *i,
701                _ => {
702                    return Err(SqlError::TypeMismatch {
703                        expected: "INTEGER".into(),
704                        got: evaluated[1].data_type().to_string(),
705                    })
706                }
707            };
708            let len = chars.len() as i64;
709
710            let (begin, count) = if evaluated.len() == 3 {
711                let cnt = match &evaluated[2] {
712                    Value::Integer(i) => *i,
713                    _ => {
714                        return Err(SqlError::TypeMismatch {
715                            expected: "INTEGER".into(),
716                            got: evaluated[2].data_type().to_string(),
717                        })
718                    }
719                };
720                if start >= 1 {
721                    let b = (start - 1).min(len) as usize;
722                    let c = cnt.max(0) as usize;
723                    (b, c)
724                } else if start == 0 {
725                    let c = (cnt - 1).max(0) as usize;
726                    (0usize, c)
727                } else {
728                    let adjusted_cnt = (cnt + start - 1).max(0) as usize;
729                    (0usize, adjusted_cnt)
730                }
731            } else if start >= 1 {
732                let b = (start - 1).min(len) as usize;
733                (b, chars.len() - b)
734            } else if start == 0 {
735                (0usize, chars.len())
736            } else {
737                let b = (len + start).max(0) as usize;
738                (b, chars.len() - b)
739            };
740
741            let result: String = chars.iter().skip(begin).take(count).collect();
742            Ok(Value::Text(result.into()))
743        }
744        "TRIM" | "LTRIM" | "RTRIM" => {
745            if evaluated.is_empty() || evaluated.len() > 2 {
746                return Err(SqlError::InvalidValue(format!(
747                    "{name} requires 1 or 2 arguments"
748                )));
749            }
750            if evaluated[0].is_null() {
751                return Ok(Value::Null);
752            }
753            let s = value_to_text(&evaluated[0]);
754            let trim_chars: Vec<char> = if evaluated.len() == 2 {
755                if evaluated[1].is_null() {
756                    return Ok(Value::Null);
757                }
758                value_to_text(&evaluated[1]).chars().collect()
759            } else {
760                vec![' ']
761            };
762            let result = match name {
763                "TRIM" => s
764                    .trim_matches(|c: char| trim_chars.contains(&c))
765                    .to_string(),
766                "LTRIM" => s
767                    .trim_start_matches(|c: char| trim_chars.contains(&c))
768                    .to_string(),
769                "RTRIM" => s
770                    .trim_end_matches(|c: char| trim_chars.contains(&c))
771                    .to_string(),
772                _ => unreachable!(),
773            };
774            Ok(Value::Text(result.into()))
775        }
776        "REPLACE" => {
777            check_args(name, &evaluated, 3)?;
778            if evaluated.iter().any(|v| v.is_null()) {
779                return Ok(Value::Null);
780            }
781            let s = value_to_text(&evaluated[0]);
782            let from = value_to_text(&evaluated[1]);
783            let to = value_to_text(&evaluated[2]);
784            if from.is_empty() {
785                return Ok(Value::Text(s.into()));
786            }
787            Ok(Value::Text(s.replace(&from, &to).into()))
788        }
789        "INSTR" => {
790            check_args(name, &evaluated, 2)?;
791            if evaluated.iter().any(|v| v.is_null()) {
792                return Ok(Value::Null);
793            }
794            let haystack = value_to_text(&evaluated[0]);
795            let needle = value_to_text(&evaluated[1]);
796            let pos = haystack
797                .find(&needle)
798                .map(|i| haystack[..i].chars().count() as i64 + 1)
799                .unwrap_or(0);
800            Ok(Value::Integer(pos))
801        }
802        "CONCAT" => {
803            if evaluated.is_empty() {
804                return Ok(Value::Text(CompactString::default()));
805            }
806            let mut result = String::new();
807            for v in &evaluated {
808                match v {
809                    Value::Null => {}
810                    _ => result.push_str(&value_to_text(v)),
811                }
812            }
813            Ok(Value::Text(result.into()))
814        }
815        "ABS" => {
816            check_args(name, &evaluated, 1)?;
817            match &evaluated[0] {
818                Value::Null => Ok(Value::Null),
819                Value::Integer(i) => i
820                    .checked_abs()
821                    .map(Value::Integer)
822                    .ok_or(SqlError::IntegerOverflow),
823                Value::Real(r) => Ok(Value::Real(r.abs())),
824                _ => Err(SqlError::TypeMismatch {
825                    expected: "numeric".into(),
826                    got: evaluated[0].data_type().to_string(),
827                }),
828            }
829        }
830        "ROUND" => {
831            if evaluated.is_empty() || evaluated.len() > 2 {
832                return Err(SqlError::InvalidValue(
833                    "ROUND requires 1 or 2 arguments".into(),
834                ));
835            }
836            if evaluated[0].is_null() {
837                return Ok(Value::Null);
838            }
839            let val = match &evaluated[0] {
840                Value::Integer(i) => *i as f64,
841                Value::Real(r) => *r,
842                _ => {
843                    return Err(SqlError::TypeMismatch {
844                        expected: "numeric".into(),
845                        got: evaluated[0].data_type().to_string(),
846                    })
847                }
848            };
849            let places = if evaluated.len() == 2 {
850                match &evaluated[1] {
851                    Value::Null => return Ok(Value::Null),
852                    Value::Integer(i) => *i,
853                    _ => {
854                        return Err(SqlError::TypeMismatch {
855                            expected: "INTEGER".into(),
856                            got: evaluated[1].data_type().to_string(),
857                        })
858                    }
859                }
860            } else {
861                0
862            };
863            let factor = 10f64.powi(places as i32);
864            let rounded = (val * factor).round() / factor;
865            Ok(Value::Real(rounded))
866        }
867        "CEIL" | "CEILING" => {
868            check_args(name, &evaluated, 1)?;
869            match &evaluated[0] {
870                Value::Null => Ok(Value::Null),
871                Value::Integer(i) => Ok(Value::Integer(*i)),
872                Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
873                _ => Err(SqlError::TypeMismatch {
874                    expected: "numeric".into(),
875                    got: evaluated[0].data_type().to_string(),
876                }),
877            }
878        }
879        "FLOOR" => {
880            check_args(name, &evaluated, 1)?;
881            match &evaluated[0] {
882                Value::Null => Ok(Value::Null),
883                Value::Integer(i) => Ok(Value::Integer(*i)),
884                Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
885                _ => Err(SqlError::TypeMismatch {
886                    expected: "numeric".into(),
887                    got: evaluated[0].data_type().to_string(),
888                }),
889            }
890        }
891        "SIGN" => {
892            check_args(name, &evaluated, 1)?;
893            match &evaluated[0] {
894                Value::Null => Ok(Value::Null),
895                Value::Integer(i) => Ok(Value::Integer(i.signum())),
896                Value::Real(r) => {
897                    if *r > 0.0 {
898                        Ok(Value::Integer(1))
899                    } else if *r < 0.0 {
900                        Ok(Value::Integer(-1))
901                    } else {
902                        Ok(Value::Integer(0))
903                    }
904                }
905                _ => Err(SqlError::TypeMismatch {
906                    expected: "numeric".into(),
907                    got: evaluated[0].data_type().to_string(),
908                }),
909            }
910        }
911        "SQRT" => {
912            check_args(name, &evaluated, 1)?;
913            match &evaluated[0] {
914                Value::Null => Ok(Value::Null),
915                Value::Integer(i) => {
916                    if *i < 0 {
917                        Ok(Value::Null)
918                    } else {
919                        Ok(Value::Real((*i as f64).sqrt()))
920                    }
921                }
922                Value::Real(r) => {
923                    if *r < 0.0 {
924                        Ok(Value::Null)
925                    } else {
926                        Ok(Value::Real(r.sqrt()))
927                    }
928                }
929                _ => Err(SqlError::TypeMismatch {
930                    expected: "numeric".into(),
931                    got: evaluated[0].data_type().to_string(),
932                }),
933            }
934        }
935        "RANDOM" => {
936            check_args(name, &evaluated, 0)?;
937            use std::collections::hash_map::DefaultHasher;
938            use std::hash::{Hash, Hasher};
939            use std::time::SystemTime;
940            let mut hasher = DefaultHasher::new();
941            SystemTime::now().hash(&mut hasher);
942            std::thread::current().id().hash(&mut hasher);
943            let mut val = hasher.finish() as i64;
944            if val == i64::MIN {
945                val = i64::MAX;
946            }
947            Ok(Value::Integer(val))
948        }
949        "TYPEOF" => {
950            check_args(name, &evaluated, 1)?;
951            let type_name = match &evaluated[0] {
952                Value::Null => "null",
953                Value::Integer(_) => "integer",
954                Value::Real(_) => "real",
955                Value::Text(_) => "text",
956                Value::Blob(_) => "blob",
957                Value::Boolean(_) => "boolean",
958            };
959            Ok(Value::Text(type_name.into()))
960        }
961        "MIN" => {
962            check_args(name, &evaluated, 2)?;
963            if evaluated[0].is_null() {
964                return Ok(evaluated[1].clone());
965            }
966            if evaluated[1].is_null() {
967                return Ok(evaluated[0].clone());
968            }
969            if evaluated[0] <= evaluated[1] {
970                Ok(evaluated[0].clone())
971            } else {
972                Ok(evaluated[1].clone())
973            }
974        }
975        "MAX" => {
976            check_args(name, &evaluated, 2)?;
977            if evaluated[0].is_null() {
978                return Ok(evaluated[1].clone());
979            }
980            if evaluated[1].is_null() {
981                return Ok(evaluated[0].clone());
982            }
983            if evaluated[0] >= evaluated[1] {
984                Ok(evaluated[0].clone())
985            } else {
986                Ok(evaluated[1].clone())
987            }
988        }
989        "HEX" => {
990            check_args(name, &evaluated, 1)?;
991            match &evaluated[0] {
992                Value::Null => Ok(Value::Null),
993                Value::Blob(b) => {
994                    let mut s = String::with_capacity(b.len() * 2);
995                    for byte in b {
996                        s.push_str(&format!("{byte:02X}"));
997                    }
998                    Ok(Value::Text(s.into()))
999                }
1000                Value::Text(s) => {
1001                    let mut r = String::with_capacity(s.len() * 2);
1002                    for byte in s.as_bytes() {
1003                        r.push_str(&format!("{byte:02X}"));
1004                    }
1005                    Ok(Value::Text(r.into()))
1006                }
1007                _ => Ok(Value::Text(value_to_text(&evaluated[0]).into())),
1008            }
1009        }
1010        _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
1011    }
1012}
1013
1014fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
1015    if args.len() != expected {
1016        Err(SqlError::InvalidValue(format!(
1017            "{name} requires {expected} argument(s), got {}",
1018            args.len()
1019        )))
1020    } else {
1021        Ok(())
1022    }
1023}
1024
1025pub fn referenced_columns(expr: &Expr, columns: &[ColumnDef]) -> Vec<usize> {
1026    let mut indices = Vec::new();
1027    collect_column_refs(expr, columns, &mut indices);
1028    indices.sort_unstable();
1029    indices.dedup();
1030    indices
1031}
1032
1033fn collect_column_refs(expr: &Expr, columns: &[ColumnDef], out: &mut Vec<usize>) {
1034    match expr {
1035        Expr::Column(name) => {
1036            for (i, c) in columns.iter().enumerate() {
1037                if c.name == *name || c.name.ends_with(&format!(".{name}")) {
1038                    out.push(i);
1039                    break;
1040                }
1041            }
1042        }
1043        Expr::QualifiedColumn { table, column } => {
1044            let qualified = format!("{table}.{column}");
1045            if let Some(idx) = columns.iter().position(|c| c.name == qualified) {
1046                out.push(idx);
1047            } else {
1048                let matches: Vec<usize> = columns
1049                    .iter()
1050                    .enumerate()
1051                    .filter(|(_, c)| c.name == *column)
1052                    .map(|(i, _)| i)
1053                    .collect();
1054                if matches.len() == 1 {
1055                    out.push(matches[0]);
1056                }
1057            }
1058        }
1059        Expr::BinaryOp { left, right, .. } => {
1060            collect_column_refs(left, columns, out);
1061            collect_column_refs(right, columns, out);
1062        }
1063        Expr::UnaryOp { expr, .. } => {
1064            collect_column_refs(expr, columns, out);
1065        }
1066        Expr::IsNull(e) | Expr::IsNotNull(e) => {
1067            collect_column_refs(e, columns, out);
1068        }
1069        Expr::Function { args, .. } => {
1070            for arg in args {
1071                collect_column_refs(arg, columns, out);
1072            }
1073        }
1074        Expr::InSubquery { expr, .. } => {
1075            collect_column_refs(expr, columns, out);
1076        }
1077        Expr::InList { expr, list, .. } => {
1078            collect_column_refs(expr, columns, out);
1079            for item in list {
1080                collect_column_refs(item, columns, out);
1081            }
1082        }
1083        Expr::InSet { expr, .. } => {
1084            collect_column_refs(expr, columns, out);
1085        }
1086        Expr::Between {
1087            expr, low, high, ..
1088        } => {
1089            collect_column_refs(expr, columns, out);
1090            collect_column_refs(low, columns, out);
1091            collect_column_refs(high, columns, out);
1092        }
1093        Expr::Like {
1094            expr,
1095            pattern,
1096            escape,
1097            ..
1098        } => {
1099            collect_column_refs(expr, columns, out);
1100            collect_column_refs(pattern, columns, out);
1101            if let Some(esc) = escape {
1102                collect_column_refs(esc, columns, out);
1103            }
1104        }
1105        Expr::Case {
1106            operand,
1107            conditions,
1108            else_result,
1109        } => {
1110            if let Some(op) = operand {
1111                collect_column_refs(op, columns, out);
1112            }
1113            for (when, then) in conditions {
1114                collect_column_refs(when, columns, out);
1115                collect_column_refs(then, columns, out);
1116            }
1117            if let Some(e) = else_result {
1118                collect_column_refs(e, columns, out);
1119            }
1120        }
1121        Expr::Coalesce(args) => {
1122            for arg in args {
1123                collect_column_refs(arg, columns, out);
1124            }
1125        }
1126        Expr::Cast { expr, .. } => {
1127            collect_column_refs(expr, columns, out);
1128        }
1129        Expr::WindowFunction { args, spec, .. } => {
1130            for arg in args {
1131                collect_column_refs(arg, columns, out);
1132            }
1133            for pb in &spec.partition_by {
1134                collect_column_refs(pb, columns, out);
1135            }
1136            for ob in &spec.order_by {
1137                collect_column_refs(&ob.expr, columns, out);
1138            }
1139        }
1140        Expr::Literal(_)
1141        | Expr::Parameter(_)
1142        | Expr::CountStar
1143        | Expr::Exists { .. }
1144        | Expr::ScalarSubquery(_) => {}
1145    }
1146}
1147
1148/// Check if an expression result is truthy (for WHERE/HAVING).
1149pub fn is_truthy(val: &Value) -> bool {
1150    match val {
1151        Value::Boolean(b) => *b,
1152        Value::Integer(i) => *i != 0,
1153        Value::Null => false,
1154        _ => true,
1155    }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160    use super::*;
1161    use crate::types::DataType;
1162
1163    fn col(name: &str, dt: DataType, nullable: bool, pos: u16) -> ColumnDef {
1164        ColumnDef {
1165            name: name.into(),
1166            data_type: dt,
1167            nullable,
1168            position: pos,
1169            default_expr: None,
1170            default_sql: None,
1171            check_expr: None,
1172            check_sql: None,
1173            check_name: None,
1174        }
1175    }
1176
1177    fn test_columns() -> Vec<ColumnDef> {
1178        vec![
1179            col("id", DataType::Integer, false, 0),
1180            col("name", DataType::Text, true, 1),
1181            col("score", DataType::Real, true, 2),
1182            col("active", DataType::Boolean, false, 3),
1183        ]
1184    }
1185
1186    fn test_row() -> Vec<Value> {
1187        vec![
1188            Value::Integer(1),
1189            Value::Text("Alice".into()),
1190            Value::Real(95.5),
1191            Value::Boolean(true),
1192        ]
1193    }
1194
1195    #[test]
1196    fn eval_literal() {
1197        let cols = test_columns();
1198        let cm = ColumnMap::new(&cols);
1199        let row = test_row();
1200        let expr = Expr::Literal(Value::Integer(42));
1201        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(42));
1202    }
1203
1204    #[test]
1205    fn eval_column_ref() {
1206        let cols = test_columns();
1207        let cm = ColumnMap::new(&cols);
1208        let row = test_row();
1209        let expr = Expr::Column("name".into());
1210        assert_eq!(
1211            eval_expr(&expr, &cm, &row).unwrap(),
1212            Value::Text("Alice".into())
1213        );
1214    }
1215
1216    #[test]
1217    fn eval_column_case_insensitive() {
1218        let cols = test_columns();
1219        let cm = ColumnMap::new(&cols);
1220        let row = test_row();
1221        let expr = Expr::Column("name".into());
1222        assert_eq!(
1223            eval_expr(&expr, &cm, &row).unwrap(),
1224            Value::Text("Alice".into())
1225        );
1226    }
1227
1228    #[test]
1229    fn eval_arithmetic_int() {
1230        let cols = test_columns();
1231        let cm = ColumnMap::new(&cols);
1232        let row = test_row();
1233        let expr = Expr::BinaryOp {
1234            left: Box::new(Expr::Column("id".into())),
1235            op: BinOp::Add,
1236            right: Box::new(Expr::Literal(Value::Integer(10))),
1237        };
1238        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(11));
1239    }
1240
1241    #[test]
1242    fn eval_comparison() {
1243        let cols = test_columns();
1244        let cm = ColumnMap::new(&cols);
1245        let row = test_row();
1246        let expr = Expr::BinaryOp {
1247            left: Box::new(Expr::Column("score".into())),
1248            op: BinOp::Gt,
1249            right: Box::new(Expr::Literal(Value::Real(90.0))),
1250        };
1251        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1252    }
1253
1254    #[test]
1255    fn eval_null_propagation() {
1256        let cols = test_columns();
1257        let cm = ColumnMap::new(&cols);
1258        let row = vec![
1259            Value::Integer(1),
1260            Value::Null,
1261            Value::Null,
1262            Value::Boolean(true),
1263        ];
1264        let expr = Expr::BinaryOp {
1265            left: Box::new(Expr::Column("name".into())),
1266            op: BinOp::Eq,
1267            right: Box::new(Expr::Literal(Value::Text("test".into()))),
1268        };
1269        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1270    }
1271
1272    #[test]
1273    fn eval_and_three_valued() {
1274        let cols = test_columns();
1275        let cm = ColumnMap::new(&cols);
1276        let row = vec![
1277            Value::Integer(1),
1278            Value::Null,
1279            Value::Null,
1280            Value::Boolean(true),
1281        ];
1282
1283        // NULL AND false = false
1284        let expr = Expr::BinaryOp {
1285            left: Box::new(Expr::Column("name".into())),
1286            op: BinOp::And,
1287            right: Box::new(Expr::Literal(Value::Boolean(false))),
1288        };
1289        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1290
1291        // NULL AND true = NULL
1292        let expr = Expr::BinaryOp {
1293            left: Box::new(Expr::Column("name".into())),
1294            op: BinOp::And,
1295            right: Box::new(Expr::Literal(Value::Boolean(true))),
1296        };
1297        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1298    }
1299
1300    #[test]
1301    fn eval_or_three_valued() {
1302        let cols = test_columns();
1303        let cm = ColumnMap::new(&cols);
1304        let row = vec![
1305            Value::Integer(1),
1306            Value::Null,
1307            Value::Null,
1308            Value::Boolean(true),
1309        ];
1310
1311        // NULL OR true = true
1312        let expr = Expr::BinaryOp {
1313            left: Box::new(Expr::Column("name".into())),
1314            op: BinOp::Or,
1315            right: Box::new(Expr::Literal(Value::Boolean(true))),
1316        };
1317        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1318
1319        // NULL OR false = NULL
1320        let expr = Expr::BinaryOp {
1321            left: Box::new(Expr::Column("name".into())),
1322            op: BinOp::Or,
1323            right: Box::new(Expr::Literal(Value::Boolean(false))),
1324        };
1325        assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1326    }
1327
1328    #[test]
1329    fn eval_is_null() {
1330        let cols = test_columns();
1331        let cm = ColumnMap::new(&cols);
1332        let row = vec![
1333            Value::Integer(1),
1334            Value::Null,
1335            Value::Null,
1336            Value::Boolean(true),
1337        ];
1338        let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1339        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1340
1341        let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1342        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1343    }
1344
1345    #[test]
1346    fn eval_not() {
1347        let cols = test_columns();
1348        let cm = ColumnMap::new(&cols);
1349        let row = test_row();
1350        let expr = Expr::UnaryOp {
1351            op: UnaryOp::Not,
1352            expr: Box::new(Expr::Column("active".into())),
1353        };
1354        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1355    }
1356
1357    #[test]
1358    fn eval_neg() {
1359        let cols = test_columns();
1360        let cm = ColumnMap::new(&cols);
1361        let row = test_row();
1362        let expr = Expr::UnaryOp {
1363            op: UnaryOp::Neg,
1364            expr: Box::new(Expr::Column("id".into())),
1365        };
1366        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(-1));
1367    }
1368
1369    #[test]
1370    fn eval_division_by_zero() {
1371        let cols = test_columns();
1372        let cm = ColumnMap::new(&cols);
1373        let row = test_row();
1374        let expr = Expr::BinaryOp {
1375            left: Box::new(Expr::Column("id".into())),
1376            op: BinOp::Div,
1377            right: Box::new(Expr::Literal(Value::Integer(0))),
1378        };
1379        assert!(matches!(
1380            eval_expr(&expr, &cm, &row),
1381            Err(SqlError::DivisionByZero)
1382        ));
1383    }
1384
1385    #[test]
1386    fn eval_mixed_numeric() {
1387        let cols = test_columns();
1388        let cm = ColumnMap::new(&cols);
1389        let row = test_row();
1390        // id (int 1) + score (real 95.5) = real 96.5
1391        let expr = Expr::BinaryOp {
1392            left: Box::new(Expr::Column("id".into())),
1393            op: BinOp::Add,
1394            right: Box::new(Expr::Column("score".into())),
1395        };
1396        assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Real(96.5));
1397    }
1398
1399    #[test]
1400    fn is_truthy_values() {
1401        assert!(is_truthy(&Value::Boolean(true)));
1402        assert!(!is_truthy(&Value::Boolean(false)));
1403        assert!(!is_truthy(&Value::Null));
1404        assert!(is_truthy(&Value::Integer(1)));
1405        assert!(!is_truthy(&Value::Integer(0)));
1406    }
1407}