1use crate::error::{Result, SqlError};
4use crate::parser::{BinOp, Expr, UnaryOp};
5use crate::types::{ColumnDef, DataType, Value};
6
7pub fn eval_expr(expr: &Expr, columns: &[ColumnDef], row: &[Value]) -> Result<Value> {
12 match expr {
13 Expr::Literal(v) => Ok(v.clone()),
14
15 Expr::Column(name) => {
16 let lower = name.to_ascii_lowercase();
17 let matches: Vec<usize> = columns
18 .iter()
19 .enumerate()
20 .filter(|(_, c)| {
21 let cn = c.name.to_ascii_lowercase();
22 cn == lower || cn.ends_with(&format!(".{lower}"))
23 })
24 .map(|(i, _)| i)
25 .collect();
26 match matches.len() {
27 0 => Err(SqlError::ColumnNotFound(name.clone())),
28 1 => Ok(row[matches[0]].clone()),
29 _ => Err(SqlError::AmbiguousColumn(name.clone())),
30 }
31 }
32
33 Expr::QualifiedColumn { table, column } => {
34 let qualified = format!(
35 "{}.{}",
36 table.to_ascii_lowercase(),
37 column.to_ascii_lowercase()
38 );
39 let idx = columns
40 .iter()
41 .position(|c| c.name.to_ascii_lowercase() == qualified)
42 .or_else(|| {
43 let lower_col = column.to_ascii_lowercase();
44 let matches: Vec<usize> = columns
45 .iter()
46 .enumerate()
47 .filter(|(_, c)| c.name.to_ascii_lowercase() == lower_col)
48 .map(|(i, _)| i)
49 .collect();
50 if matches.len() == 1 {
51 Some(matches[0])
52 } else {
53 None
54 }
55 })
56 .ok_or_else(|| SqlError::ColumnNotFound(format!("{table}.{column}")))?;
57 Ok(row[idx].clone())
58 }
59
60 Expr::BinaryOp { left, op, right } => {
61 let lval = eval_expr(left, columns, row)?;
62 let rval = eval_expr(right, columns, row)?;
63 eval_binary_op(&lval, *op, &rval)
64 }
65
66 Expr::UnaryOp { op, expr } => {
67 let val = eval_expr(expr, columns, row)?;
68 eval_unary_op(*op, &val)
69 }
70
71 Expr::IsNull(e) => {
72 let val = eval_expr(e, columns, row)?;
73 Ok(Value::Boolean(val.is_null()))
74 }
75
76 Expr::IsNotNull(e) => {
77 let val = eval_expr(e, columns, row)?;
78 Ok(Value::Boolean(!val.is_null()))
79 }
80
81 Expr::Function { name, args } => eval_scalar_function(name, args, columns, row),
82
83 Expr::CountStar => Err(SqlError::Unsupported(
84 "COUNT(*) in non-aggregate context".into(),
85 )),
86
87 Expr::InList {
88 expr: e,
89 list,
90 negated,
91 } => {
92 let lhs = eval_expr(e, columns, row)?;
93 eval_in_values(&lhs, list, columns, row, *negated)
94 }
95
96 Expr::InSet {
97 expr: e,
98 values,
99 has_null,
100 negated,
101 } => {
102 let lhs = eval_expr(e, columns, row)?;
103 eval_in_set(&lhs, values, *has_null, *negated)
104 }
105
106 Expr::Between {
107 expr: e,
108 low,
109 high,
110 negated,
111 } => {
112 let val = eval_expr(e, columns, row)?;
113 let lo = eval_expr(low, columns, row)?;
114 let hi = eval_expr(high, columns, row)?;
115 eval_between(&val, &lo, &hi, *negated)
116 }
117
118 Expr::Like {
119 expr: e,
120 pattern,
121 escape,
122 negated,
123 } => {
124 let val = eval_expr(e, columns, row)?;
125 let pat = eval_expr(pattern, columns, row)?;
126 let esc = escape
127 .as_ref()
128 .map(|e| eval_expr(e, columns, row))
129 .transpose()?;
130 eval_like(&val, &pat, esc.as_ref(), *negated)
131 }
132
133 Expr::Case {
134 operand,
135 conditions,
136 else_result,
137 } => eval_case(
138 operand.as_deref(),
139 conditions,
140 else_result.as_deref(),
141 columns,
142 row,
143 ),
144
145 Expr::Coalesce(args) => {
146 for arg in args {
147 let val = eval_expr(arg, columns, row)?;
148 if !val.is_null() {
149 return Ok(val);
150 }
151 }
152 Ok(Value::Null)
153 }
154
155 Expr::Cast { expr: e, data_type } => {
156 let val = eval_expr(e, columns, row)?;
157 eval_cast(&val, *data_type)
158 }
159
160 Expr::InSubquery { .. } | Expr::Exists { .. } | Expr::ScalarSubquery(_) => Err(
161 SqlError::Unsupported("subquery not materialized (internal error)".into()),
162 ),
163
164 Expr::Parameter(n) => Err(SqlError::Parse(format!("unbound parameter ${n}"))),
165 }
166}
167
168fn eval_binary_op(left: &Value, op: BinOp, right: &Value) -> Result<Value> {
169 match op {
171 BinOp::And => return eval_and(left, right),
172 BinOp::Or => return eval_or(left, right),
173 _ => {}
174 }
175
176 if left.is_null() || right.is_null() {
178 return Ok(Value::Null);
179 }
180
181 match op {
182 BinOp::Eq => Ok(Value::Boolean(left == right)),
183 BinOp::NotEq => Ok(Value::Boolean(left != right)),
184 BinOp::Lt => Ok(Value::Boolean(left < right)),
185 BinOp::Gt => Ok(Value::Boolean(left > right)),
186 BinOp::LtEq => Ok(Value::Boolean(left <= right)),
187 BinOp::GtEq => Ok(Value::Boolean(left >= right)),
188 BinOp::Add => eval_arithmetic(left, right, i64::checked_add, |a, b| a + b),
189 BinOp::Sub => eval_arithmetic(left, right, i64::checked_sub, |a, b| a - b),
190 BinOp::Mul => eval_arithmetic(left, right, i64::checked_mul, |a, b| a * b),
191 BinOp::Div => {
192 match right {
193 Value::Integer(0) => return Err(SqlError::DivisionByZero),
194 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
195 _ => {}
196 }
197 eval_arithmetic(left, right, i64::checked_div, |a, b| a / b)
198 }
199 BinOp::Mod => {
200 match right {
201 Value::Integer(0) => return Err(SqlError::DivisionByZero),
202 Value::Real(r) if *r == 0.0 => return Err(SqlError::DivisionByZero),
203 _ => {}
204 }
205 eval_arithmetic(left, right, i64::checked_rem, |a, b| a % b)
206 }
207 BinOp::Concat => {
208 let ls = value_to_text(left);
209 let rs = value_to_text(right);
210 Ok(Value::Text(format!("{ls}{rs}")))
211 }
212 BinOp::And | BinOp::Or => unreachable!(),
213 }
214}
215
216fn eval_and(left: &Value, right: &Value) -> Result<Value> {
218 let l = to_bool_or_null(left)?;
219 let r = to_bool_or_null(right)?;
220 match (l, r) {
221 (Some(false), _) | (_, Some(false)) => Ok(Value::Boolean(false)),
222 (Some(true), Some(true)) => Ok(Value::Boolean(true)),
223 _ => Ok(Value::Null),
224 }
225}
226
227fn eval_or(left: &Value, right: &Value) -> Result<Value> {
229 let l = to_bool_or_null(left)?;
230 let r = to_bool_or_null(right)?;
231 match (l, r) {
232 (Some(true), _) | (_, Some(true)) => Ok(Value::Boolean(true)),
233 (Some(false), Some(false)) => Ok(Value::Boolean(false)),
234 _ => Ok(Value::Null),
235 }
236}
237
238fn to_bool_or_null(val: &Value) -> Result<Option<bool>> {
239 match val {
240 Value::Boolean(b) => Ok(Some(*b)),
241 Value::Null => Ok(None),
242 Value::Integer(i) => Ok(Some(*i != 0)),
243 _ => Err(SqlError::TypeMismatch {
244 expected: "BOOLEAN".into(),
245 got: format!("{}", val.data_type()),
246 }),
247 }
248}
249
250fn eval_arithmetic(
251 left: &Value,
252 right: &Value,
253 int_op: fn(i64, i64) -> Option<i64>,
254 real_op: fn(f64, f64) -> f64,
255) -> Result<Value> {
256 match (left, right) {
257 (Value::Integer(a), Value::Integer(b)) => int_op(*a, *b)
258 .map(Value::Integer)
259 .ok_or(SqlError::IntegerOverflow),
260 (Value::Real(a), Value::Real(b)) => Ok(Value::Real(real_op(*a, *b))),
261 (Value::Integer(a), Value::Real(b)) => Ok(Value::Real(real_op(*a as f64, *b))),
262 (Value::Real(a), Value::Integer(b)) => Ok(Value::Real(real_op(*a, *b as f64))),
263 _ => Err(SqlError::TypeMismatch {
264 expected: "numeric".into(),
265 got: format!("{} and {}", left.data_type(), right.data_type()),
266 }),
267 }
268}
269
270fn eval_in_values(
271 lhs: &Value,
272 list: &[Expr],
273 columns: &[ColumnDef],
274 row: &[Value],
275 negated: bool,
276) -> Result<Value> {
277 if list.is_empty() {
278 return Ok(Value::Boolean(negated));
279 }
280 if lhs.is_null() {
281 return Ok(Value::Null);
282 }
283 let mut has_null = false;
284 for item in list {
285 let rhs = eval_expr(item, columns, row)?;
286 if rhs.is_null() {
287 has_null = true;
288 } else if lhs == &rhs {
289 return Ok(Value::Boolean(!negated));
290 }
291 }
292 if has_null {
293 Ok(Value::Null)
294 } else {
295 Ok(Value::Boolean(negated))
296 }
297}
298
299fn eval_in_set(
300 lhs: &Value,
301 values: &std::collections::HashSet<Value>,
302 has_null: bool,
303 negated: bool,
304) -> Result<Value> {
305 if values.is_empty() && !has_null {
306 return Ok(Value::Boolean(negated));
307 }
308 if lhs.is_null() {
309 return Ok(Value::Null);
310 }
311 if values.contains(lhs) {
312 return Ok(Value::Boolean(!negated));
313 }
314 if has_null {
315 Ok(Value::Null)
316 } else {
317 Ok(Value::Boolean(negated))
318 }
319}
320
321fn eval_unary_op(op: UnaryOp, val: &Value) -> Result<Value> {
322 if val.is_null() {
323 return Ok(Value::Null);
324 }
325 match op {
326 UnaryOp::Neg => match val {
327 Value::Integer(i) => i
328 .checked_neg()
329 .map(Value::Integer)
330 .ok_or(SqlError::IntegerOverflow),
331 Value::Real(r) => Ok(Value::Real(-r)),
332 _ => Err(SqlError::TypeMismatch {
333 expected: "numeric".into(),
334 got: format!("{}", val.data_type()),
335 }),
336 },
337 UnaryOp::Not => match val {
338 Value::Boolean(b) => Ok(Value::Boolean(!b)),
339 Value::Integer(i) => Ok(Value::Boolean(*i == 0)),
340 _ => Err(SqlError::TypeMismatch {
341 expected: "BOOLEAN".into(),
342 got: format!("{}", val.data_type()),
343 }),
344 },
345 }
346}
347
348fn value_to_text(val: &Value) -> String {
349 match val {
350 Value::Text(s) => s.clone(),
351 Value::Integer(i) => i.to_string(),
352 Value::Real(r) => {
353 if r.fract() == 0.0 && r.is_finite() {
354 format!("{r:.1}")
355 } else {
356 format!("{r}")
357 }
358 }
359 Value::Boolean(b) => if *b { "TRUE" } else { "FALSE" }.into(),
360 Value::Null => String::new(),
361 Value::Blob(b) => {
362 let mut s = String::with_capacity(b.len() * 2);
363 for byte in b {
364 s.push_str(&format!("{byte:02X}"));
365 }
366 s
367 }
368 }
369}
370
371fn eval_between(val: &Value, low: &Value, high: &Value, negated: bool) -> Result<Value> {
372 if val.is_null() || low.is_null() || high.is_null() {
373 let ge = if val.is_null() || low.is_null() {
374 None
375 } else {
376 Some(*val >= *low)
377 };
378 let le = if val.is_null() || high.is_null() {
379 None
380 } else {
381 Some(*val <= *high)
382 };
383
384 let result = match (ge, le) {
385 (Some(false), _) | (_, Some(false)) => Some(false),
386 (Some(true), Some(true)) => Some(true),
387 _ => None,
388 };
389
390 return match result {
391 Some(b) => Ok(Value::Boolean(if negated { !b } else { b })),
392 None => Ok(Value::Null),
393 };
394 }
395
396 let in_range = *val >= *low && *val <= *high;
397 Ok(Value::Boolean(if negated { !in_range } else { in_range }))
398}
399
400const MAX_LIKE_PATTERN_LEN: usize = 10_000;
401
402fn eval_like(val: &Value, pattern: &Value, escape: Option<&Value>, negated: bool) -> Result<Value> {
403 if val.is_null() || pattern.is_null() {
404 return Ok(Value::Null);
405 }
406 let text = match val {
407 Value::Text(s) => s.as_str(),
408 _ => {
409 return Err(SqlError::TypeMismatch {
410 expected: "TEXT".into(),
411 got: val.data_type().to_string(),
412 })
413 }
414 };
415 let pat = match pattern {
416 Value::Text(s) => s.as_str(),
417 _ => {
418 return Err(SqlError::TypeMismatch {
419 expected: "TEXT".into(),
420 got: pattern.data_type().to_string(),
421 })
422 }
423 };
424
425 if pat.len() > MAX_LIKE_PATTERN_LEN {
426 return Err(SqlError::InvalidValue(format!(
427 "LIKE pattern too long ({} chars, max {MAX_LIKE_PATTERN_LEN})",
428 pat.len()
429 )));
430 }
431
432 let esc_char = match escape {
433 Some(Value::Text(s)) => {
434 let mut chars = s.chars();
435 let c = chars.next().ok_or_else(|| {
436 SqlError::InvalidValue("ESCAPE must be a single character".into())
437 })?;
438 if chars.next().is_some() {
439 return Err(SqlError::InvalidValue(
440 "ESCAPE must be a single character".into(),
441 ));
442 }
443 Some(c)
444 }
445 Some(Value::Null) => return Ok(Value::Null),
446 Some(_) => {
447 return Err(SqlError::TypeMismatch {
448 expected: "TEXT".into(),
449 got: "non-text".into(),
450 })
451 }
452 None => None,
453 };
454
455 let matched = like_match(text, pat, esc_char);
456 Ok(Value::Boolean(if negated { !matched } else { matched }))
457}
458
459fn like_match(text: &str, pattern: &str, escape: Option<char>) -> bool {
460 let t: Vec<char> = text.chars().collect();
461 let p: Vec<char> = pattern.chars().collect();
462 like_match_impl(&t, &p, 0, 0, escape)
463}
464
465fn like_match_impl(
466 t: &[char],
467 p: &[char],
468 mut ti: usize,
469 mut pi: usize,
470 esc: Option<char>,
471) -> bool {
472 let mut star_pi: Option<usize> = None;
473 let mut star_ti: usize = 0;
474
475 while ti < t.len() {
476 if pi < p.len() {
477 if let Some(ec) = esc {
478 if p[pi] == ec && pi + 1 < p.len() {
479 pi += 1;
480 let pc_lower = p[pi].to_ascii_lowercase();
481 let tc_lower = t[ti].to_ascii_lowercase();
482 if pc_lower == tc_lower {
483 pi += 1;
484 ti += 1;
485 continue;
486 } else if let Some(sp) = star_pi {
487 pi = sp + 1;
488 star_ti += 1;
489 ti = star_ti;
490 continue;
491 } else {
492 return false;
493 }
494 }
495 }
496 if p[pi] == '%' {
497 star_pi = Some(pi);
498 star_ti = ti;
499 pi += 1;
500 continue;
501 }
502 if p[pi] == '_' {
503 pi += 1;
504 ti += 1;
505 continue;
506 }
507 if p[pi].eq_ignore_ascii_case(&t[ti]) {
508 pi += 1;
509 ti += 1;
510 continue;
511 }
512 }
513 if let Some(sp) = star_pi {
514 pi = sp + 1;
515 star_ti += 1;
516 ti = star_ti;
517 } else {
518 return false;
519 }
520 }
521
522 while pi < p.len() && p[pi] == '%' {
523 pi += 1;
524 }
525 pi == p.len()
526}
527
528fn eval_case(
529 operand: Option<&Expr>,
530 conditions: &[(Expr, Expr)],
531 else_result: Option<&Expr>,
532 columns: &[ColumnDef],
533 row: &[Value],
534) -> Result<Value> {
535 if let Some(op_expr) = operand {
536 let op_val = eval_expr(op_expr, columns, row)?;
537 for (cond, result) in conditions {
538 let cond_val = eval_expr(cond, columns, row)?;
539 if !op_val.is_null() && !cond_val.is_null() && op_val == cond_val {
540 return eval_expr(result, columns, row);
541 }
542 }
543 } else {
544 for (cond, result) in conditions {
545 let cond_val = eval_expr(cond, columns, row)?;
546 if is_truthy(&cond_val) {
547 return eval_expr(result, columns, row);
548 }
549 }
550 }
551 match else_result {
552 Some(e) => eval_expr(e, columns, row),
553 None => Ok(Value::Null),
554 }
555}
556
557fn eval_cast(val: &Value, target: DataType) -> Result<Value> {
558 if val.is_null() {
559 return Ok(Value::Null);
560 }
561 match target {
562 DataType::Integer => match val {
563 Value::Integer(_) => Ok(val.clone()),
564 Value::Real(r) => Ok(Value::Integer(*r as i64)),
565 Value::Boolean(b) => Ok(Value::Integer(if *b { 1 } else { 0 })),
566 Value::Text(s) => s
567 .trim()
568 .parse::<i64>()
569 .map(Value::Integer)
570 .or_else(|_| s.trim().parse::<f64>().map(|f| Value::Integer(f as i64)))
571 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to INTEGER"))),
572 _ => Err(SqlError::InvalidValue(format!(
573 "cannot cast {} to INTEGER",
574 val.data_type()
575 ))),
576 },
577 DataType::Real => match val {
578 Value::Real(_) => Ok(val.clone()),
579 Value::Integer(i) => Ok(Value::Real(*i as f64)),
580 Value::Boolean(b) => Ok(Value::Real(if *b { 1.0 } else { 0.0 })),
581 Value::Text(s) => s
582 .trim()
583 .parse::<f64>()
584 .map(Value::Real)
585 .map_err(|_| SqlError::InvalidValue(format!("cannot cast '{s}' to REAL"))),
586 _ => Err(SqlError::InvalidValue(format!(
587 "cannot cast {} to REAL",
588 val.data_type()
589 ))),
590 },
591 DataType::Text => Ok(Value::Text(value_to_text(val))),
592 DataType::Boolean => match val {
593 Value::Boolean(_) => Ok(val.clone()),
594 Value::Integer(i) => Ok(Value::Boolean(*i != 0)),
595 Value::Text(s) => {
596 let lower = s.trim().to_ascii_lowercase();
597 match lower.as_str() {
598 "true" | "1" | "yes" | "on" => Ok(Value::Boolean(true)),
599 "false" | "0" | "no" | "off" => Ok(Value::Boolean(false)),
600 _ => Err(SqlError::InvalidValue(format!(
601 "cannot cast '{s}' to BOOLEAN"
602 ))),
603 }
604 }
605 _ => Err(SqlError::InvalidValue(format!(
606 "cannot cast {} to BOOLEAN",
607 val.data_type()
608 ))),
609 },
610 DataType::Blob => match val {
611 Value::Blob(_) => Ok(val.clone()),
612 Value::Text(s) => Ok(Value::Blob(s.as_bytes().to_vec())),
613 _ => Err(SqlError::InvalidValue(format!(
614 "cannot cast {} to BLOB",
615 val.data_type()
616 ))),
617 },
618 DataType::Null => Ok(Value::Null),
619 }
620}
621
622fn eval_scalar_function(
623 name: &str,
624 args: &[Expr],
625 columns: &[ColumnDef],
626 row: &[Value],
627) -> Result<Value> {
628 let evaluated: Vec<Value> = args
629 .iter()
630 .map(|a| eval_expr(a, columns, row))
631 .collect::<Result<Vec<_>>>()?;
632
633 match name {
634 "LENGTH" => {
635 check_args(name, &evaluated, 1)?;
636 match &evaluated[0] {
637 Value::Null => Ok(Value::Null),
638 Value::Text(s) => Ok(Value::Integer(s.chars().count() as i64)),
639 Value::Blob(b) => Ok(Value::Integer(b.len() as i64)),
640 _ => Ok(Value::Integer(
641 value_to_text(&evaluated[0]).chars().count() as i64
642 )),
643 }
644 }
645 "UPPER" => {
646 check_args(name, &evaluated, 1)?;
647 match &evaluated[0] {
648 Value::Null => Ok(Value::Null),
649 Value::Text(s) => Ok(Value::Text(s.to_ascii_uppercase())),
650 _ => Ok(Value::Text(
651 value_to_text(&evaluated[0]).to_ascii_uppercase(),
652 )),
653 }
654 }
655 "LOWER" => {
656 check_args(name, &evaluated, 1)?;
657 match &evaluated[0] {
658 Value::Null => Ok(Value::Null),
659 Value::Text(s) => Ok(Value::Text(s.to_ascii_lowercase())),
660 _ => Ok(Value::Text(
661 value_to_text(&evaluated[0]).to_ascii_lowercase(),
662 )),
663 }
664 }
665 "SUBSTR" | "SUBSTRING" => {
666 if evaluated.len() < 2 || evaluated.len() > 3 {
667 return Err(SqlError::InvalidValue(format!(
668 "{name} requires 2 or 3 arguments"
669 )));
670 }
671 if evaluated.iter().any(|v| v.is_null()) {
672 return Ok(Value::Null);
673 }
674 let s = value_to_text(&evaluated[0]);
675 let chars: Vec<char> = s.chars().collect();
676 let start = match &evaluated[1] {
677 Value::Integer(i) => *i,
678 _ => {
679 return Err(SqlError::TypeMismatch {
680 expected: "INTEGER".into(),
681 got: evaluated[1].data_type().to_string(),
682 })
683 }
684 };
685 let len = chars.len() as i64;
686
687 let (begin, count) = if evaluated.len() == 3 {
688 let cnt = match &evaluated[2] {
689 Value::Integer(i) => *i,
690 _ => {
691 return Err(SqlError::TypeMismatch {
692 expected: "INTEGER".into(),
693 got: evaluated[2].data_type().to_string(),
694 })
695 }
696 };
697 if start >= 1 {
698 let b = (start - 1).min(len) as usize;
699 let c = cnt.max(0) as usize;
700 (b, c)
701 } else if start == 0 {
702 let c = (cnt - 1).max(0) as usize;
703 (0usize, c)
704 } else {
705 let adjusted_cnt = (cnt + start - 1).max(0) as usize;
706 (0usize, adjusted_cnt)
707 }
708 } else if start >= 1 {
709 let b = (start - 1).min(len) as usize;
710 (b, chars.len() - b)
711 } else if start == 0 {
712 (0usize, chars.len())
713 } else {
714 let b = (len + start).max(0) as usize;
715 (b, chars.len() - b)
716 };
717
718 let result: String = chars.iter().skip(begin).take(count).collect();
719 Ok(Value::Text(result))
720 }
721 "TRIM" | "LTRIM" | "RTRIM" => {
722 if evaluated.is_empty() || evaluated.len() > 2 {
723 return Err(SqlError::InvalidValue(format!(
724 "{name} requires 1 or 2 arguments"
725 )));
726 }
727 if evaluated[0].is_null() {
728 return Ok(Value::Null);
729 }
730 let s = value_to_text(&evaluated[0]);
731 let trim_chars: Vec<char> = if evaluated.len() == 2 {
732 if evaluated[1].is_null() {
733 return Ok(Value::Null);
734 }
735 value_to_text(&evaluated[1]).chars().collect()
736 } else {
737 vec![' ']
738 };
739 let result = match name {
740 "TRIM" => s
741 .trim_matches(|c: char| trim_chars.contains(&c))
742 .to_string(),
743 "LTRIM" => s
744 .trim_start_matches(|c: char| trim_chars.contains(&c))
745 .to_string(),
746 "RTRIM" => s
747 .trim_end_matches(|c: char| trim_chars.contains(&c))
748 .to_string(),
749 _ => unreachable!(),
750 };
751 Ok(Value::Text(result))
752 }
753 "REPLACE" => {
754 check_args(name, &evaluated, 3)?;
755 if evaluated.iter().any(|v| v.is_null()) {
756 return Ok(Value::Null);
757 }
758 let s = value_to_text(&evaluated[0]);
759 let from = value_to_text(&evaluated[1]);
760 let to = value_to_text(&evaluated[2]);
761 if from.is_empty() {
762 return Ok(Value::Text(s));
763 }
764 Ok(Value::Text(s.replace(&from, &to)))
765 }
766 "INSTR" => {
767 check_args(name, &evaluated, 2)?;
768 if evaluated.iter().any(|v| v.is_null()) {
769 return Ok(Value::Null);
770 }
771 let haystack = value_to_text(&evaluated[0]);
772 let needle = value_to_text(&evaluated[1]);
773 let pos = haystack
774 .find(&needle)
775 .map(|i| haystack[..i].chars().count() as i64 + 1)
776 .unwrap_or(0);
777 Ok(Value::Integer(pos))
778 }
779 "CONCAT" => {
780 if evaluated.is_empty() {
781 return Ok(Value::Text(String::new()));
782 }
783 let mut result = String::new();
784 for v in &evaluated {
785 match v {
786 Value::Null => {}
787 _ => result.push_str(&value_to_text(v)),
788 }
789 }
790 Ok(Value::Text(result))
791 }
792 "ABS" => {
793 check_args(name, &evaluated, 1)?;
794 match &evaluated[0] {
795 Value::Null => Ok(Value::Null),
796 Value::Integer(i) => i
797 .checked_abs()
798 .map(Value::Integer)
799 .ok_or(SqlError::IntegerOverflow),
800 Value::Real(r) => Ok(Value::Real(r.abs())),
801 _ => Err(SqlError::TypeMismatch {
802 expected: "numeric".into(),
803 got: evaluated[0].data_type().to_string(),
804 }),
805 }
806 }
807 "ROUND" => {
808 if evaluated.is_empty() || evaluated.len() > 2 {
809 return Err(SqlError::InvalidValue(
810 "ROUND requires 1 or 2 arguments".into(),
811 ));
812 }
813 if evaluated[0].is_null() {
814 return Ok(Value::Null);
815 }
816 let val = match &evaluated[0] {
817 Value::Integer(i) => *i as f64,
818 Value::Real(r) => *r,
819 _ => {
820 return Err(SqlError::TypeMismatch {
821 expected: "numeric".into(),
822 got: evaluated[0].data_type().to_string(),
823 })
824 }
825 };
826 let places = if evaluated.len() == 2 {
827 match &evaluated[1] {
828 Value::Null => return Ok(Value::Null),
829 Value::Integer(i) => *i,
830 _ => {
831 return Err(SqlError::TypeMismatch {
832 expected: "INTEGER".into(),
833 got: evaluated[1].data_type().to_string(),
834 })
835 }
836 }
837 } else {
838 0
839 };
840 let factor = 10f64.powi(places as i32);
841 let rounded = (val * factor).round() / factor;
842 Ok(Value::Real(rounded))
843 }
844 "CEIL" | "CEILING" => {
845 check_args(name, &evaluated, 1)?;
846 match &evaluated[0] {
847 Value::Null => Ok(Value::Null),
848 Value::Integer(i) => Ok(Value::Integer(*i)),
849 Value::Real(r) => Ok(Value::Integer(r.ceil() as i64)),
850 _ => Err(SqlError::TypeMismatch {
851 expected: "numeric".into(),
852 got: evaluated[0].data_type().to_string(),
853 }),
854 }
855 }
856 "FLOOR" => {
857 check_args(name, &evaluated, 1)?;
858 match &evaluated[0] {
859 Value::Null => Ok(Value::Null),
860 Value::Integer(i) => Ok(Value::Integer(*i)),
861 Value::Real(r) => Ok(Value::Integer(r.floor() as i64)),
862 _ => Err(SqlError::TypeMismatch {
863 expected: "numeric".into(),
864 got: evaluated[0].data_type().to_string(),
865 }),
866 }
867 }
868 "SIGN" => {
869 check_args(name, &evaluated, 1)?;
870 match &evaluated[0] {
871 Value::Null => Ok(Value::Null),
872 Value::Integer(i) => Ok(Value::Integer(i.signum())),
873 Value::Real(r) => {
874 if *r > 0.0 {
875 Ok(Value::Integer(1))
876 } else if *r < 0.0 {
877 Ok(Value::Integer(-1))
878 } else {
879 Ok(Value::Integer(0))
880 }
881 }
882 _ => Err(SqlError::TypeMismatch {
883 expected: "numeric".into(),
884 got: evaluated[0].data_type().to_string(),
885 }),
886 }
887 }
888 "SQRT" => {
889 check_args(name, &evaluated, 1)?;
890 match &evaluated[0] {
891 Value::Null => Ok(Value::Null),
892 Value::Integer(i) => {
893 if *i < 0 {
894 Ok(Value::Null)
895 } else {
896 Ok(Value::Real((*i as f64).sqrt()))
897 }
898 }
899 Value::Real(r) => {
900 if *r < 0.0 {
901 Ok(Value::Null)
902 } else {
903 Ok(Value::Real(r.sqrt()))
904 }
905 }
906 _ => Err(SqlError::TypeMismatch {
907 expected: "numeric".into(),
908 got: evaluated[0].data_type().to_string(),
909 }),
910 }
911 }
912 "RANDOM" => {
913 check_args(name, &evaluated, 0)?;
914 use std::collections::hash_map::DefaultHasher;
915 use std::hash::{Hash, Hasher};
916 use std::time::SystemTime;
917 let mut hasher = DefaultHasher::new();
918 SystemTime::now().hash(&mut hasher);
919 std::thread::current().id().hash(&mut hasher);
920 let mut val = hasher.finish() as i64;
921 if val == i64::MIN {
922 val = i64::MAX;
923 }
924 Ok(Value::Integer(val))
925 }
926 "TYPEOF" => {
927 check_args(name, &evaluated, 1)?;
928 let type_name = match &evaluated[0] {
929 Value::Null => "null",
930 Value::Integer(_) => "integer",
931 Value::Real(_) => "real",
932 Value::Text(_) => "text",
933 Value::Blob(_) => "blob",
934 Value::Boolean(_) => "boolean",
935 };
936 Ok(Value::Text(type_name.into()))
937 }
938 "MIN" => {
939 check_args(name, &evaluated, 2)?;
940 if evaluated[0].is_null() {
941 return Ok(evaluated[1].clone());
942 }
943 if evaluated[1].is_null() {
944 return Ok(evaluated[0].clone());
945 }
946 if evaluated[0] <= evaluated[1] {
947 Ok(evaluated[0].clone())
948 } else {
949 Ok(evaluated[1].clone())
950 }
951 }
952 "MAX" => {
953 check_args(name, &evaluated, 2)?;
954 if evaluated[0].is_null() {
955 return Ok(evaluated[1].clone());
956 }
957 if evaluated[1].is_null() {
958 return Ok(evaluated[0].clone());
959 }
960 if evaluated[0] >= evaluated[1] {
961 Ok(evaluated[0].clone())
962 } else {
963 Ok(evaluated[1].clone())
964 }
965 }
966 "HEX" => {
967 check_args(name, &evaluated, 1)?;
968 match &evaluated[0] {
969 Value::Null => Ok(Value::Null),
970 Value::Blob(b) => {
971 let mut s = String::with_capacity(b.len() * 2);
972 for byte in b {
973 s.push_str(&format!("{byte:02X}"));
974 }
975 Ok(Value::Text(s))
976 }
977 Value::Text(s) => {
978 let mut r = String::with_capacity(s.len() * 2);
979 for byte in s.as_bytes() {
980 r.push_str(&format!("{byte:02X}"));
981 }
982 Ok(Value::Text(r))
983 }
984 _ => Ok(Value::Text(value_to_text(&evaluated[0]))),
985 }
986 }
987 _ => Err(SqlError::Unsupported(format!("scalar function: {name}"))),
988 }
989}
990
991fn check_args(name: &str, args: &[Value], expected: usize) -> Result<()> {
992 if args.len() != expected {
993 Err(SqlError::InvalidValue(format!(
994 "{name} requires {expected} argument(s), got {}",
995 args.len()
996 )))
997 } else {
998 Ok(())
999 }
1000}
1001
1002pub fn is_truthy(val: &Value) -> bool {
1004 match val {
1005 Value::Boolean(b) => *b,
1006 Value::Integer(i) => *i != 0,
1007 Value::Null => false,
1008 _ => true,
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use crate::types::DataType;
1016
1017 fn test_columns() -> Vec<ColumnDef> {
1018 vec![
1019 ColumnDef {
1020 name: "id".into(),
1021 data_type: DataType::Integer,
1022 nullable: false,
1023 position: 0,
1024 },
1025 ColumnDef {
1026 name: "name".into(),
1027 data_type: DataType::Text,
1028 nullable: true,
1029 position: 1,
1030 },
1031 ColumnDef {
1032 name: "score".into(),
1033 data_type: DataType::Real,
1034 nullable: true,
1035 position: 2,
1036 },
1037 ColumnDef {
1038 name: "active".into(),
1039 data_type: DataType::Boolean,
1040 nullable: false,
1041 position: 3,
1042 },
1043 ]
1044 }
1045
1046 fn test_row() -> Vec<Value> {
1047 vec![
1048 Value::Integer(1),
1049 Value::Text("Alice".into()),
1050 Value::Real(95.5),
1051 Value::Boolean(true),
1052 ]
1053 }
1054
1055 #[test]
1056 fn eval_literal() {
1057 let cols = test_columns();
1058 let row = test_row();
1059 let expr = Expr::Literal(Value::Integer(42));
1060 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(42));
1061 }
1062
1063 #[test]
1064 fn eval_column_ref() {
1065 let cols = test_columns();
1066 let row = test_row();
1067 let expr = Expr::Column("name".into());
1068 assert_eq!(
1069 eval_expr(&expr, &cols, &row).unwrap(),
1070 Value::Text("Alice".into())
1071 );
1072 }
1073
1074 #[test]
1075 fn eval_column_case_insensitive() {
1076 let cols = test_columns();
1077 let row = test_row();
1078 let expr = Expr::Column("NAME".into());
1079 assert_eq!(
1080 eval_expr(&expr, &cols, &row).unwrap(),
1081 Value::Text("Alice".into())
1082 );
1083 }
1084
1085 #[test]
1086 fn eval_arithmetic_int() {
1087 let cols = test_columns();
1088 let row = test_row();
1089 let expr = Expr::BinaryOp {
1090 left: Box::new(Expr::Column("id".into())),
1091 op: BinOp::Add,
1092 right: Box::new(Expr::Literal(Value::Integer(10))),
1093 };
1094 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(11));
1095 }
1096
1097 #[test]
1098 fn eval_comparison() {
1099 let cols = test_columns();
1100 let row = test_row();
1101 let expr = Expr::BinaryOp {
1102 left: Box::new(Expr::Column("score".into())),
1103 op: BinOp::Gt,
1104 right: Box::new(Expr::Literal(Value::Real(90.0))),
1105 };
1106 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1107 }
1108
1109 #[test]
1110 fn eval_null_propagation() {
1111 let cols = test_columns();
1112 let row = vec![
1113 Value::Integer(1),
1114 Value::Null,
1115 Value::Null,
1116 Value::Boolean(true),
1117 ];
1118 let expr = Expr::BinaryOp {
1119 left: Box::new(Expr::Column("name".into())),
1120 op: BinOp::Eq,
1121 right: Box::new(Expr::Literal(Value::Text("test".into()))),
1122 };
1123 assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1124 }
1125
1126 #[test]
1127 fn eval_and_three_valued() {
1128 let cols = test_columns();
1129 let row = vec![
1130 Value::Integer(1),
1131 Value::Null,
1132 Value::Null,
1133 Value::Boolean(true),
1134 ];
1135
1136 let expr = Expr::BinaryOp {
1138 left: Box::new(Expr::Column("name".into())),
1139 op: BinOp::And,
1140 right: Box::new(Expr::Literal(Value::Boolean(false))),
1141 };
1142 assert_eq!(
1143 eval_expr(&expr, &cols, &row).unwrap(),
1144 Value::Boolean(false)
1145 );
1146
1147 let expr = Expr::BinaryOp {
1149 left: Box::new(Expr::Column("name".into())),
1150 op: BinOp::And,
1151 right: Box::new(Expr::Literal(Value::Boolean(true))),
1152 };
1153 assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1154 }
1155
1156 #[test]
1157 fn eval_or_three_valued() {
1158 let cols = test_columns();
1159 let row = vec![
1160 Value::Integer(1),
1161 Value::Null,
1162 Value::Null,
1163 Value::Boolean(true),
1164 ];
1165
1166 let expr = Expr::BinaryOp {
1168 left: Box::new(Expr::Column("name".into())),
1169 op: BinOp::Or,
1170 right: Box::new(Expr::Literal(Value::Boolean(true))),
1171 };
1172 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1173
1174 let expr = Expr::BinaryOp {
1176 left: Box::new(Expr::Column("name".into())),
1177 op: BinOp::Or,
1178 right: Box::new(Expr::Literal(Value::Boolean(false))),
1179 };
1180 assert!(eval_expr(&expr, &cols, &row).unwrap().is_null());
1181 }
1182
1183 #[test]
1184 fn eval_is_null() {
1185 let cols = test_columns();
1186 let row = vec![
1187 Value::Integer(1),
1188 Value::Null,
1189 Value::Null,
1190 Value::Boolean(true),
1191 ];
1192 let expr = Expr::IsNull(Box::new(Expr::Column("name".into())));
1193 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1194
1195 let expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
1196 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Boolean(true));
1197 }
1198
1199 #[test]
1200 fn eval_not() {
1201 let cols = test_columns();
1202 let row = test_row();
1203 let expr = Expr::UnaryOp {
1204 op: UnaryOp::Not,
1205 expr: Box::new(Expr::Column("active".into())),
1206 };
1207 assert_eq!(
1208 eval_expr(&expr, &cols, &row).unwrap(),
1209 Value::Boolean(false)
1210 );
1211 }
1212
1213 #[test]
1214 fn eval_neg() {
1215 let cols = test_columns();
1216 let row = test_row();
1217 let expr = Expr::UnaryOp {
1218 op: UnaryOp::Neg,
1219 expr: Box::new(Expr::Column("id".into())),
1220 };
1221 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Integer(-1));
1222 }
1223
1224 #[test]
1225 fn eval_division_by_zero() {
1226 let cols = test_columns();
1227 let row = test_row();
1228 let expr = Expr::BinaryOp {
1229 left: Box::new(Expr::Column("id".into())),
1230 op: BinOp::Div,
1231 right: Box::new(Expr::Literal(Value::Integer(0))),
1232 };
1233 assert!(matches!(
1234 eval_expr(&expr, &cols, &row),
1235 Err(SqlError::DivisionByZero)
1236 ));
1237 }
1238
1239 #[test]
1240 fn eval_mixed_numeric() {
1241 let cols = test_columns();
1242 let row = test_row();
1243 let expr = Expr::BinaryOp {
1245 left: Box::new(Expr::Column("id".into())),
1246 op: BinOp::Add,
1247 right: Box::new(Expr::Column("score".into())),
1248 };
1249 assert_eq!(eval_expr(&expr, &cols, &row).unwrap(), Value::Real(96.5));
1250 }
1251
1252 #[test]
1253 fn is_truthy_values() {
1254 assert!(is_truthy(&Value::Boolean(true)));
1255 assert!(!is_truthy(&Value::Boolean(false)));
1256 assert!(!is_truthy(&Value::Null));
1257 assert!(is_truthy(&Value::Integer(1)));
1258 assert!(!is_truthy(&Value::Integer(0)));
1259 }
1260}