1use std::collections::HashMap;
4
5use crate::error::{Result, SqlError};
6use crate::parser::{BinOp, Expr, UnaryOp};
7use crate::types::{ColumnDef, CompactString, DataType, Value};
8
9pub struct ColumnMap {
10 exact: HashMap<String, usize>,
11 short: HashMap<String, ShortMatch>,
12}
13
14enum ShortMatch {
15 Unique(usize),
16 Ambiguous,
17}
18
19impl ColumnMap {
20 pub fn new(columns: &[ColumnDef]) -> Self {
21 let mut exact = HashMap::with_capacity(columns.len() * 2);
22 let mut short: HashMap<String, ShortMatch> = HashMap::with_capacity(columns.len());
23
24 for (i, col) in columns.iter().enumerate() {
25 let lower = col.name.to_ascii_lowercase();
26 exact.insert(lower.clone(), i);
27
28 let unqualified = if let Some(dot) = lower.rfind('.') {
29 &lower[dot + 1..]
30 } else {
31 &lower
32 };
33 short
34 .entry(unqualified.to_string())
35 .and_modify(|e| *e = ShortMatch::Ambiguous)
36 .or_insert(ShortMatch::Unique(i));
37 }
38
39 Self { exact, short }
40 }
41
42 pub(crate) fn resolve(&self, name: &str) -> Result<usize> {
43 if let Some(&idx) = self.exact.get(name) {
44 return Ok(idx);
45 }
46 match self.short.get(name) {
47 Some(ShortMatch::Unique(idx)) => Ok(*idx),
48 Some(ShortMatch::Ambiguous) => Err(SqlError::AmbiguousColumn(name.to_string())),
49 None => Err(SqlError::ColumnNotFound(name.to_string())),
50 }
51 }
52
53 pub(crate) fn resolve_qualified(&self, table: &str, column: &str) -> Result<usize> {
54 let qualified = format!("{table}.{column}");
55 if let Some(&idx) = self.exact.get(&qualified) {
56 return Ok(idx);
57 }
58 match self.short.get(column) {
59 Some(ShortMatch::Unique(idx)) => Ok(*idx),
60 _ => Err(SqlError::ColumnNotFound(format!("{table}.{column}"))),
61 }
62 }
63}
64
65pub fn eval_expr(expr: &Expr, col_map: &ColumnMap, row: &[Value]) -> Result<Value> {
66 match expr {
67 Expr::Literal(v) => Ok(v.clone()),
68
69 Expr::Column(name) => {
70 let idx = col_map.resolve(name)?;
71 Ok(row[idx].clone())
72 }
73
74 Expr::QualifiedColumn { table, column } => {
75 let idx = col_map.resolve_qualified(table, column)?;
76 Ok(row[idx].clone())
77 }
78
79 Expr::BinaryOp { left, op, right } => {
80 let lval = eval_expr(left, col_map, row)?;
81 let rval = eval_expr(right, col_map, row)?;
82 eval_binary_op(&lval, *op, &rval)
83 }
84
85 Expr::UnaryOp { op, expr } => {
86 let val = eval_expr(expr, col_map, row)?;
87 eval_unary_op(*op, &val)
88 }
89
90 Expr::IsNull(e) => {
91 let val = eval_expr(e, col_map, row)?;
92 Ok(Value::Boolean(val.is_null()))
93 }
94
95 Expr::IsNotNull(e) => {
96 let val = eval_expr(e, col_map, row)?;
97 Ok(Value::Boolean(!val.is_null()))
98 }
99
100 Expr::Function { name, args } => eval_scalar_function(name, args, col_map, row),
101
102 Expr::CountStar => Err(SqlError::Unsupported(
103 "COUNT(*) in non-aggregate context".into(),
104 )),
105
106 Expr::InList {
107 expr: e,
108 list,
109 negated,
110 } => {
111 let lhs = eval_expr(e, col_map, row)?;
112 eval_in_values(&lhs, list, col_map, row, *negated)
113 }
114
115 Expr::InSet {
116 expr: e,
117 values,
118 has_null,
119 negated,
120 } => {
121 let lhs = eval_expr(e, col_map, row)?;
122 eval_in_set(&lhs, values, *has_null, *negated)
123 }
124
125 Expr::Between {
126 expr: e,
127 low,
128 high,
129 negated,
130 } => {
131 let val = eval_expr(e, col_map, row)?;
132 let lo = eval_expr(low, col_map, row)?;
133 let hi = eval_expr(high, col_map, row)?;
134 eval_between(&val, &lo, &hi, *negated)
135 }
136
137 Expr::Like {
138 expr: e,
139 pattern,
140 escape,
141 negated,
142 } => {
143 let val = eval_expr(e, col_map, row)?;
144 let pat = eval_expr(pattern, col_map, row)?;
145 let esc = escape
146 .as_ref()
147 .map(|e| eval_expr(e, col_map, row))
148 .transpose()?;
149 eval_like(&val, &pat, esc.as_ref(), *negated)
150 }
151
152 Expr::Case {
153 operand,
154 conditions,
155 else_result,
156 } => eval_case(
157 operand.as_deref(),
158 conditions,
159 else_result.as_deref(),
160 col_map,
161 row,
162 ),
163
164 Expr::Coalesce(args) => {
165 for arg in args {
166 let val = eval_expr(arg, col_map, row)?;
167 if !val.is_null() {
168 return Ok(val);
169 }
170 }
171 Ok(Value::Null)
172 }
173
174 Expr::Cast { expr: e, data_type } => {
175 let val = eval_expr(e, col_map, row)?;
176 eval_cast(&val, *data_type)
177 }
178
179 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
180 SqlError::Unsupported("subquery not materialized (internal error)".into()),
181 ),
182
183 Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
184
185 Expr::WindowFunction { .. } => Err(SqlError::Unsupported(
186 "window functions are only allowed in SELECT columns".into(),
187 )),
188 }
189}
190
191fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
192 match op {
194 BinOp::And => return eval_and(left, right),
195 BinOp::Or => return eval_or(left, right),
196 _ => {}
197 }
198
199 if left.is_null() || right.is_null() {
201 return Ok(Value::Null);
202 }
203
204 match op {
205 BinOp::Eq => Ok(Value::Boolean(left == right)),
206 BinOp::NotEq => Ok(Value::Boolean(left != right)),
207 BinOp::Lt => Ok(Value::Boolean(left < right)),
208 BinOp::Gt => Ok(Value::Boolean(left > right)),
209 BinOp::LtEq => Ok(Value::Boolean(left <= right)),
210 BinOp::GtEq => Ok(Value::Boolean(left >= right)),
211 BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
212 BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
213 BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
214 BinOp::Div => {
215 match right {
216 Value::Integer(0) => return Err(SqlError::DivisionByZero),
217 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
218 _ => {}
219 }
220 eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
221 }
222 BinOp::Mod => {
223 match right {
224 Value::Integer(0) => return Err(SqlError::DivisionByZero),
225 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
226 _ => {}
227 }
228 eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
229 }
230 BinOp::Concat => {
231 let ls = value_to_text(left);
232 let rs = value_to_text(right);
233 Ok(Value::Text(format!("{ls}{rs}").into()))
234 }
235 BinOp::And | BinOp::Or => unreachable!(),
236 }
237}
238
239fn eval_and(left: &Value, right: &Value) -> Result<Value> {
241 let l = to_bool_or_null(left)?;
242 let r = to_bool_or_null(right)?;
243 match (l, r) {
244 (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
245 (Some(true), Some(true)) => Ok(Value::Boolean(true)),
246 _ => Ok(Value::Null),
247 }
248}
249
250fn eval_or(left: &Value, right: &Value) -> Result<Value> {
252 let l = to_bool_or_null(left)?;
253 let r = to_bool_or_null(right)?;
254 match (l, r) {
255 (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
256 (Some(false), Some(false)) => Ok(Value::Boolean(false)),
257 _ => Ok(Value::Null),
258 }
259}
260
261fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
262 match val {
263 Value::Boolean(b) => Ok(Some(*b)),
264 Value::Null => Ok(None),
265 Value::Integer(i) => Ok(Some(*i != 0)),
266 _ => Err(SqlError::TypeMismatch {
267 expected: "BOOLEAN".into(),
268 got: format!("{}", val.data_type()),
269 }),
270 }
271}
272
273fn eval_arithmetic(
274 left: &Value,
275 right: &Value,
276 int_op: fn(i64, i64) -> Option<i64>,
277 real_op: fn(f64, f64) -> f64,
278) -> Result<Value> {
279 match (left, right) {
280 (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
281 .map(Value::Integer)
282 .ok_or(SqlError::IntegerOverflow),
283 (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
284 (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
285 (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
286 _ => Err(SqlError::TypeMismatch {
287 expected: "numeric".into(),
288 got: format!("{} and {}", left.data_type(), right.data_type()),
289 }),
290 }
291}
292
293fn eval_in_values(
294 lhs: &Value,
295 list: &[Expr],
296 col_map: &ColumnMap,
297 row: &[Value],
298 negated: bool,
299) -> Result<Value> {
300 if list.is_empty() {
301 return Ok(Value::Boolean(negated));
302 }
303 if lhs.is_null() {
304 return Ok(Value::Null);
305 }
306 let mut has_null = false;
307 for item in list {
308 let rhs = eval_expr(item, col_map, row)?;
309 if rhs.is_null() {
310 has_null = true;
311 } else if lhs == &rhs {
312 return Ok(Value::Boolean(!negated));
313 }
314 }
315 if has_null {
316 Ok(Value::Null)
317 } else {
318 Ok(Value::Boolean(negated))
319 }
320}
321
322fn eval_in_set(
323 lhs: &Value,
324 values: &std::collections::HashSet<Value>,
325 has_null: bool,
326 negated: bool,
327) -> Result<Value> {
328 if values.is_empty() && !has_null {
329 return Ok(Value::Boolean(negated));
330 }
331 if lhs.is_null() {
332 return Ok(Value::Null);
333 }
334 if values.contains(lhs) {
335 return Ok(Value::Boolean(!negated));
336 }
337 if has_null {
338 Ok(Value::Null)
339 } else {
340 Ok(Value::Boolean(negated))
341 }
342}
343
344fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
345 if val.is_null() {
346 return Ok(Value::Null);
347 }
348 match op {
349 UnaryOp::Neg => match val {
350 Value::Integer(i) => i
351 .checked_neg()
352 .map(Value::Integer)
353 .ok_or(SqlError::IntegerOverflow),
354 Value::Real(r) => Ok(Value::Real(-r)),
355 _ => Err(SqlError::TypeMismatch {
356 expected: "numeric".into(),
357 got: format!("{}", val.data_type()),
358 }),
359 },
360 UnaryOp::Not => match val {
361 Value::Boolean(b) => Ok(Value::Boolean(!b)),
362 Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
363 _ => Err(SqlError::TypeMismatch {
364 expected: "BOOLEAN".into(),
365 got: format!("{}", val.data_type()),
366 }),
367 },
368 }
369}
370
371fn value_to_text(val: &Value) -> String {
372 match val {
373 Value::Text(s) => s.to_string(),
374 Value::Integer(i) => i.to_string(),
375 Value::Real(r) => {
376 if r.fract() == 0.0 && r.is_finite() {
377 format!("{r:.1}")
378 } else {
379 format!("{r}")
380 }
381 }
382 Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
383 Value::Null => String::new(),
384 Value::Blob(b) => {
385 let mut s = String::with_capacity(b.len() * 2);
386 for byte in b {
387 s.push_str(&format!("{byte:02X}"));
388 }
389 s
390 }
391 }
392}
393
394fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
395 if val.is_null() || low.is_null() || high.is_null() {
396 let ge = if val.is_null() || low.is_null() {
397 None
398 } else {
399 Some(*val >= *low)
400 };
401 let le = if val.is_null() || high.is_null() {
402 None
403 } else {
404 Some(*val <= *high)
405 };
406
407 let result = match (ge, le) {
408 (Some(false), _) | (_, Some(false)) => Some(false),
409 (Some(true), Some(true)) => Some(true),
410 _ => None,
411 };
412
413 return match result {
414 Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
415 None => Ok(Value::Null),
416 };
417 }
418
419 let in_range = *val >= *low && *val <= *high;
420 Ok(Value::Boolean(if negated { !in_range } else { in_range }))
421}
422
423const MAX_LIKE_PATTERN_LEN: usize = 10_000;
424
425fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
426 if val.is_null() || pattern.is_null() {
427 return Ok(Value::Null);
428 }
429 let text = match val {
430 Value::Text(s) => s.as_str(),
431 _ => {
432 return Err(SqlError::TypeMismatch {
433 expected: "TEXT".into(),
434 got: val.data_type().to_string(),
435 })
436 }
437 };
438 let pat = match pattern {
439 Value::Text(s) => s.as_str(),
440 _ => {
441 return Err(SqlError::TypeMismatch {
442 expected: "TEXT".into(),
443 got: pattern.data_type().to_string(),
444 })
445 }
446 };
447
448 if pat.len() > MAX_LIKE_PATTERN_LEN {
449 return Err(SqlError::InvalidValue(format!(
450 "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
451 pat.len()
452 )));
453 }
454
455 let esc_char = match escape {
456 Some(Value::Text(s)) => {
457 let mut chars = s.chars();
458 let c = chars.next().ok_or_else(|| {
459 SqlError::InvalidValue("ESCAPE must be a single character".into())
460 })?;
461 if chars.next().is_some() {
462 return Err(SqlError::InvalidValue(
463 "ESCAPE must be a single character".into(),
464 ));
465 }
466 Some(c)
467 }
468 Some(Value::Null) => return Ok(Value::Null),
469 Some(_) => {
470 return Err(SqlError::TypeMismatch {
471 expected: "TEXT".into(),
472 got: "non-text".into(),
473 })
474 }
475 None => None,
476 };
477
478 let matched = like_match(text, pat, esc_char);
479 Ok(Value::Boolean(if negated { !matched } else { matched }))
480}
481
482fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
483 let t: Vec<char> = text.chars().collect();
484 let p: Vec<char> = pattern.chars().collect();
485 like_match_impl(&t, &p, 0, 0, escape)
486}
487
488fn like_match_impl(
489 t: &[char],
490 p: &[char],
491 mut ti: usize,
492 mut pi: usize,
493 esc: Option<char>,
494) -> bool {
495 let mut star_pi: Option<usize> = None;
496 let mut star_ti: usize = 0;
497
498 while ti < t.len() {
499 if pi < p.len() {
500 if let Some(ec) = esc {
501 if p[pi] == ec && pi + 1 < p.len() {
502 pi += 1;
503 let pc_lower = p[pi].to_ascii_lowercase();
504 let tc_lower = t[ti].to_ascii_lowercase();
505 if pc_lower == tc_lower {
506 pi += 1;
507 ti += 1;
508 continue;
509 } else if let Some(sp) = star_pi {
510 pi = sp + 1;
511 star_ti += 1;
512 ti = star_ti;
513 continue;
514 } else {
515 return false;
516 }
517 }
518 }
519 if p[pi] == '%' {
520 star_pi = Some(pi);
521 star_ti = ti;
522 pi += 1;
523 continue;
524 }
525 if p[pi] == '_' {
526 pi += 1;
527 ti += 1;
528 continue;
529 }
530 if p[pi].eq_ignore_ascii_case(&t[ti]) {
531 pi += 1;
532 ti += 1;
533 continue;
534 }
535 }
536 if let Some(sp) = star_pi {
537 pi = sp + 1;
538 star_ti += 1;
539 ti = star_ti;
540 } else {
541 return false;
542 }
543 }
544
545 while pi < p.len() && p[pi] == '%' {
546 pi += 1;
547 }
548 pi == p.len()
549}
550
551fn eval_case(
552 operand: Option<&Expr>,
553 conditions: &[(Expr, Expr)],
554 else_result: Option<&Expr>,
555 col_map: &ColumnMap,
556 row: &[Value],
557) -> Result<Value> {
558 if let Some(op_expr) = operand {
559 let op_val = eval_expr(op_expr, col_map, row)?;
560 for (cond, result) in conditions {
561 let cond_val = eval_expr(cond, col_map, row)?;
562 if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
563 return eval_expr(result, col_map, row);
564 }
565 }
566 } else {
567 for (cond, result) in conditions {
568 let cond_val = eval_expr(cond, col_map, row)?;
569 if is_truthy(&cond_val) {
570 return eval_expr(result, col_map, row);
571 }
572 }
573 }
574 match else_result {
575 Some(e) => eval_expr(e, col_map, row),
576 None => Ok(Value::Null),
577 }
578}
579
580fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
581 if val.is_null() {
582 return Ok(Value::Null);
583 }
584 match target {
585 DataType::Integer => match val {
586 Value::Integer(_) => Ok(val.clone()),
587 Value::Real(r) => Ok(Value::Integer(*r as i64)),
588 Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
589 Value::Text(s) => s
590 .trim()
591 .parse::<i64>()
592 .map(Value::Integer)
593 .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
594 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
595 _ => Err(SqlError::InvalidValue(format!(
596 "cannot cast {} to INTEGER",
597 val.data_type()
598 ))),
599 },
600 DataType::Real => match val {
601 Value::Real(_) => Ok(val.clone()),
602 Value::Integer(i) => Ok(Value::Real(*i as f64)),
603 Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
604 Value::Text(s) => s
605 .trim()
606 .parse::<f64>()
607 .map(Value::Real)
608 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
609 _ => Err(SqlError::InvalidValue(format!(
610 "cannot cast {} to REAL",
611 val.data_type()
612 ))),
613 },
614 DataType::Text => Ok(Value::Text(value_to_text(val).into())),
615 DataType::Boolean => match val {
616 Value::Boolean(_) => Ok(val.clone()),
617 Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
618 Value::Text(s) => {
619 let lower = s.trim().to_ascii_lowercase();
620 match lower.as_str() {
621 "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
622 "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
623 _ => Err(SqlError::InvalidValue(format!(
624 "cannot cast '{s}' to BOOLEAN"
625 ))),
626 }
627 }
628 _ => Err(SqlError::InvalidValue(format!(
629 "cannot cast {} to BOOLEAN",
630 val.data_type()
631 ))),
632 },
633 DataType::Blob => match val {
634 Value::Blob(_) => Ok(val.clone()),
635 Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
636 _ => Err(SqlError::InvalidValue(format!(
637 "cannot cast {} to BLOB",
638 val.data_type()
639 ))),
640 },
641 DataType::Null => Ok(Value::Null),
642 }
643}
644
645fn eval_scalar_function(
646 name: &str,
647 args: &[Expr],
648 col_map: &ColumnMap,
649 row: &[Value],
650) -> Result<Value> {
651 let evaluated: Vec<Value> = args
652 .iter()
653 .map(|a| eval_expr(a, col_map, row))
654 .collect::<Result<Vec<_>>>()?;
655
656 match name {
657 "LENGTH" => {
658 check_args(name, &evaluated, 1)?;
659 match &evaluated[0] {
660 Value::Null => Ok(Value::Null),
661 Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
662 Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
663 _ => Ok(Value::Integer(
664 value_to_text(&evaluated[0]).chars().count() as i64
665 )),
666 }
667 }
668 "UPPER" => {
669 check_args(name, &evaluated, 1)?;
670 match &evaluated[0] {
671 Value::Null => Ok(Value::Null),
672 Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
673 _ => Ok(Value::Text(
674 value_to_text(&evaluated[0]).to_ascii_uppercase().into(),
675 )),
676 }
677 }
678 "LOWER" => {
679 check_args(name, &evaluated, 1)?;
680 match &evaluated[0] {
681 Value::Null => Ok(Value::Null),
682 Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
683 _ => Ok(Value::Text(
684 value_to_text(&evaluated[0]).to_ascii_lowercase().into(),
685 )),
686 }
687 }
688 "SUBSTR" | "SUBSTRING" => {
689 if evaluated.len() < 2 || evaluated.len() > 3 {
690 return Err(SqlError::InvalidValue(format!(
691 "{name} requires 2 or 3 arguments"
692 )));
693 }
694 if evaluated.iter().any(|v| v.is_null()) {
695 return Ok(Value::Null);
696 }
697 let s = value_to_text(&evaluated[0]);
698 let chars: Vec<char> = s.chars().collect();
699 let start = match &evaluated[1] {
700 Value::Integer(i) => *i,
701 _ => {
702 return Err(SqlError::TypeMismatch {
703 expected: "INTEGER".into(),
704 got: evaluated[1].data_type().to_string(),
705 })
706 }
707 };
708 let len = chars.len() as i64;
709
710 let (begin, count) = if evaluated.len() == 3 {
711 let cnt = match &evaluated[2] {
712 Value::Integer(i) => *i,
713 _ => {
714 return Err(SqlError::TypeMismatch {
715 expected: "INTEGER".into(),
716 got: evaluated[2].data_type().to_string(),
717 })
718 }
719 };
720 if start >= 1 {
721 let b = (start - 1).min(len) as usize;
722 let c = cnt.max(0) as usize;
723 (b, c)
724 } else if start == 0 {
725 let c = (cnt - 1).max(0) as usize;
726 (0usize, c)
727 } else {
728 let adjusted_cnt = (cnt + start - 1).max(0) as usize;
729 (0usize, adjusted_cnt)
730 }
731 } else if start >= 1 {
732 let b = (start - 1).min(len) as usize;
733 (b, chars.len() - b)
734 } else if start == 0 {
735 (0usize, chars.len())
736 } else {
737 let b = (len + start).max(0) as usize;
738 (b, chars.len() - b)
739 };
740
741 let result: String = chars.iter().skip(begin).take(count).collect();
742 Ok(Value::Text(result.into()))
743 }
744 "TRIM" | "LTRIM" | "RTRIM" => {
745 if evaluated.is_empty() || evaluated.len() > 2 {
746 return Err(SqlError::InvalidValue(format!(
747 "{name} requires 1 or 2 arguments"
748 )));
749 }
750 if evaluated[0].is_null() {
751 return Ok(Value::Null);
752 }
753 let s = value_to_text(&evaluated[0]);
754 let trim_chars: Vec<char> = if evaluated.len() == 2 {
755 if evaluated[1].is_null() {
756 return Ok(Value::Null);
757 }
758 value_to_text(&evaluated[1]).chars().collect()
759 } else {
760 vec![' ']
761 };
762 let result = match name {
763 "TRIM" => s
764 .trim_matches(|c: char| trim_chars.contains(&c))
765 .to_string(),
766 "LTRIM" => s
767 .trim_start_matches(|c: char| trim_chars.contains(&c))
768 .to_string(),
769 "RTRIM" => s
770 .trim_end_matches(|c: char| trim_chars.contains(&c))
771 .to_string(),
772 _ => unreachable!(),
773 };
774 Ok(Value::Text(result.into()))
775 }
776 "REPLACE" => {
777 check_args(name, &evaluated, 3)?;
778 if evaluated.iter().any(|v| v.is_null()) {
779 return Ok(Value::Null);
780 }
781 let s = value_to_text(&evaluated[0]);
782 let from = value_to_text(&evaluated[1]);
783 let to = value_to_text(&evaluated[2]);
784 if from.is_empty() {
785 return Ok(Value::Text(s.into()));
786 }
787 Ok(Value::Text(s.replace(&from, &to).into()))
788 }
789 "INSTR" => {
790 check_args(name, &evaluated, 2)?;
791 if evaluated.iter().any(|v| v.is_null()) {
792 return Ok(Value::Null);
793 }
794 let haystack = value_to_text(&evaluated[0]);
795 let needle = value_to_text(&evaluated[1]);
796 let pos = haystack
797 .find(&needle)
798 .map(|i| haystack[..i].chars().count() as i64 + 1)
799 .unwrap_or(0);
800 Ok(Value::Integer(pos))
801 }
802 "CONCAT" => {
803 if evaluated.is_empty() {
804 return Ok(Value::Text(CompactString::default()));
805 }
806 let mut result = String::new();
807 for v in &evaluated {
808 match v {
809 Value::Null => {}
810 _ => result.push_str(&value_to_text(v)),
811 }
812 }
813 Ok(Value::Text(result.into()))
814 }
815 "ABS" => {
816 check_args(name, &evaluated, 1)?;
817 match &evaluated[0] {
818 Value::Null => Ok(Value::Null),
819 Value::Integer(i) => i
820 .checked_abs()
821 .map(Value::Integer)
822 .ok_or(SqlError::IntegerOverflow),
823 Value::Real(r) => Ok(Value::Real(r.abs())),
824 _ => Err(SqlError::TypeMismatch {
825 expected: "numeric".into(),
826 got: evaluated[0].data_type().to_string(),
827 }),
828 }
829 }
830 "ROUND" => {
831 if evaluated.is_empty() || evaluated.len() > 2 {
832 return Err(SqlError::InvalidValue(
833 "ROUND requires 1 or 2 arguments".into(),
834 ));
835 }
836 if evaluated[0].is_null() {
837 return Ok(Value::Null);
838 }
839 let val = match &evaluated[0] {
840 Value::Integer(i) => *i as f64,
841 Value::Real(r) => *r,
842 _ => {
843 return Err(SqlError::TypeMismatch {
844 expected: "numeric".into(),
845 got: evaluated[0].data_type().to_string(),
846 })
847 }
848 };
849 let places = if evaluated.len() == 2 {
850 match &evaluated[1] {
851 Value::Null => return Ok(Value::Null),
852 Value::Integer(i) => *i,
853 _ => {
854 return Err(SqlError::TypeMismatch {
855 expected: "INTEGER".into(),
856 got: evaluated[1].data_type().to_string(),
857 })
858 }
859 }
860 } else {
861 0
862 };
863 let factor = 10f64.powi(places as i32);
864 let rounded = (val * factor).round() / factor;
865 Ok(Value::Real(rounded))
866 }
867 "CEIL" | "CEILING" => {
868 check_args(name, &evaluated, 1)?;
869 match &evaluated[0] {
870 Value::Null => Ok(Value::Null),
871 Value::Integer(i) => Ok(Value::Integer(*i)),
872 Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
873 _ => Err(SqlError::TypeMismatch {
874 expected: "numeric".into(),
875 got: evaluated[0].data_type().to_string(),
876 }),
877 }
878 }
879 "FLOOR" => {
880 check_args(name, &evaluated, 1)?;
881 match &evaluated[0] {
882 Value::Null => Ok(Value::Null),
883 Value::Integer(i) => Ok(Value::Integer(*i)),
884 Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
885 _ => Err(SqlError::TypeMismatch {
886 expected: "numeric".into(),
887 got: evaluated[0].data_type().to_string(),
888 }),
889 }
890 }
891 "SIGN" => {
892 check_args(name, &evaluated, 1)?;
893 match &evaluated[0] {
894 Value::Null => Ok(Value::Null),
895 Value::Integer(i) => Ok(Value::Integer(i.signum())),
896 Value::Real(r) => {
897 if *r > 0.0 {
898 Ok(Value::Integer(1))
899 } else if *r < 0.0 {
900 Ok(Value::Integer(-1))
901 } else {
902 Ok(Value::Integer(0))
903 }
904 }
905 _ => Err(SqlError::TypeMismatch {
906 expected: "numeric".into(),
907 got: evaluated[0].data_type().to_string(),
908 }),
909 }
910 }
911 "SQRT" => {
912 check_args(name, &evaluated, 1)?;
913 match &evaluated[0] {
914 Value::Null => Ok(Value::Null),
915 Value::Integer(i) => {
916 if *i < 0 {
917 Ok(Value::Null)
918 } else {
919 Ok(Value::Real((*i as f64).sqrt()))
920 }
921 }
922 Value::Real(r) => {
923 if *r < 0.0 {
924 Ok(Value::Null)
925 } else {
926 Ok(Value::Real(r.sqrt()))
927 }
928 }
929 _ => Err(SqlError::TypeMismatch {
930 expected: "numeric".into(),
931 got: evaluated[0].data_type().to_string(),
932 }),
933 }
934 }
935 "RANDOM" => {
936 check_args(name, &evaluated, 0)?;
937 use std::collections::hash_map::DefaultHasher;
938 use std::hash::{Hash, Hasher};
939 use std::time::SystemTime;
940 let mut hasher = DefaultHasher::new();
941 SystemTime::now().hash(&mut hasher);
942 std::thread::current().id().hash(&mut hasher);
943 let mut val = hasher.finish() as i64;
944 if val == i64::MIN {
945 val = i64::MAX;
946 }
947 Ok(Value::Integer(val))
948 }
949 "TYPEOF" => {
950 check_args(name, &evaluated, 1)?;
951 let type_name = match &evaluated[0] {
952 Value::Null => "null",
953 Value::Integer(_) => "integer",
954 Value::Real(_) => "real",
955 Value::Text(_) => "text",
956 Value::Blob(_) => "blob",
957 Value::Boolean(_) => "boolean",
958 };
959 Ok(Value::Text(type_name.into()))
960 }
961 "MIN" => {
962 check_args(name, &evaluated, 2)?;
963 if evaluated[0].is_null() {
964 return Ok(evaluated[1].clone());
965 }
966 if evaluated[1].is_null() {
967 return Ok(evaluated[0].clone());
968 }
969 if evaluated[0] <= evaluated[1] {
970 Ok(evaluated[0].clone())
971 } else {
972 Ok(evaluated[1].clone())
973 }
974 }
975 "MAX" => {
976 check_args(name, &evaluated, 2)?;
977 if evaluated[0].is_null() {
978 return Ok(evaluated[1].clone());
979 }
980 if evaluated[1].is_null() {
981 return Ok(evaluated[0].clone());
982 }
983 if evaluated[0] >= evaluated[1] {
984 Ok(evaluated[0].clone())
985 } else {
986 Ok(evaluated[1].clone())
987 }
988 }
989 "HEX" => {
990 check_args(name, &evaluated, 1)?;
991 match &evaluated[0] {
992 Value::Null => Ok(Value::Null),
993 Value::Blob(b) => {
994 let mut s = String::with_capacity(b.len() * 2);
995 for byte in b {
996 s.push_str(&format!("{byte:02X}"));
997 }
998 Ok(Value::Text(s.into()))
999 }
1000 Value::Text(s) => {
1001 let mut r = String::with_capacity(s.len() * 2);
1002 for byte in s.as_bytes() {
1003 r.push_str(&format!("{byte:02X}"));
1004 }
1005 Ok(Value::Text(r.into()))
1006 }
1007 _ => Ok(Value::Text(value_to_text(&evaluated[0]).into())),
1008 }
1009 }
1010 _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
1011 }
1012}
1013
1014fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
1015 if args.len() != expected {
1016 Err(SqlError::InvalidValue(format!(
1017 "{name} requires {expected} argument(s), got {}",
1018 args.len()
1019 )))
1020 } else {
1021 Ok(())
1022 }
1023}
1024
1025pub fn referenced_columns(expr: &Expr, columns: &[ColumnDef]) -> Vec<usize> {
1026 let mut indices = Vec::new();
1027 collect_column_refs(expr, columns, &mut indices);
1028 indices.sort_unstable();
1029 indices.dedup();
1030 indices
1031}
1032
1033fn collect_column_refs(expr: &Expr, columns: &[ColumnDef], out: &mut Vec<usize>) {
1034 match expr {
1035 Expr::Column(name) => {
1036 for (i, c) in columns.iter().enumerate() {
1037 if c.name == *name || c.name.ends_with(&format!(".{name}")) {
1038 out.push(i);
1039 break;
1040 }
1041 }
1042 }
1043 Expr::QualifiedColumn { table, column } => {
1044 let qualified = format!("{table}.{column}");
1045 if let Some(idx) = columns.iter().position(|c| c.name == qualified) {
1046 out.push(idx);
1047 } else {
1048 let matches: Vec<usize> = columns
1049 .iter()
1050 .enumerate()
1051 .filter(|(_, c)| c.name == *column)
1052 .map(|(i, _)| i)
1053 .collect();
1054 if matches.len() == 1 {
1055 out.push(matches[0]);
1056 }
1057 }
1058 }
1059 Expr::BinaryOp { left, right, .. } => {
1060 collect_column_refs(left, columns, out);
1061 collect_column_refs(right, columns, out);
1062 }
1063 Expr::UnaryOp { expr, .. } => {
1064 collect_column_refs(expr, columns, out);
1065 }
1066 Expr::IsNull(e) | Expr::IsNotNull(e) => {
1067 collect_column_refs(e, columns, out);
1068 }
1069 Expr::Function { args, .. } => {
1070 for arg in args {
1071 collect_column_refs(arg, columns, out);
1072 }
1073 }
1074 Expr::InSubquery { expr, .. } => {
1075 collect_column_refs(expr, columns, out);
1076 }
1077 Expr::InList { expr, list, .. } => {
1078 collect_column_refs(expr, columns, out);
1079 for item in list {
1080 collect_column_refs(item, columns, out);
1081 }
1082 }
1083 Expr::InSet { expr, .. } => {
1084 collect_column_refs(expr, columns, out);
1085 }
1086 Expr::Between {
1087 expr, low, high, ..
1088 } => {
1089 collect_column_refs(expr, columns, out);
1090 collect_column_refs(low, columns, out);
1091 collect_column_refs(high, columns, out);
1092 }
1093 Expr::Like {
1094 expr,
1095 pattern,
1096 escape,
1097 ..
1098 } => {
1099 collect_column_refs(expr, columns, out);
1100 collect_column_refs(pattern, columns, out);
1101 if let Some(esc) = escape {
1102 collect_column_refs(esc, columns, out);
1103 }
1104 }
1105 Expr::Case {
1106 operand,
1107 conditions,
1108 else_result,
1109 } => {
1110 if let Some(op) = operand {
1111 collect_column_refs(op, columns, out);
1112 }
1113 for (when, then) in conditions {
1114 collect_column_refs(when, columns, out);
1115 collect_column_refs(then, columns, out);
1116 }
1117 if let Some(e) = else_result {
1118 collect_column_refs(e, columns, out);
1119 }
1120 }
1121 Expr::Coalesce(args) => {
1122 for arg in args {
1123 collect_column_refs(arg, columns, out);
1124 }
1125 }
1126 Expr::Cast { expr, .. } => {
1127 collect_column_refs(expr, columns, out);
1128 }
1129 Expr::WindowFunction { args, spec, .. } => {
1130 for arg in args {
1131 collect_column_refs(arg, columns, out);
1132 }
1133 for pb in &spec.partition_by {
1134 collect_column_refs(pb, columns, out);
1135 }
1136 for ob in &spec.order_by {
1137 collect_column_refs(&ob.expr, columns, out);
1138 }
1139 }
1140 Expr::Literal(_)
1141 | Expr::Parameter(_)
1142 | Expr::CountStar
1143 | Expr::Exists { .. }
1144 | Expr::ScalarSubquery(_) => {}
1145 }
1146}
1147
1148pub fn is_truthy(val: &Value) -> bool {
1150 match val {
1151 Value::Boolean(b) => *b,
1152 Value::Integer(i) => *i != 0,
1153 Value::Null => false,
1154 _ => true,
1155 }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160 use super::*;
1161 use crate::types::DataType;
1162
1163 fn col(name: &str, dt: DataType, nullable: bool, pos: u16) -> ColumnDef {
1164 ColumnDef {
1165 name: name.into(),
1166 data_type: dt,
1167 nullable,
1168 position: pos,
1169 default_expr: None,
1170 default_sql: None,
1171 check_expr: None,
1172 check_sql: None,
1173 check_name: None,
1174 }
1175 }
1176
1177 fn test_columns() -> Vec<ColumnDef> {
1178 vec![
1179 col("id", DataType::Integer, false, 0),
1180 col("name", DataType::Text, true, 1),
1181 col("score", DataType::Real, true, 2),
1182 col("active", DataType::Boolean, false, 3),
1183 ]
1184 }
1185
1186 fn test_row() -> Vec<Value> {
1187 vec![
1188 Value::Integer(1),
1189 Value::Text("Alice".into()),
1190 Value::Real(95.5),
1191 Value::Boolean(true),
1192 ]
1193 }
1194
1195 #[test]
1196 fn eval_literal() {
1197 let cols = test_columns();
1198 let cm = ColumnMap::new(&cols);
1199 let row = test_row();
1200 let expr = Expr::Literal(Value::Integer(42));
1201 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(42));
1202 }
1203
1204 #[test]
1205 fn eval_column_ref() {
1206 let cols = test_columns();
1207 let cm = ColumnMap::new(&cols);
1208 let row = test_row();
1209 let expr = Expr::Column("name".into());
1210 assert_eq!(
1211 eval_expr(&expr, &cm, &row).unwrap(),
1212 Value::Text("Alice".into())
1213 );
1214 }
1215
1216 #[test]
1217 fn eval_column_case_insensitive() {
1218 let cols = test_columns();
1219 let cm = ColumnMap::new(&cols);
1220 let row = test_row();
1221 let expr = Expr::Column("name".into());
1222 assert_eq!(
1223 eval_expr(&expr, &cm, &row).unwrap(),
1224 Value::Text("Alice".into())
1225 );
1226 }
1227
1228 #[test]
1229 fn eval_arithmetic_int() {
1230 let cols = test_columns();
1231 let cm = ColumnMap::new(&cols);
1232 let row = test_row();
1233 let expr = Expr::BinaryOp {
1234 left: Box::new(Expr::Column("id".into())),
1235 op: BinOp::Add,
1236 right: Box::new(Expr::Literal(Value::Integer(10))),
1237 };
1238 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(11));
1239 }
1240
1241 #[test]
1242 fn eval_comparison() {
1243 let cols = test_columns();
1244 let cm = ColumnMap::new(&cols);
1245 let row = test_row();
1246 let expr = Expr::BinaryOp {
1247 left: Box::new(Expr::Column("score".into())),
1248 op: BinOp::Gt,
1249 right: Box::new(Expr::Literal(Value::Real(90.0))),
1250 };
1251 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1252 }
1253
1254 #[test]
1255 fn eval_null_propagation() {
1256 let cols = test_columns();
1257 let cm = ColumnMap::new(&cols);
1258 let row = vec![
1259 Value::Integer(1),
1260 Value::Null,
1261 Value::Null,
1262 Value::Boolean(true),
1263 ];
1264 let expr = Expr::BinaryOp {
1265 left: Box::new(Expr::Column("name".into())),
1266 op: BinOp::Eq,
1267 right: Box::new(Expr::Literal(Value::Text("test".into()))),
1268 };
1269 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1270 }
1271
1272 #[test]
1273 fn eval_and_three_valued() {
1274 let cols = test_columns();
1275 let cm = ColumnMap::new(&cols);
1276 let row = vec![
1277 Value::Integer(1),
1278 Value::Null,
1279 Value::Null,
1280 Value::Boolean(true),
1281 ];
1282
1283 let expr = Expr::BinaryOp {
1285 left: Box::new(Expr::Column("name".into())),
1286 op: BinOp::And,
1287 right: Box::new(Expr::Literal(Value::Boolean(false))),
1288 };
1289 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1290
1291 let expr = Expr::BinaryOp {
1293 left: Box::new(Expr::Column("name".into())),
1294 op: BinOp::And,
1295 right: Box::new(Expr::Literal(Value::Boolean(true))),
1296 };
1297 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1298 }
1299
1300 #[test]
1301 fn eval_or_three_valued() {
1302 let cols = test_columns();
1303 let cm = ColumnMap::new(&cols);
1304 let row = vec![
1305 Value::Integer(1),
1306 Value::Null,
1307 Value::Null,
1308 Value::Boolean(true),
1309 ];
1310
1311 let expr = Expr::BinaryOp {
1313 left: Box::new(Expr::Column("name".into())),
1314 op: BinOp::Or,
1315 right: Box::new(Expr::Literal(Value::Boolean(true))),
1316 };
1317 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1318
1319 let expr = Expr::BinaryOp {
1321 left: Box::new(Expr::Column("name".into())),
1322 op: BinOp::Or,
1323 right: Box::new(Expr::Literal(Value::Boolean(false))),
1324 };
1325 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1326 }
1327
1328 #[test]
1329 fn eval_is_null() {
1330 let cols = test_columns();
1331 let cm = ColumnMap::new(&cols);
1332 let row = vec![
1333 Value::Integer(1),
1334 Value::Null,
1335 Value::Null,
1336 Value::Boolean(true),
1337 ];
1338 let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1339 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1340
1341 let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1342 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1343 }
1344
1345 #[test]
1346 fn eval_not() {
1347 let cols = test_columns();
1348 let cm = ColumnMap::new(&cols);
1349 let row = test_row();
1350 let expr = Expr::UnaryOp {
1351 op: UnaryOp::Not,
1352 expr: Box::new(Expr::Column("active".into())),
1353 };
1354 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1355 }
1356
1357 #[test]
1358 fn eval_neg() {
1359 let cols = test_columns();
1360 let cm = ColumnMap::new(&cols);
1361 let row = test_row();
1362 let expr = Expr::UnaryOp {
1363 op: UnaryOp::Neg,
1364 expr: Box::new(Expr::Column("id".into())),
1365 };
1366 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(-1));
1367 }
1368
1369 #[test]
1370 fn eval_division_by_zero() {
1371 let cols = test_columns();
1372 let cm = ColumnMap::new(&cols);
1373 let row = test_row();
1374 let expr = Expr::BinaryOp {
1375 left: Box::new(Expr::Column("id".into())),
1376 op: BinOp::Div,
1377 right: Box::new(Expr::Literal(Value::Integer(0))),
1378 };
1379 assert!(matches!(
1380 eval_expr(&expr, &cm, &row),
1381 Err(SqlError::DivisionByZero)
1382 ));
1383 }
1384
1385 #[test]
1386 fn eval_mixed_numeric() {
1387 let cols = test_columns();
1388 let cm = ColumnMap::new(&cols);
1389 let row = test_row();
1390 let expr = Expr::BinaryOp {
1392 left: Box::new(Expr::Column("id".into())),
1393 op: BinOp::Add,
1394 right: Box::new(Expr::Column("score".into())),
1395 };
1396 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Real(96.5));
1397 }
1398
1399 #[test]
1400 fn is_truthy_values() {
1401 assert!(is_truthy(&Value::Boolean(true)));
1402 assert!(!is_truthy(&Value::Boolean(false)));
1403 assert!(!is_truthy(&Value::Null));
1404 assert!(is_truthy(&Value::Integer(1)));
1405 assert!(!is_truthy(&Value::Integer(0)));
1406 }
1407}