1use std::collections::HashMap;
4
5use crate::error::{Result, SqlError};
6use crate::parser::{BinOp, Expr, UnaryOp};
7use crate::types::{ColumnDef, CompactString, DataType, Value};
8
9pub struct ColumnMap {
10 exact: HashMap<String, usize>,
11 short: HashMap<String, ShortMatch>,
12}
13
14enum ShortMatch {
15 Unique(usize),
16 Ambiguous,
17}
18
19impl ColumnMap {
20 pub fn new(columns: &[ColumnDef]) -> Self {
21 let mut exact = HashMap::with_capacity(columns.len() * 2);
22 let mut short: HashMap<String, ShortMatch> = HashMap::with_capacity(columns.len());
23
24 for (i, col) in columns.iter().enumerate() {
25 let lower = col.name.to_ascii_lowercase();
26 exact.insert(lower.clone(), i);
27
28 let unqualified = if let Some(dot) = lower.rfind('.') {
29 &lower[dot + 1..]
30 } else {
31 &lower
32 };
33 short
34 .entry(unqualified.to_string())
35 .and_modify(|e| *e = ShortMatch::Ambiguous)
36 .or_insert(ShortMatch::Unique(i));
37 }
38
39 Self { exact, short }
40 }
41
42 pub(crate) fn resolve(&self, name: &str) -> Result<usize> {
43 if let Some(&idx) = self.exact.get(name) {
44 return Ok(idx);
45 }
46 match self.short.get(name) {
47 Some(ShortMatch::Unique(idx)) => Ok(*idx),
48 Some(ShortMatch::Ambiguous) => Err(SqlError::AmbiguousColumn(name.to_string())),
49 None => Err(SqlError::ColumnNotFound(name.to_string())),
50 }
51 }
52
53 pub(crate) fn resolve_qualified(&self, table: &str, column: &str) -> Result<usize> {
54 let qualified = format!("{table}.{column}");
55 if let Some(&idx) = self.exact.get(&qualified) {
56 return Ok(idx);
57 }
58 match self.short.get(column) {
59 Some(ShortMatch::Unique(idx)) => Ok(*idx),
60 _ => Err(SqlError::ColumnNotFound(format!("{table}.{column}"))),
61 }
62 }
63}
64
65pub fn eval_expr(expr: &Expr, col_map: &ColumnMap, row: &[Value]) -> Result<Value> {
66 match expr {
67 Expr::Literal(v) => Ok(v.clone()),
68
69 Expr::Column(name) => {
70 let idx = col_map.resolve(name)?;
71 Ok(row[idx].clone())
72 }
73
74 Expr::QualifiedColumn { table, column } => {
75 let idx = col_map.resolve_qualified(table, column)?;
76 Ok(row[idx].clone())
77 }
78
79 Expr::BinaryOp { left, op, right } => {
80 let lval = eval_expr(left, col_map, row)?;
81 let rval = eval_expr(right, col_map, row)?;
82 eval_binary_op(&lval, *op, &rval)
83 }
84
85 Expr::UnaryOp { op, expr } => {
86 let val = eval_expr(expr, col_map, row)?;
87 eval_unary_op(*op, &val)
88 }
89
90 Expr::IsNull(e) => {
91 let val = eval_expr(e, col_map, row)?;
92 Ok(Value::Boolean(val.is_null()))
93 }
94
95 Expr::IsNotNull(e) => {
96 let val = eval_expr(e, col_map, row)?;
97 Ok(Value::Boolean(!val.is_null()))
98 }
99
100 Expr::Function { name, args } => eval_scalar_function(name, args, col_map, row),
101
102 Expr::CountStar => Err(SqlError::Unsupported(
103 "COUNT(*) in non-aggregate context".into(),
104 )),
105
106 Expr::InList {
107 expr: e,
108 list,
109 negated,
110 } => {
111 let lhs = eval_expr(e, col_map, row)?;
112 eval_in_values(&lhs, list, col_map, row, *negated)
113 }
114
115 Expr::InSet {
116 expr: e,
117 values,
118 has_null,
119 negated,
120 } => {
121 let lhs = eval_expr(e, col_map, row)?;
122 eval_in_set(&lhs, values, *has_null, *negated)
123 }
124
125 Expr::Between {
126 expr: e,
127 low,
128 high,
129 negated,
130 } => {
131 let val = eval_expr(e, col_map, row)?;
132 let lo = eval_expr(low, col_map, row)?;
133 let hi = eval_expr(high, col_map, row)?;
134 eval_between(&val, &lo, &hi, *negated)
135 }
136
137 Expr::Like {
138 expr: e,
139 pattern,
140 escape,
141 negated,
142 } => {
143 let val = eval_expr(e, col_map, row)?;
144 let pat = eval_expr(pattern, col_map, row)?;
145 let esc = escape
146 .as_ref()
147 .map(|e| eval_expr(e, col_map, row))
148 .transpose()?;
149 eval_like(&val, &pat, esc.as_ref(), *negated)
150 }
151
152 Expr::Case {
153 operand,
154 conditions,
155 else_result,
156 } => eval_case(
157 operand.as_deref(),
158 conditions,
159 else_result.as_deref(),
160 col_map,
161 row,
162 ),
163
164 Expr::Coalesce(args) => {
165 for arg in args {
166 let val = eval_expr(arg, col_map, row)?;
167 if !val.is_null() {
168 return Ok(val);
169 }
170 }
171 Ok(Value::Null)
172 }
173
174 Expr::Cast { expr: e, data_type } => {
175 let val = eval_expr(e, col_map, row)?;
176 eval_cast(&val, *data_type)
177 }
178
179 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
180 SqlError::Unsupported("subquery not materialized (internal error)".into()),
181 ),
182
183 Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
184 }
185}
186
187fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
188 match op {
190 BinOp::And => return eval_and(left, right),
191 BinOp::Or => return eval_or(left, right),
192 _ => {}
193 }
194
195 if left.is_null() || right.is_null() {
197 return Ok(Value::Null);
198 }
199
200 match op {
201 BinOp::Eq => Ok(Value::Boolean(left == right)),
202 BinOp::NotEq => Ok(Value::Boolean(left != right)),
203 BinOp::Lt => Ok(Value::Boolean(left < right)),
204 BinOp::Gt => Ok(Value::Boolean(left > right)),
205 BinOp::LtEq => Ok(Value::Boolean(left <= right)),
206 BinOp::GtEq => Ok(Value::Boolean(left >= right)),
207 BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
208 BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
209 BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
210 BinOp::Div => {
211 match right {
212 Value::Integer(0) => return Err(SqlError::DivisionByZero),
213 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
214 _ => {}
215 }
216 eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
217 }
218 BinOp::Mod => {
219 match right {
220 Value::Integer(0) => return Err(SqlError::DivisionByZero),
221 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
222 _ => {}
223 }
224 eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
225 }
226 BinOp::Concat => {
227 let ls = value_to_text(left);
228 let rs = value_to_text(right);
229 Ok(Value::Text(format!("{ls}{rs}").into()))
230 }
231 BinOp::And | BinOp::Or => unreachable!(),
232 }
233}
234
235fn eval_and(left: &Value, right: &Value) -> Result<Value> {
237 let l = to_bool_or_null(left)?;
238 let r = to_bool_or_null(right)?;
239 match (l, r) {
240 (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
241 (Some(true), Some(true)) => Ok(Value::Boolean(true)),
242 _ => Ok(Value::Null),
243 }
244}
245
246fn eval_or(left: &Value, right: &Value) -> Result<Value> {
248 let l = to_bool_or_null(left)?;
249 let r = to_bool_or_null(right)?;
250 match (l, r) {
251 (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
252 (Some(false), Some(false)) => Ok(Value::Boolean(false)),
253 _ => Ok(Value::Null),
254 }
255}
256
257fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
258 match val {
259 Value::Boolean(b) => Ok(Some(*b)),
260 Value::Null => Ok(None),
261 Value::Integer(i) => Ok(Some(*i != 0)),
262 _ => Err(SqlError::TypeMismatch {
263 expected: "BOOLEAN".into(),
264 got: format!("{}", val.data_type()),
265 }),
266 }
267}
268
269fn eval_arithmetic(
270 left: &Value,
271 right: &Value,
272 int_op: fn(i64, i64) -> Option<i64>,
273 real_op: fn(f64, f64) -> f64,
274) -> Result<Value> {
275 match (left, right) {
276 (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
277 .map(Value::Integer)
278 .ok_or(SqlError::IntegerOverflow),
279 (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
280 (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
281 (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
282 _ => Err(SqlError::TypeMismatch {
283 expected: "numeric".into(),
284 got: format!("{} and {}", left.data_type(), right.data_type()),
285 }),
286 }
287}
288
289fn eval_in_values(
290 lhs: &Value,
291 list: &[Expr],
292 col_map: &ColumnMap,
293 row: &[Value],
294 negated: bool,
295) -> Result<Value> {
296 if list.is_empty() {
297 return Ok(Value::Boolean(negated));
298 }
299 if lhs.is_null() {
300 return Ok(Value::Null);
301 }
302 let mut has_null = false;
303 for item in list {
304 let rhs = eval_expr(item, col_map, row)?;
305 if rhs.is_null() {
306 has_null = true;
307 } else if lhs == &rhs {
308 return Ok(Value::Boolean(!negated));
309 }
310 }
311 if has_null {
312 Ok(Value::Null)
313 } else {
314 Ok(Value::Boolean(negated))
315 }
316}
317
318fn eval_in_set(
319 lhs: &Value,
320 values: &std::collections::HashSet<Value>,
321 has_null: bool,
322 negated: bool,
323) -> Result<Value> {
324 if values.is_empty() && !has_null {
325 return Ok(Value::Boolean(negated));
326 }
327 if lhs.is_null() {
328 return Ok(Value::Null);
329 }
330 if values.contains(lhs) {
331 return Ok(Value::Boolean(!negated));
332 }
333 if has_null {
334 Ok(Value::Null)
335 } else {
336 Ok(Value::Boolean(negated))
337 }
338}
339
340fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
341 if val.is_null() {
342 return Ok(Value::Null);
343 }
344 match op {
345 UnaryOp::Neg => match val {
346 Value::Integer(i) => i
347 .checked_neg()
348 .map(Value::Integer)
349 .ok_or(SqlError::IntegerOverflow),
350 Value::Real(r) => Ok(Value::Real(-r)),
351 _ => Err(SqlError::TypeMismatch {
352 expected: "numeric".into(),
353 got: format!("{}", val.data_type()),
354 }),
355 },
356 UnaryOp::Not => match val {
357 Value::Boolean(b) => Ok(Value::Boolean(!b)),
358 Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
359 _ => Err(SqlError::TypeMismatch {
360 expected: "BOOLEAN".into(),
361 got: format!("{}", val.data_type()),
362 }),
363 },
364 }
365}
366
367fn value_to_text(val: &Value) -> String {
368 match val {
369 Value::Text(s) => s.to_string(),
370 Value::Integer(i) => i.to_string(),
371 Value::Real(r) => {
372 if r.fract() == 0.0 && r.is_finite() {
373 format!("{r:.1}")
374 } else {
375 format!("{r}")
376 }
377 }
378 Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
379 Value::Null => String::new(),
380 Value::Blob(b) => {
381 let mut s = String::with_capacity(b.len() * 2);
382 for byte in b {
383 s.push_str(&format!("{byte:02X}"));
384 }
385 s
386 }
387 }
388}
389
390fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
391 if val.is_null() || low.is_null() || high.is_null() {
392 let ge = if val.is_null() || low.is_null() {
393 None
394 } else {
395 Some(*val >= *low)
396 };
397 let le = if val.is_null() || high.is_null() {
398 None
399 } else {
400 Some(*val <= *high)
401 };
402
403 let result = match (ge, le) {
404 (Some(false), _) | (_, Some(false)) => Some(false),
405 (Some(true), Some(true)) => Some(true),
406 _ => None,
407 };
408
409 return match result {
410 Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
411 None => Ok(Value::Null),
412 };
413 }
414
415 let in_range = *val >= *low && *val <= *high;
416 Ok(Value::Boolean(if negated { !in_range } else { in_range }))
417}
418
419const MAX_LIKE_PATTERN_LEN: usize = 10_000;
420
421fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
422 if val.is_null() || pattern.is_null() {
423 return Ok(Value::Null);
424 }
425 let text = match val {
426 Value::Text(s) => s.as_str(),
427 _ => {
428 return Err(SqlError::TypeMismatch {
429 expected: "TEXT".into(),
430 got: val.data_type().to_string(),
431 })
432 }
433 };
434 let pat = match pattern {
435 Value::Text(s) => s.as_str(),
436 _ => {
437 return Err(SqlError::TypeMismatch {
438 expected: "TEXT".into(),
439 got: pattern.data_type().to_string(),
440 })
441 }
442 };
443
444 if pat.len() > MAX_LIKE_PATTERN_LEN {
445 return Err(SqlError::InvalidValue(format!(
446 "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
447 pat.len()
448 )));
449 }
450
451 let esc_char = match escape {
452 Some(Value::Text(s)) => {
453 let mut chars = s.chars();
454 let c = chars.next().ok_or_else(|| {
455 SqlError::InvalidValue("ESCAPE must be a single character".into())
456 })?;
457 if chars.next().is_some() {
458 return Err(SqlError::InvalidValue(
459 "ESCAPE must be a single character".into(),
460 ));
461 }
462 Some(c)
463 }
464 Some(Value::Null) => return Ok(Value::Null),
465 Some(_) => {
466 return Err(SqlError::TypeMismatch {
467 expected: "TEXT".into(),
468 got: "non-text".into(),
469 })
470 }
471 None => None,
472 };
473
474 let matched = like_match(text, pat, esc_char);
475 Ok(Value::Boolean(if negated { !matched } else { matched }))
476}
477
478fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
479 let t: Vec<char> = text.chars().collect();
480 let p: Vec<char> = pattern.chars().collect();
481 like_match_impl(&t, &p, 0, 0, escape)
482}
483
484fn like_match_impl(
485 t: &[char],
486 p: &[char],
487 mut ti: usize,
488 mut pi: usize,
489 esc: Option<char>,
490) -> bool {
491 let mut star_pi: Option<usize> = None;
492 let mut star_ti: usize = 0;
493
494 while ti < t.len() {
495 if pi < p.len() {
496 if let Some(ec) = esc {
497 if p[pi] == ec && pi + 1 < p.len() {
498 pi += 1;
499 let pc_lower = p[pi].to_ascii_lowercase();
500 let tc_lower = t[ti].to_ascii_lowercase();
501 if pc_lower == tc_lower {
502 pi += 1;
503 ti += 1;
504 continue;
505 } else if let Some(sp) = star_pi {
506 pi = sp + 1;
507 star_ti += 1;
508 ti = star_ti;
509 continue;
510 } else {
511 return false;
512 }
513 }
514 }
515 if p[pi] == '%' {
516 star_pi = Some(pi);
517 star_ti = ti;
518 pi += 1;
519 continue;
520 }
521 if p[pi] == '_' {
522 pi += 1;
523 ti += 1;
524 continue;
525 }
526 if p[pi].eq_ignore_ascii_case(&t[ti]) {
527 pi += 1;
528 ti += 1;
529 continue;
530 }
531 }
532 if let Some(sp) = star_pi {
533 pi = sp + 1;
534 star_ti += 1;
535 ti = star_ti;
536 } else {
537 return false;
538 }
539 }
540
541 while pi < p.len() && p[pi] == '%' {
542 pi += 1;
543 }
544 pi == p.len()
545}
546
547fn eval_case(
548 operand: Option<&Expr>,
549 conditions: &[(Expr, Expr)],
550 else_result: Option<&Expr>,
551 col_map: &ColumnMap,
552 row: &[Value],
553) -> Result<Value> {
554 if let Some(op_expr) = operand {
555 let op_val = eval_expr(op_expr, col_map, row)?;
556 for (cond, result) in conditions {
557 let cond_val = eval_expr(cond, col_map, row)?;
558 if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
559 return eval_expr(result, col_map, row);
560 }
561 }
562 } else {
563 for (cond, result) in conditions {
564 let cond_val = eval_expr(cond, col_map, row)?;
565 if is_truthy(&cond_val) {
566 return eval_expr(result, col_map, row);
567 }
568 }
569 }
570 match else_result {
571 Some(e) => eval_expr(e, col_map, row),
572 None => Ok(Value::Null),
573 }
574}
575
576fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
577 if val.is_null() {
578 return Ok(Value::Null);
579 }
580 match target {
581 DataType::Integer => match val {
582 Value::Integer(_) => Ok(val.clone()),
583 Value::Real(r) => Ok(Value::Integer(*r as i64)),
584 Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
585 Value::Text(s) => s
586 .trim()
587 .parse::<i64>()
588 .map(Value::Integer)
589 .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
590 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
591 _ => Err(SqlError::InvalidValue(format!(
592 "cannot cast {} to INTEGER",
593 val.data_type()
594 ))),
595 },
596 DataType::Real => match val {
597 Value::Real(_) => Ok(val.clone()),
598 Value::Integer(i) => Ok(Value::Real(*i as f64)),
599 Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
600 Value::Text(s) => s
601 .trim()
602 .parse::<f64>()
603 .map(Value::Real)
604 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
605 _ => Err(SqlError::InvalidValue(format!(
606 "cannot cast {} to REAL",
607 val.data_type()
608 ))),
609 },
610 DataType::Text => Ok(Value::Text(value_to_text(val).into())),
611 DataType::Boolean => match val {
612 Value::Boolean(_) => Ok(val.clone()),
613 Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
614 Value::Text(s) => {
615 let lower = s.trim().to_ascii_lowercase();
616 match lower.as_str() {
617 "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
618 "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
619 _ => Err(SqlError::InvalidValue(format!(
620 "cannot cast '{s}' to BOOLEAN"
621 ))),
622 }
623 }
624 _ => Err(SqlError::InvalidValue(format!(
625 "cannot cast {} to BOOLEAN",
626 val.data_type()
627 ))),
628 },
629 DataType::Blob => match val {
630 Value::Blob(_) => Ok(val.clone()),
631 Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
632 _ => Err(SqlError::InvalidValue(format!(
633 "cannot cast {} to BLOB",
634 val.data_type()
635 ))),
636 },
637 DataType::Null => Ok(Value::Null),
638 }
639}
640
641fn eval_scalar_function(
642 name: &str,
643 args: &[Expr],
644 col_map: &ColumnMap,
645 row: &[Value],
646) -> Result<Value> {
647 let evaluated: Vec<Value> = args
648 .iter()
649 .map(|a| eval_expr(a, col_map, row))
650 .collect::<Result<Vec<_>>>()?;
651
652 match name {
653 "LENGTH" => {
654 check_args(name, &evaluated, 1)?;
655 match &evaluated[0] {
656 Value::Null => Ok(Value::Null),
657 Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
658 Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
659 _ => Ok(Value::Integer(
660 value_to_text(&evaluated[0]).chars().count() as i64
661 )),
662 }
663 }
664 "UPPER" => {
665 check_args(name, &evaluated, 1)?;
666 match &evaluated[0] {
667 Value::Null => Ok(Value::Null),
668 Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
669 _ => Ok(Value::Text(
670 value_to_text(&evaluated[0]).to_ascii_uppercase().into(),
671 )),
672 }
673 }
674 "LOWER" => {
675 check_args(name, &evaluated, 1)?;
676 match &evaluated[0] {
677 Value::Null => Ok(Value::Null),
678 Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
679 _ => Ok(Value::Text(
680 value_to_text(&evaluated[0]).to_ascii_lowercase().into(),
681 )),
682 }
683 }
684 "SUBSTR" | "SUBSTRING" => {
685 if evaluated.len() < 2 || evaluated.len() > 3 {
686 return Err(SqlError::InvalidValue(format!(
687 "{name} requires 2 or 3 arguments"
688 )));
689 }
690 if evaluated.iter().any(|v| v.is_null()) {
691 return Ok(Value::Null);
692 }
693 let s = value_to_text(&evaluated[0]);
694 let chars: Vec<char> = s.chars().collect();
695 let start = match &evaluated[1] {
696 Value::Integer(i) => *i,
697 _ => {
698 return Err(SqlError::TypeMismatch {
699 expected: "INTEGER".into(),
700 got: evaluated[1].data_type().to_string(),
701 })
702 }
703 };
704 let len = chars.len() as i64;
705
706 let (begin, count) = if evaluated.len() == 3 {
707 let cnt = match &evaluated[2] {
708 Value::Integer(i) => *i,
709 _ => {
710 return Err(SqlError::TypeMismatch {
711 expected: "INTEGER".into(),
712 got: evaluated[2].data_type().to_string(),
713 })
714 }
715 };
716 if start >= 1 {
717 let b = (start - 1).min(len) as usize;
718 let c = cnt.max(0) as usize;
719 (b, c)
720 } else if start == 0 {
721 let c = (cnt - 1).max(0) as usize;
722 (0usize, c)
723 } else {
724 let adjusted_cnt = (cnt + start - 1).max(0) as usize;
725 (0usize, adjusted_cnt)
726 }
727 } else if start >= 1 {
728 let b = (start - 1).min(len) as usize;
729 (b, chars.len() - b)
730 } else if start == 0 {
731 (0usize, chars.len())
732 } else {
733 let b = (len + start).max(0) as usize;
734 (b, chars.len() - b)
735 };
736
737 let result: String = chars.iter().skip(begin).take(count).collect();
738 Ok(Value::Text(result.into()))
739 }
740 "TRIM" | "LTRIM" | "RTRIM" => {
741 if evaluated.is_empty() || evaluated.len() > 2 {
742 return Err(SqlError::InvalidValue(format!(
743 "{name} requires 1 or 2 arguments"
744 )));
745 }
746 if evaluated[0].is_null() {
747 return Ok(Value::Null);
748 }
749 let s = value_to_text(&evaluated[0]);
750 let trim_chars: Vec<char> = if evaluated.len() == 2 {
751 if evaluated[1].is_null() {
752 return Ok(Value::Null);
753 }
754 value_to_text(&evaluated[1]).chars().collect()
755 } else {
756 vec![' ']
757 };
758 let result = match name {
759 "TRIM" => s
760 .trim_matches(|c: char| trim_chars.contains(&c))
761 .to_string(),
762 "LTRIM" => s
763 .trim_start_matches(|c: char| trim_chars.contains(&c))
764 .to_string(),
765 "RTRIM" => s
766 .trim_end_matches(|c: char| trim_chars.contains(&c))
767 .to_string(),
768 _ => unreachable!(),
769 };
770 Ok(Value::Text(result.into()))
771 }
772 "REPLACE" => {
773 check_args(name, &evaluated, 3)?;
774 if evaluated.iter().any(|v| v.is_null()) {
775 return Ok(Value::Null);
776 }
777 let s = value_to_text(&evaluated[0]);
778 let from = value_to_text(&evaluated[1]);
779 let to = value_to_text(&evaluated[2]);
780 if from.is_empty() {
781 return Ok(Value::Text(s.into()));
782 }
783 Ok(Value::Text(s.replace(&from, &to).into()))
784 }
785 "INSTR" => {
786 check_args(name, &evaluated, 2)?;
787 if evaluated.iter().any(|v| v.is_null()) {
788 return Ok(Value::Null);
789 }
790 let haystack = value_to_text(&evaluated[0]);
791 let needle = value_to_text(&evaluated[1]);
792 let pos = haystack
793 .find(&needle)
794 .map(|i| haystack[..i].chars().count() as i64 + 1)
795 .unwrap_or(0);
796 Ok(Value::Integer(pos))
797 }
798 "CONCAT" => {
799 if evaluated.is_empty() {
800 return Ok(Value::Text(CompactString::default()));
801 }
802 let mut result = String::new();
803 for v in &evaluated {
804 match v {
805 Value::Null => {}
806 _ => result.push_str(&value_to_text(v)),
807 }
808 }
809 Ok(Value::Text(result.into()))
810 }
811 "ABS" => {
812 check_args(name, &evaluated, 1)?;
813 match &evaluated[0] {
814 Value::Null => Ok(Value::Null),
815 Value::Integer(i) => i
816 .checked_abs()
817 .map(Value::Integer)
818 .ok_or(SqlError::IntegerOverflow),
819 Value::Real(r) => Ok(Value::Real(r.abs())),
820 _ => Err(SqlError::TypeMismatch {
821 expected: "numeric".into(),
822 got: evaluated[0].data_type().to_string(),
823 }),
824 }
825 }
826 "ROUND" => {
827 if evaluated.is_empty() || evaluated.len() > 2 {
828 return Err(SqlError::InvalidValue(
829 "ROUND requires 1 or 2 arguments".into(),
830 ));
831 }
832 if evaluated[0].is_null() {
833 return Ok(Value::Null);
834 }
835 let val = match &evaluated[0] {
836 Value::Integer(i) => *i as f64,
837 Value::Real(r) => *r,
838 _ => {
839 return Err(SqlError::TypeMismatch {
840 expected: "numeric".into(),
841 got: evaluated[0].data_type().to_string(),
842 })
843 }
844 };
845 let places = if evaluated.len() == 2 {
846 match &evaluated[1] {
847 Value::Null => return Ok(Value::Null),
848 Value::Integer(i) => *i,
849 _ => {
850 return Err(SqlError::TypeMismatch {
851 expected: "INTEGER".into(),
852 got: evaluated[1].data_type().to_string(),
853 })
854 }
855 }
856 } else {
857 0
858 };
859 let factor = 10f64.powi(places as i32);
860 let rounded = (val * factor).round() / factor;
861 Ok(Value::Real(rounded))
862 }
863 "CEIL" | "CEILING" => {
864 check_args(name, &evaluated, 1)?;
865 match &evaluated[0] {
866 Value::Null => Ok(Value::Null),
867 Value::Integer(i) => Ok(Value::Integer(*i)),
868 Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
869 _ => Err(SqlError::TypeMismatch {
870 expected: "numeric".into(),
871 got: evaluated[0].data_type().to_string(),
872 }),
873 }
874 }
875 "FLOOR" => {
876 check_args(name, &evaluated, 1)?;
877 match &evaluated[0] {
878 Value::Null => Ok(Value::Null),
879 Value::Integer(i) => Ok(Value::Integer(*i)),
880 Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
881 _ => Err(SqlError::TypeMismatch {
882 expected: "numeric".into(),
883 got: evaluated[0].data_type().to_string(),
884 }),
885 }
886 }
887 "SIGN" => {
888 check_args(name, &evaluated, 1)?;
889 match &evaluated[0] {
890 Value::Null => Ok(Value::Null),
891 Value::Integer(i) => Ok(Value::Integer(i.signum())),
892 Value::Real(r) => {
893 if *r > 0.0 {
894 Ok(Value::Integer(1))
895 } else if *r < 0.0 {
896 Ok(Value::Integer(-1))
897 } else {
898 Ok(Value::Integer(0))
899 }
900 }
901 _ => Err(SqlError::TypeMismatch {
902 expected: "numeric".into(),
903 got: evaluated[0].data_type().to_string(),
904 }),
905 }
906 }
907 "SQRT" => {
908 check_args(name, &evaluated, 1)?;
909 match &evaluated[0] {
910 Value::Null => Ok(Value::Null),
911 Value::Integer(i) => {
912 if *i < 0 {
913 Ok(Value::Null)
914 } else {
915 Ok(Value::Real((*i as f64).sqrt()))
916 }
917 }
918 Value::Real(r) => {
919 if *r < 0.0 {
920 Ok(Value::Null)
921 } else {
922 Ok(Value::Real(r.sqrt()))
923 }
924 }
925 _ => Err(SqlError::TypeMismatch {
926 expected: "numeric".into(),
927 got: evaluated[0].data_type().to_string(),
928 }),
929 }
930 }
931 "RANDOM" => {
932 check_args(name, &evaluated, 0)?;
933 use std::collections::hash_map::DefaultHasher;
934 use std::hash::{Hash, Hasher};
935 use std::time::SystemTime;
936 let mut hasher = DefaultHasher::new();
937 SystemTime::now().hash(&mut hasher);
938 std::thread::current().id().hash(&mut hasher);
939 let mut val = hasher.finish() as i64;
940 if val == i64::MIN {
941 val = i64::MAX;
942 }
943 Ok(Value::Integer(val))
944 }
945 "TYPEOF" => {
946 check_args(name, &evaluated, 1)?;
947 let type_name = match &evaluated[0] {
948 Value::Null => "null",
949 Value::Integer(_) => "integer",
950 Value::Real(_) => "real",
951 Value::Text(_) => "text",
952 Value::Blob(_) => "blob",
953 Value::Boolean(_) => "boolean",
954 };
955 Ok(Value::Text(type_name.into()))
956 }
957 "MIN" => {
958 check_args(name, &evaluated, 2)?;
959 if evaluated[0].is_null() {
960 return Ok(evaluated[1].clone());
961 }
962 if evaluated[1].is_null() {
963 return Ok(evaluated[0].clone());
964 }
965 if evaluated[0] <= evaluated[1] {
966 Ok(evaluated[0].clone())
967 } else {
968 Ok(evaluated[1].clone())
969 }
970 }
971 "MAX" => {
972 check_args(name, &evaluated, 2)?;
973 if evaluated[0].is_null() {
974 return Ok(evaluated[1].clone());
975 }
976 if evaluated[1].is_null() {
977 return Ok(evaluated[0].clone());
978 }
979 if evaluated[0] >= evaluated[1] {
980 Ok(evaluated[0].clone())
981 } else {
982 Ok(evaluated[1].clone())
983 }
984 }
985 "HEX" => {
986 check_args(name, &evaluated, 1)?;
987 match &evaluated[0] {
988 Value::Null => Ok(Value::Null),
989 Value::Blob(b) => {
990 let mut s = String::with_capacity(b.len() * 2);
991 for byte in b {
992 s.push_str(&format!("{byte:02X}"));
993 }
994 Ok(Value::Text(s.into()))
995 }
996 Value::Text(s) => {
997 let mut r = String::with_capacity(s.len() * 2);
998 for byte in s.as_bytes() {
999 r.push_str(&format!("{byte:02X}"));
1000 }
1001 Ok(Value::Text(r.into()))
1002 }
1003 _ => Ok(Value::Text(value_to_text(&evaluated[0]).into())),
1004 }
1005 }
1006 _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
1007 }
1008}
1009
1010fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
1011 if args.len() != expected {
1012 Err(SqlError::InvalidValue(format!(
1013 "{name} requires {expected} argument(s), got {}",
1014 args.len()
1015 )))
1016 } else {
1017 Ok(())
1018 }
1019}
1020
1021pub fn referenced_columns(expr: &Expr, columns: &[ColumnDef]) -> Vec<usize> {
1022 let mut indices = Vec::new();
1023 collect_column_refs(expr, columns, &mut indices);
1024 indices.sort_unstable();
1025 indices.dedup();
1026 indices
1027}
1028
1029fn collect_column_refs(expr: &Expr, columns: &[ColumnDef], out: &mut Vec<usize>) {
1030 match expr {
1031 Expr::Column(name) => {
1032 for (i, c) in columns.iter().enumerate() {
1033 if c.name == *name || c.name.ends_with(&format!(".{name}")) {
1034 out.push(i);
1035 break;
1036 }
1037 }
1038 }
1039 Expr::QualifiedColumn { table, column } => {
1040 let qualified = format!("{table}.{column}");
1041 if let Some(idx) = columns.iter().position(|c| c.name == qualified) {
1042 out.push(idx);
1043 } else {
1044 let matches: Vec<usize> = columns
1045 .iter()
1046 .enumerate()
1047 .filter(|(_, c)| c.name == *column)
1048 .map(|(i, _)| i)
1049 .collect();
1050 if matches.len() == 1 {
1051 out.push(matches[0]);
1052 }
1053 }
1054 }
1055 Expr::BinaryOp { left, right, .. } => {
1056 collect_column_refs(left, columns, out);
1057 collect_column_refs(right, columns, out);
1058 }
1059 Expr::UnaryOp { expr, .. } => {
1060 collect_column_refs(expr, columns, out);
1061 }
1062 Expr::IsNull(e) | Expr::IsNotNull(e) => {
1063 collect_column_refs(e, columns, out);
1064 }
1065 Expr::Function { args, .. } => {
1066 for arg in args {
1067 collect_column_refs(arg, columns, out);
1068 }
1069 }
1070 Expr::InSubquery { expr, .. } => {
1071 collect_column_refs(expr, columns, out);
1072 }
1073 Expr::InList { expr, list, .. } => {
1074 collect_column_refs(expr, columns, out);
1075 for item in list {
1076 collect_column_refs(item, columns, out);
1077 }
1078 }
1079 Expr::InSet { expr, .. } => {
1080 collect_column_refs(expr, columns, out);
1081 }
1082 Expr::Between {
1083 expr, low, high, ..
1084 } => {
1085 collect_column_refs(expr, columns, out);
1086 collect_column_refs(low, columns, out);
1087 collect_column_refs(high, columns, out);
1088 }
1089 Expr::Like {
1090 expr,
1091 pattern,
1092 escape,
1093 ..
1094 } => {
1095 collect_column_refs(expr, columns, out);
1096 collect_column_refs(pattern, columns, out);
1097 if let Some(esc) = escape {
1098 collect_column_refs(esc, columns, out);
1099 }
1100 }
1101 Expr::Case {
1102 operand,
1103 conditions,
1104 else_result,
1105 } => {
1106 if let Some(op) = operand {
1107 collect_column_refs(op, columns, out);
1108 }
1109 for (when, then) in conditions {
1110 collect_column_refs(when, columns, out);
1111 collect_column_refs(then, columns, out);
1112 }
1113 if let Some(e) = else_result {
1114 collect_column_refs(e, columns, out);
1115 }
1116 }
1117 Expr::Coalesce(args) => {
1118 for arg in args {
1119 collect_column_refs(arg, columns, out);
1120 }
1121 }
1122 Expr::Cast { expr, .. } => {
1123 collect_column_refs(expr, columns, out);
1124 }
1125 Expr::Literal(_)
1126 | Expr::Parameter(_)
1127 | Expr::CountStar
1128 | Expr::Exists { .. }
1129 | Expr::ScalarSubquery(_) => {}
1130 }
1131}
1132
1133pub fn is_truthy(val: &Value) -> bool {
1135 match val {
1136 Value::Boolean(b) => *b,
1137 Value::Integer(i) => *i != 0,
1138 Value::Null => false,
1139 _ => true,
1140 }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145 use super::*;
1146 use crate::types::DataType;
1147
1148 fn test_columns() -> Vec<ColumnDef> {
1149 vec![
1150 ColumnDef {
1151 name: "id".into(),
1152 data_type: DataType::Integer,
1153 nullable: false,
1154 position: 0,
1155 },
1156 ColumnDef {
1157 name: "name".into(),
1158 data_type: DataType::Text,
1159 nullable: true,
1160 position: 1,
1161 },
1162 ColumnDef {
1163 name: "score".into(),
1164 data_type: DataType::Real,
1165 nullable: true,
1166 position: 2,
1167 },
1168 ColumnDef {
1169 name: "active".into(),
1170 data_type: DataType::Boolean,
1171 nullable: false,
1172 position: 3,
1173 },
1174 ]
1175 }
1176
1177 fn test_row() -> Vec<Value> {
1178 vec![
1179 Value::Integer(1),
1180 Value::Text("Alice".into()),
1181 Value::Real(95.5),
1182 Value::Boolean(true),
1183 ]
1184 }
1185
1186 #[test]
1187 fn eval_literal() {
1188 let cols = test_columns();
1189 let cm = ColumnMap::new(&cols);
1190 let row = test_row();
1191 let expr = Expr::Literal(Value::Integer(42));
1192 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(42));
1193 }
1194
1195 #[test]
1196 fn eval_column_ref() {
1197 let cols = test_columns();
1198 let cm = ColumnMap::new(&cols);
1199 let row = test_row();
1200 let expr = Expr::Column("name".into());
1201 assert_eq!(
1202 eval_expr(&expr, &cm, &row).unwrap(),
1203 Value::Text("Alice".into())
1204 );
1205 }
1206
1207 #[test]
1208 fn eval_column_case_insensitive() {
1209 let cols = test_columns();
1210 let cm = ColumnMap::new(&cols);
1211 let row = test_row();
1212 let expr = Expr::Column("name".into());
1213 assert_eq!(
1214 eval_expr(&expr, &cm, &row).unwrap(),
1215 Value::Text("Alice".into())
1216 );
1217 }
1218
1219 #[test]
1220 fn eval_arithmetic_int() {
1221 let cols = test_columns();
1222 let cm = ColumnMap::new(&cols);
1223 let row = test_row();
1224 let expr = Expr::BinaryOp {
1225 left: Box::new(Expr::Column("id".into())),
1226 op: BinOp::Add,
1227 right: Box::new(Expr::Literal(Value::Integer(10))),
1228 };
1229 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(11));
1230 }
1231
1232 #[test]
1233 fn eval_comparison() {
1234 let cols = test_columns();
1235 let cm = ColumnMap::new(&cols);
1236 let row = test_row();
1237 let expr = Expr::BinaryOp {
1238 left: Box::new(Expr::Column("score".into())),
1239 op: BinOp::Gt,
1240 right: Box::new(Expr::Literal(Value::Real(90.0))),
1241 };
1242 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1243 }
1244
1245 #[test]
1246 fn eval_null_propagation() {
1247 let cols = test_columns();
1248 let cm = ColumnMap::new(&cols);
1249 let row = vec![
1250 Value::Integer(1),
1251 Value::Null,
1252 Value::Null,
1253 Value::Boolean(true),
1254 ];
1255 let expr = Expr::BinaryOp {
1256 left: Box::new(Expr::Column("name".into())),
1257 op: BinOp::Eq,
1258 right: Box::new(Expr::Literal(Value::Text("test".into()))),
1259 };
1260 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1261 }
1262
1263 #[test]
1264 fn eval_and_three_valued() {
1265 let cols = test_columns();
1266 let cm = ColumnMap::new(&cols);
1267 let row = vec![
1268 Value::Integer(1),
1269 Value::Null,
1270 Value::Null,
1271 Value::Boolean(true),
1272 ];
1273
1274 let expr = Expr::BinaryOp {
1276 left: Box::new(Expr::Column("name".into())),
1277 op: BinOp::And,
1278 right: Box::new(Expr::Literal(Value::Boolean(false))),
1279 };
1280 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1281
1282 let expr = Expr::BinaryOp {
1284 left: Box::new(Expr::Column("name".into())),
1285 op: BinOp::And,
1286 right: Box::new(Expr::Literal(Value::Boolean(true))),
1287 };
1288 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1289 }
1290
1291 #[test]
1292 fn eval_or_three_valued() {
1293 let cols = test_columns();
1294 let cm = ColumnMap::new(&cols);
1295 let row = vec![
1296 Value::Integer(1),
1297 Value::Null,
1298 Value::Null,
1299 Value::Boolean(true),
1300 ];
1301
1302 let expr = Expr::BinaryOp {
1304 left: Box::new(Expr::Column("name".into())),
1305 op: BinOp::Or,
1306 right: Box::new(Expr::Literal(Value::Boolean(true))),
1307 };
1308 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1309
1310 let expr = Expr::BinaryOp {
1312 left: Box::new(Expr::Column("name".into())),
1313 op: BinOp::Or,
1314 right: Box::new(Expr::Literal(Value::Boolean(false))),
1315 };
1316 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1317 }
1318
1319 #[test]
1320 fn eval_is_null() {
1321 let cols = test_columns();
1322 let cm = ColumnMap::new(&cols);
1323 let row = vec![
1324 Value::Integer(1),
1325 Value::Null,
1326 Value::Null,
1327 Value::Boolean(true),
1328 ];
1329 let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1330 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1331
1332 let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1333 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1334 }
1335
1336 #[test]
1337 fn eval_not() {
1338 let cols = test_columns();
1339 let cm = ColumnMap::new(&cols);
1340 let row = test_row();
1341 let expr = Expr::UnaryOp {
1342 op: UnaryOp::Not,
1343 expr: Box::new(Expr::Column("active".into())),
1344 };
1345 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1346 }
1347
1348 #[test]
1349 fn eval_neg() {
1350 let cols = test_columns();
1351 let cm = ColumnMap::new(&cols);
1352 let row = test_row();
1353 let expr = Expr::UnaryOp {
1354 op: UnaryOp::Neg,
1355 expr: Box::new(Expr::Column("id".into())),
1356 };
1357 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(-1));
1358 }
1359
1360 #[test]
1361 fn eval_division_by_zero() {
1362 let cols = test_columns();
1363 let cm = ColumnMap::new(&cols);
1364 let row = test_row();
1365 let expr = Expr::BinaryOp {
1366 left: Box::new(Expr::Column("id".into())),
1367 op: BinOp::Div,
1368 right: Box::new(Expr::Literal(Value::Integer(0))),
1369 };
1370 assert!(matches!(
1371 eval_expr(&expr, &cm, &row),
1372 Err(SqlError::DivisionByZero)
1373 ));
1374 }
1375
1376 #[test]
1377 fn eval_mixed_numeric() {
1378 let cols = test_columns();
1379 let cm = ColumnMap::new(&cols);
1380 let row = test_row();
1381 let expr = Expr::BinaryOp {
1383 left: Box::new(Expr::Column("id".into())),
1384 op: BinOp::Add,
1385 right: Box::new(Expr::Column("score".into())),
1386 };
1387 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Real(96.5));
1388 }
1389
1390 #[test]
1391 fn is_truthy_values() {
1392 assert!(is_truthy(&Value::Boolean(true)));
1393 assert!(!is_truthy(&Value::Boolean(false)));
1394 assert!(!is_truthy(&Value::Null));
1395 assert!(is_truthy(&Value::Integer(1)));
1396 assert!(!is_truthy(&Value::Integer(0)));
1397 }
1398}