1use std::collections::HashMap;
4
5use crate::error::{Result, SqlError};
6use crate::parser::{BinOp, Expr, UnaryOp};
7use crate::types::{ColumnDef, CompactString, DataType, Value};
8
9pub struct ColumnMap {
10 exact: HashMap<String, usize>,
11 short: HashMap<String, ShortMatch>,
12}
13
14#[derive(Clone)]
15enum ShortMatch {
16 Unique(usize),
17 Ambiguous,
18}
19
20impl Clone for ColumnMap {
21 fn clone(&self) -> Self {
22 Self {
23 exact: self.exact.clone(),
24 short: self.short.clone(),
25 }
26 }
27}
28
29impl ColumnMap {
30 pub fn new(columns: &[ColumnDef]) -> Self {
31 let mut exact = HashMap::with_capacity(columns.len() * 2);
32 let mut short: HashMap<String, ShortMatch> = HashMap::with_capacity(columns.len());
33
34 for (i, col) in columns.iter().enumerate() {
35 let lower = col.name.to_ascii_lowercase();
36 exact.insert(lower.clone(), i);
37
38 let unqualified = if let Some(dot) = lower.rfind('.') {
39 &lower[dot + 1..]
40 } else {
41 &lower
42 };
43 short
44 .entry(unqualified.to_string())
45 .and_modify(|e| *e = ShortMatch::Ambiguous)
46 .or_insert(ShortMatch::Unique(i));
47 }
48
49 Self { exact, short }
50 }
51
52 pub(crate) fn resolve(&self, name: &str) -> Result<usize> {
53 if let Some(&idx) = self.exact.get(name) {
54 return Ok(idx);
55 }
56 match self.short.get(name) {
57 Some(ShortMatch::Unique(idx)) => Ok(*idx),
58 Some(ShortMatch::Ambiguous) => Err(SqlError::AmbiguousColumn(name.to_string())),
59 None => Err(SqlError::ColumnNotFound(name.to_string())),
60 }
61 }
62
63 pub(crate) fn resolve_qualified(&self, table: &str, column: &str) -> Result<usize> {
64 let qualified = format!("{table}.{column}");
65 if let Some(&idx) = self.exact.get(&qualified) {
66 return Ok(idx);
67 }
68 match self.short.get(column) {
69 Some(ShortMatch::Unique(idx)) => Ok(*idx),
70 _ => Err(SqlError::ColumnNotFound(format!("{table}.{column}"))),
71 }
72 }
73}
74
75pub fn eval_expr(expr: &Expr, col_map: &ColumnMap, row: &[Value]) -> Result<Value> {
76 match expr {
77 Expr::Literal(v) => Ok(v.clone()),
78
79 Expr::Column(name) => {
80 let idx = col_map.resolve(name)?;
81 Ok(row[idx].clone())
82 }
83
84 Expr::QualifiedColumn { table, column } => {
85 let idx = col_map.resolve_qualified(table, column)?;
86 Ok(row[idx].clone())
87 }
88
89 Expr::BinaryOp { left, op, right } => {
90 let lval = eval_expr(left, col_map, row)?;
91 let rval = eval_expr(right, col_map, row)?;
92 eval_binary_op(&lval, *op, &rval)
93 }
94
95 Expr::UnaryOp { op, expr } => {
96 let val = eval_expr(expr, col_map, row)?;
97 eval_unary_op(*op, &val)
98 }
99
100 Expr::IsNull(e) => {
101 let val = eval_expr(e, col_map, row)?;
102 Ok(Value::Boolean(val.is_null()))
103 }
104
105 Expr::IsNotNull(e) => {
106 let val = eval_expr(e, col_map, row)?;
107 Ok(Value::Boolean(!val.is_null()))
108 }
109
110 Expr::Function { name, args } => eval_scalar_function(name, args, col_map, row),
111
112 Expr::CountStar => Err(SqlError::Unsupported(
113 "COUNT(*) in non-aggregate context".into(),
114 )),
115
116 Expr::InList {
117 expr: e,
118 list,
119 negated,
120 } => {
121 let lhs = eval_expr(e, col_map, row)?;
122 eval_in_values(&lhs, list, col_map, row, *negated)
123 }
124
125 Expr::InSet {
126 expr: e,
127 values,
128 has_null,
129 negated,
130 } => {
131 let lhs = eval_expr(e, col_map, row)?;
132 eval_in_set(&lhs, values, *has_null, *negated)
133 }
134
135 Expr::Between {
136 expr: e,
137 low,
138 high,
139 negated,
140 } => {
141 let val = eval_expr(e, col_map, row)?;
142 let lo = eval_expr(low, col_map, row)?;
143 let hi = eval_expr(high, col_map, row)?;
144 eval_between(&val, &lo, &hi, *negated)
145 }
146
147 Expr::Like {
148 expr: e,
149 pattern,
150 escape,
151 negated,
152 } => {
153 let val = eval_expr(e, col_map, row)?;
154 let pat = eval_expr(pattern, col_map, row)?;
155 let esc = escape
156 .as_ref()
157 .map(|e| eval_expr(e, col_map, row))
158 .transpose()?;
159 eval_like(&val, &pat, esc.as_ref(), *negated)
160 }
161
162 Expr::Case {
163 operand,
164 conditions,
165 else_result,
166 } => eval_case(
167 operand.as_deref(),
168 conditions,
169 else_result.as_deref(),
170 col_map,
171 row,
172 ),
173
174 Expr::Coalesce(args) => {
175 for arg in args {
176 let val = eval_expr(arg, col_map, row)?;
177 if !val.is_null() {
178 return Ok(val);
179 }
180 }
181 Ok(Value::Null)
182 }
183
184 Expr::Cast { expr: e, data_type } => {
185 let val = eval_expr(e, col_map, row)?;
186 eval_cast(&val, *data_type)
187 }
188
189 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
190 SqlError::Unsupported("subquery not materialized (internal error)".into()),
191 ),
192
193 Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
194
195 Expr::WindowFunction { .. } => Err(SqlError::Unsupported(
196 "window functions are only allowed in SELECT columns".into(),
197 )),
198 }
199}
200
201fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
202 match op {
204 BinOp::And => return eval_and(left, right),
205 BinOp::Or => return eval_or(left, right),
206 _ => {}
207 }
208
209 if left.is_null() || right.is_null() {
211 return Ok(Value::Null);
212 }
213
214 match op {
215 BinOp::Eq => Ok(Value::Boolean(left == right)),
216 BinOp::NotEq => Ok(Value::Boolean(left != right)),
217 BinOp::Lt => Ok(Value::Boolean(left < right)),
218 BinOp::Gt => Ok(Value::Boolean(left > right)),
219 BinOp::LtEq => Ok(Value::Boolean(left <= right)),
220 BinOp::GtEq => Ok(Value::Boolean(left >= right)),
221 BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
222 BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
223 BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
224 BinOp::Div => {
225 match right {
226 Value::Integer(0) => return Err(SqlError::DivisionByZero),
227 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
228 _ => {}
229 }
230 eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
231 }
232 BinOp::Mod => {
233 match right {
234 Value::Integer(0) => return Err(SqlError::DivisionByZero),
235 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
236 _ => {}
237 }
238 eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
239 }
240 BinOp::Concat => {
241 let ls = value_to_text(left);
242 let rs = value_to_text(right);
243 Ok(Value::Text(format!("{ls}{rs}").into()))
244 }
245 BinOp::And | BinOp::Or => unreachable!(),
246 }
247}
248
249fn eval_and(left: &Value, right: &Value) -> Result<Value> {
251 let l = to_bool_or_null(left)?;
252 let r = to_bool_or_null(right)?;
253 match (l, r) {
254 (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
255 (Some(true), Some(true)) => Ok(Value::Boolean(true)),
256 _ => Ok(Value::Null),
257 }
258}
259
260fn eval_or(left: &Value, right: &Value) -> Result<Value> {
262 let l = to_bool_or_null(left)?;
263 let r = to_bool_or_null(right)?;
264 match (l, r) {
265 (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
266 (Some(false), Some(false)) => Ok(Value::Boolean(false)),
267 _ => Ok(Value::Null),
268 }
269}
270
271fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
272 match val {
273 Value::Boolean(b) => Ok(Some(*b)),
274 Value::Null => Ok(None),
275 Value::Integer(i) => Ok(Some(*i != 0)),
276 _ => Err(SqlError::TypeMismatch {
277 expected: "BOOLEAN".into(),
278 got: format!("{}", val.data_type()),
279 }),
280 }
281}
282
283fn eval_arithmetic(
284 left: &Value,
285 right: &Value,
286 int_op: fn(i64, i64) -> Option<i64>,
287 real_op: fn(f64, f64) -> f64,
288) -> Result<Value> {
289 match (left, right) {
290 (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
291 .map(Value::Integer)
292 .ok_or(SqlError::IntegerOverflow),
293 (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
294 (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
295 (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
296 _ => Err(SqlError::TypeMismatch {
297 expected: "numeric".into(),
298 got: format!("{} and {}", left.data_type(), right.data_type()),
299 }),
300 }
301}
302
303fn eval_in_values(
304 lhs: &Value,
305 list: &[Expr],
306 col_map: &ColumnMap,
307 row: &[Value],
308 negated: bool,
309) -> Result<Value> {
310 if list.is_empty() {
311 return Ok(Value::Boolean(negated));
312 }
313 if lhs.is_null() {
314 return Ok(Value::Null);
315 }
316 let mut has_null = false;
317 for item in list {
318 let rhs = eval_expr(item, col_map, row)?;
319 if rhs.is_null() {
320 has_null = true;
321 } else if lhs == &rhs {
322 return Ok(Value::Boolean(!negated));
323 }
324 }
325 if has_null {
326 Ok(Value::Null)
327 } else {
328 Ok(Value::Boolean(negated))
329 }
330}
331
332fn eval_in_set(
333 lhs: &Value,
334 values: &std::collections::HashSet<Value>,
335 has_null: bool,
336 negated: bool,
337) -> Result<Value> {
338 if values.is_empty() && !has_null {
339 return Ok(Value::Boolean(negated));
340 }
341 if lhs.is_null() {
342 return Ok(Value::Null);
343 }
344 if values.contains(lhs) {
345 return Ok(Value::Boolean(!negated));
346 }
347 if has_null {
348 Ok(Value::Null)
349 } else {
350 Ok(Value::Boolean(negated))
351 }
352}
353
354fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
355 if val.is_null() {
356 return Ok(Value::Null);
357 }
358 match op {
359 UnaryOp::Neg => match val {
360 Value::Integer(i) => i
361 .checked_neg()
362 .map(Value::Integer)
363 .ok_or(SqlError::IntegerOverflow),
364 Value::Real(r) => Ok(Value::Real(-r)),
365 _ => Err(SqlError::TypeMismatch {
366 expected: "numeric".into(),
367 got: format!("{}", val.data_type()),
368 }),
369 },
370 UnaryOp::Not => match val {
371 Value::Boolean(b) => Ok(Value::Boolean(!b)),
372 Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
373 _ => Err(SqlError::TypeMismatch {
374 expected: "BOOLEAN".into(),
375 got: format!("{}", val.data_type()),
376 }),
377 },
378 }
379}
380
381fn value_to_text(val: &Value) -> String {
382 match val {
383 Value::Text(s) => s.to_string(),
384 Value::Integer(i) => i.to_string(),
385 Value::Real(r) => {
386 if r.fract() == 0.0 && r.is_finite() {
387 format!("{r:.1}")
388 } else {
389 format!("{r}")
390 }
391 }
392 Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
393 Value::Null => String::new(),
394 Value::Blob(b) => {
395 let mut s = String::with_capacity(b.len() * 2);
396 for byte in b {
397 s.push_str(&format!("{byte:02X}"));
398 }
399 s
400 }
401 }
402}
403
404fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
405 if val.is_null() || low.is_null() || high.is_null() {
406 let ge = if val.is_null() || low.is_null() {
407 None
408 } else {
409 Some(*val >= *low)
410 };
411 let le = if val.is_null() || high.is_null() {
412 None
413 } else {
414 Some(*val <= *high)
415 };
416
417 let result = match (ge, le) {
418 (Some(false), _) | (_, Some(false)) => Some(false),
419 (Some(true), Some(true)) => Some(true),
420 _ => None,
421 };
422
423 return match result {
424 Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
425 None => Ok(Value::Null),
426 };
427 }
428
429 let in_range = *val >= *low && *val <= *high;
430 Ok(Value::Boolean(if negated { !in_range } else { in_range }))
431}
432
433const MAX_LIKE_PATTERN_LEN: usize = 10_000;
434
435fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
436 if val.is_null() || pattern.is_null() {
437 return Ok(Value::Null);
438 }
439 let text = match val {
440 Value::Text(s) => s.as_str(),
441 _ => {
442 return Err(SqlError::TypeMismatch {
443 expected: "TEXT".into(),
444 got: val.data_type().to_string(),
445 })
446 }
447 };
448 let pat = match pattern {
449 Value::Text(s) => s.as_str(),
450 _ => {
451 return Err(SqlError::TypeMismatch {
452 expected: "TEXT".into(),
453 got: pattern.data_type().to_string(),
454 })
455 }
456 };
457
458 if pat.len() > MAX_LIKE_PATTERN_LEN {
459 return Err(SqlError::InvalidValue(format!(
460 "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
461 pat.len()
462 )));
463 }
464
465 let esc_char = match escape {
466 Some(Value::Text(s)) => {
467 let mut chars = s.chars();
468 let c = chars.next().ok_or_else(|| {
469 SqlError::InvalidValue("ESCAPE must be a single character".into())
470 })?;
471 if chars.next().is_some() {
472 return Err(SqlError::InvalidValue(
473 "ESCAPE must be a single character".into(),
474 ));
475 }
476 Some(c)
477 }
478 Some(Value::Null) => return Ok(Value::Null),
479 Some(_) => {
480 return Err(SqlError::TypeMismatch {
481 expected: "TEXT".into(),
482 got: "non-text".into(),
483 })
484 }
485 None => None,
486 };
487
488 let matched = like_match(text, pat, esc_char);
489 Ok(Value::Boolean(if negated { !matched } else { matched }))
490}
491
492fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
493 let t: Vec<char> = text.chars().collect();
494 let p: Vec<char> = pattern.chars().collect();
495 like_match_impl(&t, &p, 0, 0, escape)
496}
497
498fn like_match_impl(
499 t: &[char],
500 p: &[char],
501 mut ti: usize,
502 mut pi: usize,
503 esc: Option<char>,
504) -> bool {
505 let mut star_pi: Option<usize> = None;
506 let mut star_ti: usize = 0;
507
508 while ti < t.len() {
509 if pi < p.len() {
510 if let Some(ec) = esc {
511 if p[pi] == ec && pi + 1 < p.len() {
512 pi += 1;
513 let pc_lower = p[pi].to_ascii_lowercase();
514 let tc_lower = t[ti].to_ascii_lowercase();
515 if pc_lower == tc_lower {
516 pi += 1;
517 ti += 1;
518 continue;
519 } else if let Some(sp) = star_pi {
520 pi = sp + 1;
521 star_ti += 1;
522 ti = star_ti;
523 continue;
524 } else {
525 return false;
526 }
527 }
528 }
529 if p[pi] == '%' {
530 star_pi = Some(pi);
531 star_ti = ti;
532 pi += 1;
533 continue;
534 }
535 if p[pi] == '_' {
536 pi += 1;
537 ti += 1;
538 continue;
539 }
540 if p[pi].eq_ignore_ascii_case(&t[ti]) {
541 pi += 1;
542 ti += 1;
543 continue;
544 }
545 }
546 if let Some(sp) = star_pi {
547 pi = sp + 1;
548 star_ti += 1;
549 ti = star_ti;
550 } else {
551 return false;
552 }
553 }
554
555 while pi < p.len() && p[pi] == '%' {
556 pi += 1;
557 }
558 pi == p.len()
559}
560
561fn eval_case(
562 operand: Option<&Expr>,
563 conditions: &[(Expr, Expr)],
564 else_result: Option<&Expr>,
565 col_map: &ColumnMap,
566 row: &[Value],
567) -> Result<Value> {
568 if let Some(op_expr) = operand {
569 let op_val = eval_expr(op_expr, col_map, row)?;
570 for (cond, result) in conditions {
571 let cond_val = eval_expr(cond, col_map, row)?;
572 if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
573 return eval_expr(result, col_map, row);
574 }
575 }
576 } else {
577 for (cond, result) in conditions {
578 let cond_val = eval_expr(cond, col_map, row)?;
579 if is_truthy(&cond_val) {
580 return eval_expr(result, col_map, row);
581 }
582 }
583 }
584 match else_result {
585 Some(e) => eval_expr(e, col_map, row),
586 None => Ok(Value::Null),
587 }
588}
589
590fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
591 if val.is_null() {
592 return Ok(Value::Null);
593 }
594 match target {
595 DataType::Integer => match val {
596 Value::Integer(_) => Ok(val.clone()),
597 Value::Real(r) => Ok(Value::Integer(*r as i64)),
598 Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
599 Value::Text(s) => s
600 .trim()
601 .parse::<i64>()
602 .map(Value::Integer)
603 .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
604 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
605 _ => Err(SqlError::InvalidValue(format!(
606 "cannot cast {} to INTEGER",
607 val.data_type()
608 ))),
609 },
610 DataType::Real => match val {
611 Value::Real(_) => Ok(val.clone()),
612 Value::Integer(i) => Ok(Value::Real(*i as f64)),
613 Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
614 Value::Text(s) => s
615 .trim()
616 .parse::<f64>()
617 .map(Value::Real)
618 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
619 _ => Err(SqlError::InvalidValue(format!(
620 "cannot cast {} to REAL",
621 val.data_type()
622 ))),
623 },
624 DataType::Text => Ok(Value::Text(value_to_text(val).into())),
625 DataType::Boolean => match val {
626 Value::Boolean(_) => Ok(val.clone()),
627 Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
628 Value::Text(s) => {
629 let lower = s.trim().to_ascii_lowercase();
630 match lower.as_str() {
631 "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
632 "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
633 _ => Err(SqlError::InvalidValue(format!(
634 "cannot cast '{s}' to BOOLEAN"
635 ))),
636 }
637 }
638 _ => Err(SqlError::InvalidValue(format!(
639 "cannot cast {} to BOOLEAN",
640 val.data_type()
641 ))),
642 },
643 DataType::Blob => match val {
644 Value::Blob(_) => Ok(val.clone()),
645 Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
646 _ => Err(SqlError::InvalidValue(format!(
647 "cannot cast {} to BLOB",
648 val.data_type()
649 ))),
650 },
651 DataType::Null => Ok(Value::Null),
652 }
653}
654
655fn eval_scalar_function(
656 name: &str,
657 args: &[Expr],
658 col_map: &ColumnMap,
659 row: &[Value],
660) -> Result<Value> {
661 let evaluated: Vec<Value> = args
662 .iter()
663 .map(|a| eval_expr(a, col_map, row))
664 .collect::<Result<Vec<_>>>()?;
665
666 match name {
667 "LENGTH" => {
668 check_args(name, &evaluated, 1)?;
669 match &evaluated[0] {
670 Value::Null => Ok(Value::Null),
671 Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
672 Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
673 _ => Ok(Value::Integer(
674 value_to_text(&evaluated[0]).chars().count() as i64
675 )),
676 }
677 }
678 "UPPER" => {
679 check_args(name, &evaluated, 1)?;
680 match &evaluated[0] {
681 Value::Null => Ok(Value::Null),
682 Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
683 _ => Ok(Value::Text(
684 value_to_text(&evaluated[0]).to_ascii_uppercase().into(),
685 )),
686 }
687 }
688 "LOWER" => {
689 check_args(name, &evaluated, 1)?;
690 match &evaluated[0] {
691 Value::Null => Ok(Value::Null),
692 Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
693 _ => Ok(Value::Text(
694 value_to_text(&evaluated[0]).to_ascii_lowercase().into(),
695 )),
696 }
697 }
698 "SUBSTR" | "SUBSTRING" => {
699 if evaluated.len() < 2 || evaluated.len() > 3 {
700 return Err(SqlError::InvalidValue(format!(
701 "{name} requires 2 or 3 arguments"
702 )));
703 }
704 if evaluated.iter().any(|v| v.is_null()) {
705 return Ok(Value::Null);
706 }
707 let s = value_to_text(&evaluated[0]);
708 let chars: Vec<char> = s.chars().collect();
709 let start = match &evaluated[1] {
710 Value::Integer(i) => *i,
711 _ => {
712 return Err(SqlError::TypeMismatch {
713 expected: "INTEGER".into(),
714 got: evaluated[1].data_type().to_string(),
715 })
716 }
717 };
718 let len = chars.len() as i64;
719
720 let (begin, count) = if evaluated.len() == 3 {
721 let cnt = match &evaluated[2] {
722 Value::Integer(i) => *i,
723 _ => {
724 return Err(SqlError::TypeMismatch {
725 expected: "INTEGER".into(),
726 got: evaluated[2].data_type().to_string(),
727 })
728 }
729 };
730 if start >= 1 {
731 let b = (start - 1).min(len) as usize;
732 let c = cnt.max(0) as usize;
733 (b, c)
734 } else if start == 0 {
735 let c = (cnt - 1).max(0) as usize;
736 (0usize, c)
737 } else {
738 let adjusted_cnt = (cnt + start - 1).max(0) as usize;
739 (0usize, adjusted_cnt)
740 }
741 } else if start >= 1 {
742 let b = (start - 1).min(len) as usize;
743 (b, chars.len() - b)
744 } else if start == 0 {
745 (0usize, chars.len())
746 } else {
747 let b = (len + start).max(0) as usize;
748 (b, chars.len() - b)
749 };
750
751 let result: String = chars.iter().skip(begin).take(count).collect();
752 Ok(Value::Text(result.into()))
753 }
754 "TRIM" | "LTRIM" | "RTRIM" => {
755 if evaluated.is_empty() || evaluated.len() > 2 {
756 return Err(SqlError::InvalidValue(format!(
757 "{name} requires 1 or 2 arguments"
758 )));
759 }
760 if evaluated[0].is_null() {
761 return Ok(Value::Null);
762 }
763 let s = value_to_text(&evaluated[0]);
764 let trim_chars: Vec<char> = if evaluated.len() == 2 {
765 if evaluated[1].is_null() {
766 return Ok(Value::Null);
767 }
768 value_to_text(&evaluated[1]).chars().collect()
769 } else {
770 vec![' ']
771 };
772 let result = match name {
773 "TRIM" => s
774 .trim_matches(|c: char| trim_chars.contains(&c))
775 .to_string(),
776 "LTRIM" => s
777 .trim_start_matches(|c: char| trim_chars.contains(&c))
778 .to_string(),
779 "RTRIM" => s
780 .trim_end_matches(|c: char| trim_chars.contains(&c))
781 .to_string(),
782 _ => unreachable!(),
783 };
784 Ok(Value::Text(result.into()))
785 }
786 "REPLACE" => {
787 check_args(name, &evaluated, 3)?;
788 if evaluated.iter().any(|v| v.is_null()) {
789 return Ok(Value::Null);
790 }
791 let s = value_to_text(&evaluated[0]);
792 let from = value_to_text(&evaluated[1]);
793 let to = value_to_text(&evaluated[2]);
794 if from.is_empty() {
795 return Ok(Value::Text(s.into()));
796 }
797 Ok(Value::Text(s.replace(&from, &to).into()))
798 }
799 "INSTR" => {
800 check_args(name, &evaluated, 2)?;
801 if evaluated.iter().any(|v| v.is_null()) {
802 return Ok(Value::Null);
803 }
804 let haystack = value_to_text(&evaluated[0]);
805 let needle = value_to_text(&evaluated[1]);
806 let pos = haystack
807 .find(&needle)
808 .map(|i| haystack[..i].chars().count() as i64 + 1)
809 .unwrap_or(0);
810 Ok(Value::Integer(pos))
811 }
812 "CONCAT" => {
813 if evaluated.is_empty() {
814 return Ok(Value::Text(CompactString::default()));
815 }
816 let mut result = String::new();
817 for v in &evaluated {
818 match v {
819 Value::Null => {}
820 _ => result.push_str(&value_to_text(v)),
821 }
822 }
823 Ok(Value::Text(result.into()))
824 }
825 "ABS" => {
826 check_args(name, &evaluated, 1)?;
827 match &evaluated[0] {
828 Value::Null => Ok(Value::Null),
829 Value::Integer(i) => i
830 .checked_abs()
831 .map(Value::Integer)
832 .ok_or(SqlError::IntegerOverflow),
833 Value::Real(r) => Ok(Value::Real(r.abs())),
834 _ => Err(SqlError::TypeMismatch {
835 expected: "numeric".into(),
836 got: evaluated[0].data_type().to_string(),
837 }),
838 }
839 }
840 "ROUND" => {
841 if evaluated.is_empty() || evaluated.len() > 2 {
842 return Err(SqlError::InvalidValue(
843 "ROUND requires 1 or 2 arguments".into(),
844 ));
845 }
846 if evaluated[0].is_null() {
847 return Ok(Value::Null);
848 }
849 let val = match &evaluated[0] {
850 Value::Integer(i) => *i as f64,
851 Value::Real(r) => *r,
852 _ => {
853 return Err(SqlError::TypeMismatch {
854 expected: "numeric".into(),
855 got: evaluated[0].data_type().to_string(),
856 })
857 }
858 };
859 let places = if evaluated.len() == 2 {
860 match &evaluated[1] {
861 Value::Null => return Ok(Value::Null),
862 Value::Integer(i) => *i,
863 _ => {
864 return Err(SqlError::TypeMismatch {
865 expected: "INTEGER".into(),
866 got: evaluated[1].data_type().to_string(),
867 })
868 }
869 }
870 } else {
871 0
872 };
873 let factor = 10f64.powi(places as i32);
874 let rounded = (val * factor).round() / factor;
875 Ok(Value::Real(rounded))
876 }
877 "CEIL" | "CEILING" => {
878 check_args(name, &evaluated, 1)?;
879 match &evaluated[0] {
880 Value::Null => Ok(Value::Null),
881 Value::Integer(i) => Ok(Value::Integer(*i)),
882 Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
883 _ => Err(SqlError::TypeMismatch {
884 expected: "numeric".into(),
885 got: evaluated[0].data_type().to_string(),
886 }),
887 }
888 }
889 "FLOOR" => {
890 check_args(name, &evaluated, 1)?;
891 match &evaluated[0] {
892 Value::Null => Ok(Value::Null),
893 Value::Integer(i) => Ok(Value::Integer(*i)),
894 Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
895 _ => Err(SqlError::TypeMismatch {
896 expected: "numeric".into(),
897 got: evaluated[0].data_type().to_string(),
898 }),
899 }
900 }
901 "SIGN" => {
902 check_args(name, &evaluated, 1)?;
903 match &evaluated[0] {
904 Value::Null => Ok(Value::Null),
905 Value::Integer(i) => Ok(Value::Integer(i.signum())),
906 Value::Real(r) => {
907 if *r > 0.0 {
908 Ok(Value::Integer(1))
909 } else if *r < 0.0 {
910 Ok(Value::Integer(-1))
911 } else {
912 Ok(Value::Integer(0))
913 }
914 }
915 _ => Err(SqlError::TypeMismatch {
916 expected: "numeric".into(),
917 got: evaluated[0].data_type().to_string(),
918 }),
919 }
920 }
921 "SQRT" => {
922 check_args(name, &evaluated, 1)?;
923 match &evaluated[0] {
924 Value::Null => Ok(Value::Null),
925 Value::Integer(i) => {
926 if *i < 0 {
927 Ok(Value::Null)
928 } else {
929 Ok(Value::Real((*i as f64).sqrt()))
930 }
931 }
932 Value::Real(r) => {
933 if *r < 0.0 {
934 Ok(Value::Null)
935 } else {
936 Ok(Value::Real(r.sqrt()))
937 }
938 }
939 _ => Err(SqlError::TypeMismatch {
940 expected: "numeric".into(),
941 got: evaluated[0].data_type().to_string(),
942 }),
943 }
944 }
945 "RANDOM" => {
946 check_args(name, &evaluated, 0)?;
947 use std::collections::hash_map::DefaultHasher;
948 use std::hash::{Hash, Hasher};
949 use std::time::SystemTime;
950 let mut hasher = DefaultHasher::new();
951 SystemTime::now().hash(&mut hasher);
952 std::thread::current().id().hash(&mut hasher);
953 let mut val = hasher.finish() as i64;
954 if val == i64::MIN {
955 val = i64::MAX;
956 }
957 Ok(Value::Integer(val))
958 }
959 "TYPEOF" => {
960 check_args(name, &evaluated, 1)?;
961 let type_name = match &evaluated[0] {
962 Value::Null => "null",
963 Value::Integer(_) => "integer",
964 Value::Real(_) => "real",
965 Value::Text(_) => "text",
966 Value::Blob(_) => "blob",
967 Value::Boolean(_) => "boolean",
968 };
969 Ok(Value::Text(type_name.into()))
970 }
971 "MIN" => {
972 check_args(name, &evaluated, 2)?;
973 if evaluated[0].is_null() {
974 return Ok(evaluated[1].clone());
975 }
976 if evaluated[1].is_null() {
977 return Ok(evaluated[0].clone());
978 }
979 if evaluated[0] <= evaluated[1] {
980 Ok(evaluated[0].clone())
981 } else {
982 Ok(evaluated[1].clone())
983 }
984 }
985 "MAX" => {
986 check_args(name, &evaluated, 2)?;
987 if evaluated[0].is_null() {
988 return Ok(evaluated[1].clone());
989 }
990 if evaluated[1].is_null() {
991 return Ok(evaluated[0].clone());
992 }
993 if evaluated[0] >= evaluated[1] {
994 Ok(evaluated[0].clone())
995 } else {
996 Ok(evaluated[1].clone())
997 }
998 }
999 "HEX" => {
1000 check_args(name, &evaluated, 1)?;
1001 match &evaluated[0] {
1002 Value::Null => Ok(Value::Null),
1003 Value::Blob(b) => {
1004 let mut s = String::with_capacity(b.len() * 2);
1005 for byte in b {
1006 s.push_str(&format!("{byte:02X}"));
1007 }
1008 Ok(Value::Text(s.into()))
1009 }
1010 Value::Text(s) => {
1011 let mut r = String::with_capacity(s.len() * 2);
1012 for byte in s.as_bytes() {
1013 r.push_str(&format!("{byte:02X}"));
1014 }
1015 Ok(Value::Text(r.into()))
1016 }
1017 _ => Ok(Value::Text(value_to_text(&evaluated[0]).into())),
1018 }
1019 }
1020 _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
1021 }
1022}
1023
1024fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
1025 if args.len() != expected {
1026 Err(SqlError::InvalidValue(format!(
1027 "{name} requires {expected} argument(s), got {}",
1028 args.len()
1029 )))
1030 } else {
1031 Ok(())
1032 }
1033}
1034
1035pub fn referenced_columns(expr: &Expr, columns: &[ColumnDef]) -> Vec<usize> {
1036 let mut indices = Vec::new();
1037 collect_column_refs(expr, columns, &mut indices);
1038 indices.sort_unstable();
1039 indices.dedup();
1040 indices
1041}
1042
1043fn collect_column_refs(expr: &Expr, columns: &[ColumnDef], out: &mut Vec<usize>) {
1044 match expr {
1045 Expr::Column(name) => {
1046 for (i, c) in columns.iter().enumerate() {
1047 if c.name == *name || c.name.ends_with(&format!(".{name}")) {
1048 out.push(i);
1049 break;
1050 }
1051 }
1052 }
1053 Expr::QualifiedColumn { table, column } => {
1054 let qualified = format!("{table}.{column}");
1055 if let Some(idx) = columns.iter().position(|c| c.name == qualified) {
1056 out.push(idx);
1057 } else {
1058 let matches: Vec<usize> = columns
1059 .iter()
1060 .enumerate()
1061 .filter(|(_, c)| c.name == *column)
1062 .map(|(i, _)| i)
1063 .collect();
1064 if matches.len() == 1 {
1065 out.push(matches[0]);
1066 }
1067 }
1068 }
1069 Expr::BinaryOp { left, right, .. } => {
1070 collect_column_refs(left, columns, out);
1071 collect_column_refs(right, columns, out);
1072 }
1073 Expr::UnaryOp { expr, .. } => {
1074 collect_column_refs(expr, columns, out);
1075 }
1076 Expr::IsNull(e) | Expr::IsNotNull(e) => {
1077 collect_column_refs(e, columns, out);
1078 }
1079 Expr::Function { args, .. } => {
1080 for arg in args {
1081 collect_column_refs(arg, columns, out);
1082 }
1083 }
1084 Expr::InSubquery { expr, .. } => {
1085 collect_column_refs(expr, columns, out);
1086 }
1087 Expr::InList { expr, list, .. } => {
1088 collect_column_refs(expr, columns, out);
1089 for item in list {
1090 collect_column_refs(item, columns, out);
1091 }
1092 }
1093 Expr::InSet { expr, .. } => {
1094 collect_column_refs(expr, columns, out);
1095 }
1096 Expr::Between {
1097 expr, low, high, ..
1098 } => {
1099 collect_column_refs(expr, columns, out);
1100 collect_column_refs(low, columns, out);
1101 collect_column_refs(high, columns, out);
1102 }
1103 Expr::Like {
1104 expr,
1105 pattern,
1106 escape,
1107 ..
1108 } => {
1109 collect_column_refs(expr, columns, out);
1110 collect_column_refs(pattern, columns, out);
1111 if let Some(esc) = escape {
1112 collect_column_refs(esc, columns, out);
1113 }
1114 }
1115 Expr::Case {
1116 operand,
1117 conditions,
1118 else_result,
1119 } => {
1120 if let Some(op) = operand {
1121 collect_column_refs(op, columns, out);
1122 }
1123 for (when, then) in conditions {
1124 collect_column_refs(when, columns, out);
1125 collect_column_refs(then, columns, out);
1126 }
1127 if let Some(e) = else_result {
1128 collect_column_refs(e, columns, out);
1129 }
1130 }
1131 Expr::Coalesce(args) => {
1132 for arg in args {
1133 collect_column_refs(arg, columns, out);
1134 }
1135 }
1136 Expr::Cast { expr, .. } => {
1137 collect_column_refs(expr, columns, out);
1138 }
1139 Expr::WindowFunction { args, spec, .. } => {
1140 for arg in args {
1141 collect_column_refs(arg, columns, out);
1142 }
1143 for pb in &spec.partition_by {
1144 collect_column_refs(pb, columns, out);
1145 }
1146 for ob in &spec.order_by {
1147 collect_column_refs(&ob.expr, columns, out);
1148 }
1149 }
1150 Expr::Literal(_)
1151 | Expr::Parameter(_)
1152 | Expr::CountStar
1153 | Expr::Exists { .. }
1154 | Expr::ScalarSubquery(_) => {}
1155 }
1156}
1157
1158pub fn is_truthy(val: &Value) -> bool {
1160 match val {
1161 Value::Boolean(b) => *b,
1162 Value::Integer(i) => *i != 0,
1163 Value::Null => false,
1164 _ => true,
1165 }
1166}
1167
1168#[cfg(test)]
1169mod tests {
1170 use super::*;
1171 use crate::types::DataType;
1172
1173 fn col(name: &str, dt: DataType, nullable: bool, pos: u16) -> ColumnDef {
1174 ColumnDef {
1175 name: name.into(),
1176 data_type: dt,
1177 nullable,
1178 position: pos,
1179 default_expr: None,
1180 default_sql: None,
1181 check_expr: None,
1182 check_sql: None,
1183 check_name: None,
1184 }
1185 }
1186
1187 fn test_columns() -> Vec<ColumnDef> {
1188 vec![
1189 col("id", DataType::Integer, false, 0),
1190 col("name", DataType::Text, true, 1),
1191 col("score", DataType::Real, true, 2),
1192 col("active", DataType::Boolean, false, 3),
1193 ]
1194 }
1195
1196 fn test_row() -> Vec<Value> {
1197 vec![
1198 Value::Integer(1),
1199 Value::Text("Alice".into()),
1200 Value::Real(95.5),
1201 Value::Boolean(true),
1202 ]
1203 }
1204
1205 #[test]
1206 fn eval_literal() {
1207 let cols = test_columns();
1208 let cm = ColumnMap::new(&cols);
1209 let row = test_row();
1210 let expr = Expr::Literal(Value::Integer(42));
1211 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(42));
1212 }
1213
1214 #[test]
1215 fn eval_column_ref() {
1216 let cols = test_columns();
1217 let cm = ColumnMap::new(&cols);
1218 let row = test_row();
1219 let expr = Expr::Column("name".into());
1220 assert_eq!(
1221 eval_expr(&expr, &cm, &row).unwrap(),
1222 Value::Text("Alice".into())
1223 );
1224 }
1225
1226 #[test]
1227 fn eval_column_case_insensitive() {
1228 let cols = test_columns();
1229 let cm = ColumnMap::new(&cols);
1230 let row = test_row();
1231 let expr = Expr::Column("name".into());
1232 assert_eq!(
1233 eval_expr(&expr, &cm, &row).unwrap(),
1234 Value::Text("Alice".into())
1235 );
1236 }
1237
1238 #[test]
1239 fn eval_arithmetic_int() {
1240 let cols = test_columns();
1241 let cm = ColumnMap::new(&cols);
1242 let row = test_row();
1243 let expr = Expr::BinaryOp {
1244 left: Box::new(Expr::Column("id".into())),
1245 op: BinOp::Add,
1246 right: Box::new(Expr::Literal(Value::Integer(10))),
1247 };
1248 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(11));
1249 }
1250
1251 #[test]
1252 fn eval_comparison() {
1253 let cols = test_columns();
1254 let cm = ColumnMap::new(&cols);
1255 let row = test_row();
1256 let expr = Expr::BinaryOp {
1257 left: Box::new(Expr::Column("score".into())),
1258 op: BinOp::Gt,
1259 right: Box::new(Expr::Literal(Value::Real(90.0))),
1260 };
1261 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1262 }
1263
1264 #[test]
1265 fn eval_null_propagation() {
1266 let cols = test_columns();
1267 let cm = ColumnMap::new(&cols);
1268 let row = vec![
1269 Value::Integer(1),
1270 Value::Null,
1271 Value::Null,
1272 Value::Boolean(true),
1273 ];
1274 let expr = Expr::BinaryOp {
1275 left: Box::new(Expr::Column("name".into())),
1276 op: BinOp::Eq,
1277 right: Box::new(Expr::Literal(Value::Text("test".into()))),
1278 };
1279 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1280 }
1281
1282 #[test]
1283 fn eval_and_three_valued() {
1284 let cols = test_columns();
1285 let cm = ColumnMap::new(&cols);
1286 let row = vec![
1287 Value::Integer(1),
1288 Value::Null,
1289 Value::Null,
1290 Value::Boolean(true),
1291 ];
1292
1293 let expr = Expr::BinaryOp {
1295 left: Box::new(Expr::Column("name".into())),
1296 op: BinOp::And,
1297 right: Box::new(Expr::Literal(Value::Boolean(false))),
1298 };
1299 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1300
1301 let expr = Expr::BinaryOp {
1303 left: Box::new(Expr::Column("name".into())),
1304 op: BinOp::And,
1305 right: Box::new(Expr::Literal(Value::Boolean(true))),
1306 };
1307 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1308 }
1309
1310 #[test]
1311 fn eval_or_three_valued() {
1312 let cols = test_columns();
1313 let cm = ColumnMap::new(&cols);
1314 let row = vec![
1315 Value::Integer(1),
1316 Value::Null,
1317 Value::Null,
1318 Value::Boolean(true),
1319 ];
1320
1321 let expr = Expr::BinaryOp {
1323 left: Box::new(Expr::Column("name".into())),
1324 op: BinOp::Or,
1325 right: Box::new(Expr::Literal(Value::Boolean(true))),
1326 };
1327 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1328
1329 let expr = Expr::BinaryOp {
1331 left: Box::new(Expr::Column("name".into())),
1332 op: BinOp::Or,
1333 right: Box::new(Expr::Literal(Value::Boolean(false))),
1334 };
1335 assert!(eval_expr(&expr, &cm, &row).unwrap().is_null());
1336 }
1337
1338 #[test]
1339 fn eval_is_null() {
1340 let cols = test_columns();
1341 let cm = ColumnMap::new(&cols);
1342 let row = vec![
1343 Value::Integer(1),
1344 Value::Null,
1345 Value::Null,
1346 Value::Boolean(true),
1347 ];
1348 let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1349 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1350
1351 let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1352 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(true));
1353 }
1354
1355 #[test]
1356 fn eval_not() {
1357 let cols = test_columns();
1358 let cm = ColumnMap::new(&cols);
1359 let row = test_row();
1360 let expr = Expr::UnaryOp {
1361 op: UnaryOp::Not,
1362 expr: Box::new(Expr::Column("active".into())),
1363 };
1364 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Boolean(false));
1365 }
1366
1367 #[test]
1368 fn eval_neg() {
1369 let cols = test_columns();
1370 let cm = ColumnMap::new(&cols);
1371 let row = test_row();
1372 let expr = Expr::UnaryOp {
1373 op: UnaryOp::Neg,
1374 expr: Box::new(Expr::Column("id".into())),
1375 };
1376 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Integer(-1));
1377 }
1378
1379 #[test]
1380 fn eval_division_by_zero() {
1381 let cols = test_columns();
1382 let cm = ColumnMap::new(&cols);
1383 let row = test_row();
1384 let expr = Expr::BinaryOp {
1385 left: Box::new(Expr::Column("id".into())),
1386 op: BinOp::Div,
1387 right: Box::new(Expr::Literal(Value::Integer(0))),
1388 };
1389 assert!(matches!(
1390 eval_expr(&expr, &cm, &row),
1391 Err(SqlError::DivisionByZero)
1392 ));
1393 }
1394
1395 #[test]
1396 fn eval_mixed_numeric() {
1397 let cols = test_columns();
1398 let cm = ColumnMap::new(&cols);
1399 let row = test_row();
1400 let expr = Expr::BinaryOp {
1402 left: Box::new(Expr::Column("id".into())),
1403 op: BinOp::Add,
1404 right: Box::new(Expr::Column("score".into())),
1405 };
1406 assert_eq!(eval_expr(&expr, &cm, &row).unwrap(), Value::Real(96.5));
1407 }
1408
1409 #[test]
1410 fn is_truthy_values() {
1411 assert!(is_truthy(&Value::Boolean(true)));
1412 assert!(!is_truthy(&Value::Boolean(false)));
1413 assert!(!is_truthy(&Value::Null));
1414 assert!(is_truthy(&Value::Integer(1)));
1415 assert!(!is_truthy(&Value::Integer(0)));
1416 }
1417}