1use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14 query: &SelectQuery,
15 rows: &[Row],
16 _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18 execute(query, rows, None)
19}
20
21pub fn execute_query_indexed(
23 query: &SelectQuery,
24 rows: &[Row],
25 schema: &Schema,
26 index: Option<&crate::index::TableIndex>,
27 searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29 let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31 collect_fts_results(wc, schema, searcher)
32 } else {
33 HashMap::new()
34 };
35
36 execute_with_fts(query, rows, index, &fts_results)
37}
38
39fn collect_fts_results(
42 clause: &WhereClause,
43 schema: &Schema,
44 searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46 let mut results = HashMap::new();
47 collect_fts_results_inner(clause, schema, searcher, &mut results);
48 results
49}
50
51fn collect_fts_results_inner(
52 clause: &WhereClause,
53 schema: &Schema,
54 searcher: &crate::search::TableSearcher,
55 results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57 match clause {
58 WhereClause::Comparison(cmp) => {
59 if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60 if let Some(SqlValue::String(pattern)) = &cmp.value {
61 let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63 if !search_term.is_empty() {
64 if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65 let key = (cmp.column.clone(), pattern.clone());
66 results.insert(key, paths.into_iter().collect());
67 }
68 }
69 }
70 }
71 }
72 WhereClause::BoolOp(bop) => {
73 collect_fts_results_inner(&bop.left, schema, searcher, results);
74 collect_fts_results_inner(&bop.right, schema, searcher, results);
75 }
76 }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82 query: &SelectQuery,
83 rows: &[Row],
84 index: Option<&crate::index::TableIndex>,
85 fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87 let mut all_columns: Vec<String> = Vec::new();
89 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90 for r in rows {
91 for k in r.keys() {
92 if seen.insert(k.clone()) {
93 all_columns.push(k.clone());
94 }
95 }
96 }
97
98 let has_aggregates = match &query.columns {
100 ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101 _ => false,
102 };
103
104 let columns: Vec<String> = match &query.columns {
106 ColumnList::All => all_columns,
107 ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108 };
109
110 let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112 let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113 if let Some(paths) = candidate_paths {
114 rows.iter()
115 .filter(|r| {
116 r.get("path")
117 .and_then(|v| v.as_str())
118 .map_or(false, |p| paths.contains(p))
119 })
120 .filter(|r| evaluate_with_fts(wc, r, fts))
121 .cloned()
122 .collect()
123 } else {
124 rows.iter()
125 .filter(|r| evaluate_with_fts(wc, r, fts))
126 .cloned()
127 .collect()
128 }
129 } else {
130 rows.to_vec()
131 };
132
133 let mut result = if has_aggregates || query.group_by.is_some() {
135 let exprs = match &query.columns {
136 ColumnList::Named(exprs) => exprs.clone(),
137 _ => return Err(MdqlError::QueryExecution(
138 "SELECT * with GROUP BY is not supported".into(),
139 )),
140 };
141 let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142 aggregate_rows(&filtered, &exprs, group_keys)?
143 } else {
144 filtered
145 };
146
147 if let Some(ref having) = query.having {
149 result.retain(|row| evaluate(having, row));
150 }
151
152 if let Some(ref order_by) = query.order_by {
154 let resolved = resolve_order_aliases(order_by, &query.columns);
155 sort_rows(&mut result, &resolved);
156 }
157
158 if let Some(limit) = query.limit {
160 result.truncate(limit as usize);
161 }
162
163 if !matches!(query.columns, ColumnList::All) {
165 let named_exprs = match &query.columns {
166 ColumnList::Named(exprs) => exprs,
167 _ => unreachable!(),
168 };
169
170 let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
174 let already_aggregated = has_aggregates || query.group_by.is_some();
175 if has_expr_cols && !already_aggregated {
176 for row in &mut result {
177 for expr in named_exprs {
178 if let SelectExpr::Expr { expr: e, alias } = expr {
179 let name = alias.clone().unwrap_or_else(|| e.display_name());
180 let val = evaluate_expr(e, row);
181 row.insert(name, val);
182 }
183 }
184 }
185 }
186
187 let col_set: std::collections::HashSet<&str> =
188 columns.iter().map(|s| s.as_str()).collect();
189 for row in &mut result {
190 row.retain(|k, _| col_set.contains(k.as_str()));
191 }
192 }
193
194 Ok((result, columns))
195}
196
197fn aggregate_rows(
198 rows: &[Row],
199 exprs: &[SelectExpr],
200 group_keys: &[String],
201) -> crate::errors::Result<Vec<Row>> {
202 let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
204 let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
205
206 if group_keys.is_empty() {
207 let all_refs: Vec<&Row> = rows.iter().collect();
209 groups.push((vec![], all_refs));
210 } else {
211 for row in rows {
212 let key: Vec<String> = group_keys
213 .iter()
214 .map(|k| {
215 row.get(k)
216 .map(|v| v.to_display_string())
217 .unwrap_or_default()
218 })
219 .collect();
220 let key_vals: Vec<Value> = group_keys
221 .iter()
222 .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
223 .collect();
224 if let Some(&idx) = key_index.get(&key) {
225 groups[idx].1.push(row);
226 } else {
227 let idx = groups.len();
228 key_index.insert(key, idx);
229 groups.push((key_vals, vec![row]));
230 }
231 }
232 }
233
234 let mut result = Vec::new();
236 for (key_vals, group_rows) in &groups {
237 let mut out = Row::new();
238
239 for (i, k) in group_keys.iter().enumerate() {
241 out.insert(k.clone(), key_vals[i].clone());
242 }
243
244 for expr in exprs {
246 match expr {
247 SelectExpr::Column(name) => {
248 if !out.contains_key(name) {
250 if let Some(first) = group_rows.first() {
251 out.insert(
252 name.clone(),
253 first.get(name).cloned().unwrap_or(Value::Null),
254 );
255 }
256 }
257 }
258 SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
259 let out_name = alias
260 .clone()
261 .unwrap_or_else(|| expr.output_name());
262 let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
263 out.insert(out_name, val);
264 }
265 SelectExpr::Expr { expr: e, alias } => {
266 let out_name = alias.clone().unwrap_or_else(|| e.display_name());
267 if let Some(first) = group_rows.first() {
268 let val = evaluate_expr(e, first);
269 out.insert(out_name, val);
270 }
271 }
272 }
273 }
274
275 result.push(out);
276 }
277
278 Ok(result)
279}
280
281fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
284 if let Some(expr) = arg_expr {
285 evaluate_expr(expr, row)
286 } else {
287 row.get(arg).cloned().unwrap_or(Value::Null)
288 }
289}
290
291fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
292 match func {
293 AggFunc::Count => {
294 if arg == "*" && arg_expr.is_none() {
295 Value::Int(rows.len() as i64)
296 } else {
297 let count = rows
298 .iter()
299 .filter(|r| {
300 let v = resolve_agg_value(arg, arg_expr, r);
301 !v.is_null()
302 })
303 .count();
304 Value::Int(count as i64)
305 }
306 }
307 AggFunc::Sum => {
308 let mut total = 0.0f64;
309 let mut has_any = false;
310 for r in rows {
311 let v = resolve_agg_value(arg, arg_expr, r);
312 match v {
313 Value::Int(n) => { total += n as f64; has_any = true; }
314 Value::Float(f) => { total += f; has_any = true; }
315 _ => {}
316 }
317 }
318 if has_any { Value::Float(total) } else { Value::Null }
319 }
320 AggFunc::Avg => {
321 let mut total = 0.0f64;
322 let mut count = 0usize;
323 for r in rows {
324 let v = resolve_agg_value(arg, arg_expr, r);
325 match v {
326 Value::Int(n) => { total += n as f64; count += 1; }
327 Value::Float(f) => { total += f; count += 1; }
328 _ => {}
329 }
330 }
331 if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
332 }
333 AggFunc::Min => {
334 let mut min_val: Option<Value> = None;
335 for r in rows {
336 let v = resolve_agg_value(arg, arg_expr, r);
337 if v.is_null() { continue; }
338 min_val = Some(match min_val {
339 None => v,
340 Some(ref current) => {
341 if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
342 v
343 } else {
344 current.clone()
345 }
346 }
347 });
348 }
349 min_val.unwrap_or(Value::Null)
350 }
351 AggFunc::Max => {
352 let mut max_val: Option<Value> = None;
353 for r in rows {
354 let v = resolve_agg_value(arg, arg_expr, r);
355 if v.is_null() { continue; }
356 max_val = Some(match max_val {
357 None => v,
358 Some(ref current) => {
359 if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
360 v
361 } else {
362 current.clone()
363 }
364 }
365 });
366 }
367 max_val.unwrap_or(Value::Null)
368 }
369 }
370}
371
372fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
373 match clause {
374 WhereClause::BoolOp(bop) => {
375 let left = evaluate_with_fts(&bop.left, row, fts);
376 match bop.op.as_str() {
377 "AND" => left && evaluate_with_fts(&bop.right, row, fts),
378 "OR" => left || evaluate_with_fts(&bop.right, row, fts),
379 _ => false,
380 }
381 }
382 WhereClause::Comparison(cmp) => {
383 if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
385 if let Some(SqlValue::String(pattern)) = &cmp.value {
386 let key = (cmp.column.clone(), pattern.clone());
387 if let Some(matching_paths) = fts.get(&key) {
388 let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
389 let matched = matching_paths.contains(row_path);
390 return if cmp.op == "LIKE" { matched } else { !matched };
391 }
392 }
393 }
394 evaluate_comparison(cmp, row)
395 }
396 }
397}
398
399pub fn execute_join_query(
400 query: &SelectQuery,
401 tables: &HashMap<String, (Schema, Vec<Row>)>,
402) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
403 if query.joins.is_empty() {
404 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
405 }
406
407 let left_name = &query.table;
408 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
409
410 let mut aliases: HashMap<String, String> = HashMap::new();
412 aliases.insert(left_name.clone(), left_name.clone());
413 if let Some(ref a) = query.table_alias {
414 aliases.insert(a.clone(), left_name.clone());
415 }
416 for join in &query.joins {
417 aliases.insert(join.table.clone(), join.table.clone());
418 if let Some(ref a) = join.alias {
419 aliases.insert(a.clone(), join.table.clone());
420 }
421 }
422
423 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
425 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
426 })?;
427
428 let mut current_rows: Vec<Row> = left_rows
429 .iter()
430 .map(|r| {
431 let mut prefixed = Row::new();
432 for (k, v) in r {
433 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
434 }
435 prefixed
436 })
437 .collect();
438
439 for join in &query.joins {
441 let right_name = &join.table;
442 let right_alias = join.alias.as_deref().unwrap_or(right_name);
443
444 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
445 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
446 })?;
447
448 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
450 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
451
452 let (left_key, right_key) = if on_right_table == *right_name {
454 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
456 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
457 } else {
458 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
460 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
461 };
462
463 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
465 for r in right_rows {
466 if let Some(key) = r.get(&right_key) {
467 let key_str = key.to_display_string();
468 right_index.entry(key_str).or_default().push(r);
469 }
470 }
471
472 let mut next_rows: Vec<Row> = Vec::new();
474 for lr in ¤t_rows {
475 if let Some(key) = lr.get(&left_key) {
476 let key_str = key.to_display_string();
477 if let Some(matching) = right_index.get(&key_str) {
478 for rr in matching {
479 let mut merged = lr.clone();
480 for (k, v) in *rr {
481 merged.insert(format!("{}.{}", right_alias, k), v.clone());
482 }
483 next_rows.push(merged);
484 }
485 }
486 }
487 }
488 current_rows = next_rows;
489 }
490
491 let (mut result, columns) = execute(query, ¤t_rows, None)?;
492
493 if !result.is_empty() {
497 let mut base_counts: HashMap<String, usize> = HashMap::new();
498 for key in &columns {
499 if let Some((_prefix, base)) = key.split_once('.') {
500 *base_counts.entry(base.to_string()).or_default() += 1;
501 }
502 }
503 let unique_bases: Vec<String> = base_counts
504 .into_iter()
505 .filter(|(_, count)| *count == 1)
506 .map(|(base, _)| base)
507 .collect();
508
509 if !unique_bases.is_empty() {
510 let unique_set: std::collections::HashSet<&str> =
511 unique_bases.iter().map(|s| s.as_str()).collect();
512 for row in &mut result {
513 let additions: Vec<(String, Value)> = row
514 .iter()
515 .filter_map(|(k, v)| {
516 k.split_once('.').and_then(|(_, base)| {
517 if unique_set.contains(base) {
518 Some((base.to_string(), v.clone()))
519 } else {
520 None
521 }
522 })
523 })
524 .collect();
525 for (k, v) in additions {
526 row.insert(k, v);
527 }
528 }
529 }
530 }
531
532 Ok((result, columns))
533}
534
535fn reverse_alias(
537 table_name: &str,
538 aliases: &HashMap<String, String>,
539 query: &SelectQuery,
540 joins: &[JoinClause],
541) -> String {
542 if query.table == table_name {
544 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
545 }
546 for j in joins {
548 if j.table == table_name {
549 return j.alias.as_deref().unwrap_or(&j.table).to_string();
550 }
551 }
552 if aliases.contains_key(table_name) {
554 return table_name.to_string();
555 }
556 table_name.to_string()
557}
558
559fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
560 if let Some((alias, column)) = col.split_once('.') {
561 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
562 (table, column.to_string())
563 } else {
564 (String::new(), col.to_string())
565 }
566}
567
568fn execute(
569 query: &SelectQuery,
570 rows: &[Row],
571 index: Option<&crate::index::TableIndex>,
572) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
573 let empty_fts = HashMap::new();
574 execute_with_fts(query, rows, index, &empty_fts)
575}
576
577pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
578 match clause {
579 WhereClause::BoolOp(bop) => {
580 let left = evaluate(&bop.left, row);
581 match bop.op.as_str() {
582 "AND" => left && evaluate(&bop.right, row),
583 "OR" => left || evaluate(&bop.right, row),
584 _ => false,
585 }
586 }
587 WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
588 }
589}
590
591pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
593 match expr {
594 Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
595 Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
596 Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
597 Expr::Literal(SqlValue::Null) => Value::Null,
598 Expr::Literal(SqlValue::List(_)) => Value::Null,
599 Expr::Column(name) => {
600 if let Some(val) = row.get(name) {
601 return val.clone();
602 }
603 for (i, _) in name.match_indices('.') {
605 let dict_col = &name[..i];
606 let dict_key = &name[i + 1..];
607 if let Some(Value::Dict(map)) = row.get(dict_col) {
608 return map.get(dict_key).cloned().unwrap_or(Value::Null);
609 }
610 }
611 Value::Null
612 }
613 Expr::UnaryMinus(inner) => {
614 match evaluate_expr(inner, row) {
615 Value::Int(n) => Value::Int(-n),
616 Value::Float(f) => Value::Float(-f),
617 Value::Null => Value::Null,
618 _ => Value::Null, }
620 }
621 Expr::BinaryOp { left, op, right } => {
622 let lv = evaluate_expr(left, row);
623 let rv = evaluate_expr(right, row);
624
625 if lv.is_null() || rv.is_null() {
627 return Value::Null;
628 }
629
630 match (&lv, &rv) {
632 (Value::Int(a), Value::Int(b)) => {
633 match op {
634 ArithOp::Add => Value::Int(a.wrapping_add(*b)),
635 ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
636 ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
637 ArithOp::Div => {
638 if *b == 0 { Value::Null } else { Value::Int(a / b) }
639 }
640 ArithOp::Mod => {
641 if *b == 0 { Value::Null } else { Value::Int(a % b) }
642 }
643 }
644 }
645 _ => {
646 let a = match &lv {
648 Value::Int(n) => *n as f64,
649 Value::Float(f) => *f,
650 _ => return Value::Null,
651 };
652 let b = match &rv {
653 Value::Int(n) => *n as f64,
654 Value::Float(f) => *f,
655 _ => return Value::Null,
656 };
657 match op {
658 ArithOp::Add => Value::Float(a + b),
659 ArithOp::Sub => Value::Float(a - b),
660 ArithOp::Mul => Value::Float(a * b),
661 ArithOp::Div => {
662 if b == 0.0 { Value::Null } else { Value::Float(a / b) }
663 }
664 ArithOp::Mod => {
665 if b == 0.0 { Value::Null } else { Value::Float(a % b) }
666 }
667 }
668 }
669 }
670 }
671 Expr::Case { whens, else_expr } => {
672 for (condition, result) in whens {
673 if evaluate(condition, row) {
674 return evaluate_expr(result, row);
675 }
676 }
677 match else_expr {
678 Some(e) => evaluate_expr(e, row),
679 None => Value::Null,
680 }
681 }
682 Expr::CurrentDate => {
683 Value::Date(chrono::Local::now().naive_local().date())
684 }
685 Expr::CurrentTimestamp => {
686 Value::DateTime(chrono::Local::now().naive_local())
687 }
688 Expr::DateAdd { date, days } => {
689 let date_val = evaluate_expr(date, row);
690 let days_val = evaluate_expr(days, row);
691 let n = match &days_val {
692 Value::Int(n) => *n,
693 Value::Float(f) => *f as i64,
694 _ => return Value::Null,
695 };
696 let duration = chrono::Duration::days(n);
697 match date_val {
698 Value::Date(d) => {
699 match d.checked_add_signed(duration) {
700 Some(result) => Value::Date(result),
701 None => Value::Null,
702 }
703 }
704 Value::DateTime(dt) => {
705 match dt.checked_add_signed(duration) {
706 Some(result) => Value::DateTime(result),
707 None => Value::Null,
708 }
709 }
710 _ => Value::Null,
711 }
712 }
713 Expr::DateDiff { left, right } => {
714 let lv = evaluate_expr(left, row);
715 let rv = evaluate_expr(right, row);
716 let left_date = match &lv {
717 Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
718 Value::DateTime(dt) => *dt,
719 _ => return Value::Null,
720 };
721 let right_date = match &rv {
722 Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
723 Value::DateTime(dt) => *dt,
724 _ => return Value::Null,
725 };
726 Value::Int((left_date - right_date).num_days())
727 }
728 }
729}
730
731fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
732 if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
734 if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
735 let left_val = evaluate_expr(left_expr, row);
736 let right_val = evaluate_expr(right_expr, row);
737
738 if left_val.is_null() || right_val.is_null() {
740 return false;
741 }
742
743 let ord = compare_model_values(&left_val, &right_val);
745
746 return match cmp.op.as_str() {
747 "=" => ord == Some(Ordering::Equal),
748 "!=" => ord != Some(Ordering::Equal),
749 "<" => ord == Some(Ordering::Less),
750 ">" => ord == Some(Ordering::Greater),
751 "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
752 ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
753 _ => false,
754 };
755 }
756 }
757
758 let actual = row.get(&cmp.column);
760
761 if cmp.op == "IS NULL" {
762 return actual.map_or(true, |v| v.is_null());
763 }
764 if cmp.op == "IS NOT NULL" {
765 return actual.map_or(false, |v| !v.is_null());
766 }
767
768 let actual = match actual {
769 Some(v) if !v.is_null() => v,
770 _ => return false,
771 };
772
773 let expected = match &cmp.value {
774 Some(v) => v,
775 None => return false,
776 };
777
778 match cmp.op.as_str() {
779 "=" => eq_match(actual, expected),
780 "!=" => !eq_match(actual, expected),
781 "<" => compare_values(actual, expected) == Some(Ordering::Less),
782 ">" => compare_values(actual, expected) == Some(Ordering::Greater),
783 "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
784 ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
785 "LIKE" => like_match(actual, expected),
786 "NOT LIKE" => !like_match(actual, expected),
787 "IN" => {
788 if let SqlValue::List(items) = expected {
789 items.iter().any(|v| eq_match(actual, v))
790 } else {
791 eq_match(actual, expected)
792 }
793 }
794 _ => false,
795 }
796}
797
798fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
800 match (a, b) {
801 (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
802 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
803 _ => a.partial_cmp(b),
804 }
805}
806
807fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
808 match sql_val {
809 SqlValue::Null => Value::Null,
810 SqlValue::String(s) => {
811 match target {
812 Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
813 Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
814 Value::Date(_) => {
815 chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
816 .map(Value::Date)
817 .unwrap_or(Value::String(s.clone()))
818 }
819 Value::DateTime(_) => {
820 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
821 .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
822 .map(Value::DateTime)
823 .unwrap_or(Value::String(s.clone()))
824 }
825 _ => Value::String(s.clone()),
826 }
827 }
828 SqlValue::Int(n) => {
829 match target {
830 Value::Float(_) => Value::Float(*n as f64),
831 _ => Value::Int(*n),
832 }
833 }
834 SqlValue::Float(f) => Value::Float(*f),
835 SqlValue::List(_) => Value::Null, }
837}
838
839fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
840 if let Value::List(items) = actual {
842 if let SqlValue::String(s) = expected {
843 return items.contains(s);
844 }
845 }
846
847 let coerced = coerce_sql_to_value(expected, actual);
848 actual == &coerced
849}
850
851fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
852 let pattern_str = match pattern {
853 SqlValue::String(s) => s,
854 _ => return false,
855 };
856
857 let mut regex_str = String::from("(?is)^");
859 for ch in pattern_str.chars() {
860 match ch {
861 '%' => regex_str.push_str(".*"),
862 '_' => regex_str.push('.'),
863 c => {
864 if regex::escape(&c.to_string()) != c.to_string() {
865 regex_str.push_str(®ex::escape(&c.to_string()));
866 } else {
867 regex_str.push(c);
868 }
869 }
870 }
871 }
872 regex_str.push('$');
873
874 let re = match Regex::new(®ex_str) {
875 Ok(r) => r,
876 Err(_) => return false,
877 };
878
879 match actual {
880 Value::List(items) => items.iter().any(|item| re.is_match(item)),
881 _ => re.is_match(&actual.to_display_string()),
882 }
883}
884
885fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
886 let coerced = coerce_sql_to_value(expected, actual);
887 actual.partial_cmp(&coerced).map(|o| o)
888}
889
890fn sql_value_to_index_value(sv: &SqlValue) -> Value {
892 match sv {
893 SqlValue::String(s) => {
894 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
896 return Value::DateTime(dt);
897 }
898 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
899 return Value::DateTime(dt);
900 }
901 if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
903 return Value::Date(d);
904 }
905 Value::String(s.clone())
906 }
907 SqlValue::Int(n) => Value::Int(*n),
908 SqlValue::Float(f) => Value::Float(*f),
909 SqlValue::Null => Value::Null,
910 SqlValue::List(_) => Value::Null,
911 }
912}
913
914fn try_index_filter(
918 clause: &WhereClause,
919 index: &crate::index::TableIndex,
920) -> Option<std::collections::HashSet<String>> {
921 match clause {
922 WhereClause::Comparison(cmp) => {
923 if !index.has_index(&cmp.column) {
924 return None;
925 }
926 match cmp.op.as_str() {
927 "=" => {
928 let val = sql_value_to_index_value(cmp.value.as_ref()?);
929 let paths = index.lookup_eq(&cmp.column, &val);
930 Some(paths.into_iter().map(|s| s.to_string()).collect())
931 }
932 "<" => {
933 let val = sql_value_to_index_value(cmp.value.as_ref()?);
934 let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
937 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
938 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
939 }
940 ">" => {
941 let val = sql_value_to_index_value(cmp.value.as_ref()?);
942 let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
943 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
944 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
945 }
946 "<=" => {
947 let val = sql_value_to_index_value(cmp.value.as_ref()?);
948 let paths = index.lookup_range(&cmp.column, None, Some(&val));
949 Some(paths.into_iter().map(|s| s.to_string()).collect())
950 }
951 ">=" => {
952 let val = sql_value_to_index_value(cmp.value.as_ref()?);
953 let paths = index.lookup_range(&cmp.column, Some(&val), None);
954 Some(paths.into_iter().map(|s| s.to_string()).collect())
955 }
956 "IN" => {
957 if let Some(SqlValue::List(items)) = &cmp.value {
958 let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
959 let paths = index.lookup_in(&cmp.column, &vals);
960 Some(paths.into_iter().map(|s| s.to_string()).collect())
961 } else {
962 None
963 }
964 }
965 _ => None, }
967 }
968 WhereClause::BoolOp(bop) => {
969 let left = try_index_filter(&bop.left, index);
970 let right = try_index_filter(&bop.right, index);
971 match bop.op.as_str() {
972 "AND" => {
973 match (left, right) {
974 (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
975 (Some(l), None) => Some(l), (None, Some(r)) => Some(r),
977 (None, None) => None,
978 }
979 }
980 "OR" => {
981 match (left, right) {
982 (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
983 _ => None, }
985 }
986 _ => None,
987 }
988 }
989 }
990}
991
992fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
995 let named = match columns {
996 ColumnList::Named(exprs) => exprs,
997 _ => return specs.to_vec(),
998 };
999
1000 let alias_map: HashMap<String, &Expr> = named
1002 .iter()
1003 .filter_map(|se| match se {
1004 SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
1005 _ => None,
1006 })
1007 .collect();
1008
1009 specs
1010 .iter()
1011 .map(|spec| {
1012 if let Some(expr) = alias_map.get(&spec.column) {
1014 OrderSpec {
1015 column: spec.column.clone(),
1016 expr: Some((*expr).clone()),
1017 descending: spec.descending,
1018 }
1019 } else {
1020 spec.clone()
1021 }
1022 })
1023 .collect()
1024}
1025
1026fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
1027 rows.sort_by(|a, b| {
1028 for spec in specs {
1029 let (va, vb) = if let Some(ref expr) = spec.expr {
1030 (evaluate_expr(expr, a), evaluate_expr(expr, b))
1031 } else {
1032 (
1033 a.get(&spec.column).cloned().unwrap_or(Value::Null),
1034 b.get(&spec.column).cloned().unwrap_or(Value::Null),
1035 )
1036 };
1037
1038 let ordering = match (&va, &vb) {
1040 (Value::Null, Value::Null) => Ordering::Equal,
1041 (Value::Null, _) => Ordering::Greater,
1042 (_, Value::Null) => Ordering::Less,
1043 (a_val, b_val) => {
1044 compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
1045 }
1046 };
1047
1048 let ordering = if spec.descending {
1049 ordering.reverse()
1050 } else {
1051 ordering
1052 };
1053
1054 if ordering != Ordering::Equal {
1055 return ordering;
1056 }
1057 }
1058 Ordering::Equal
1059 });
1060}
1061
1062pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
1064 match sql_val {
1065 SqlValue::Null => Value::Null,
1066 SqlValue::String(s) => Value::String(s.clone()),
1067 SqlValue::Int(n) => Value::Int(*n),
1068 SqlValue::Float(f) => Value::Float(*f),
1069 SqlValue::List(items) => {
1070 let strings: Vec<String> = items
1071 .iter()
1072 .filter_map(|v| match v {
1073 SqlValue::String(s) => Some(s.clone()),
1074 _ => None,
1075 })
1076 .collect();
1077 Value::List(strings)
1078 }
1079 }
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084 use super::*;
1085
1086 fn make_rows() -> Vec<Row> {
1087 vec![
1088 Row::from([
1089 ("path".into(), Value::String("a.md".into())),
1090 ("title".into(), Value::String("Alpha".into())),
1091 ("count".into(), Value::Int(10)),
1092 ]),
1093 Row::from([
1094 ("path".into(), Value::String("b.md".into())),
1095 ("title".into(), Value::String("Beta".into())),
1096 ("count".into(), Value::Int(5)),
1097 ]),
1098 Row::from([
1099 ("path".into(), Value::String("c.md".into())),
1100 ("title".into(), Value::String("Gamma".into())),
1101 ("count".into(), Value::Int(20)),
1102 ]),
1103 ]
1104 }
1105
1106 #[test]
1107 fn test_select_all() {
1108 let q = SelectQuery {
1109 columns: ColumnList::All,
1110 table: "test".into(),
1111 table_alias: None,
1112 joins: vec![],
1113 where_clause: None,
1114 group_by: None,
1115 having: None,
1116 order_by: None,
1117 limit: None,
1118 };
1119 let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1120 assert_eq!(rows.len(), 3);
1121 }
1122
1123 #[test]
1124 fn test_where_gt() {
1125 let q = SelectQuery {
1126 columns: ColumnList::All,
1127 table: "test".into(),
1128 table_alias: None,
1129 joins: vec![],
1130 where_clause: Some(WhereClause::Comparison(Comparison {
1131 column: "count".into(),
1132 op: ">".into(),
1133 value: Some(SqlValue::Int(5)),
1134 left_expr: Some(Expr::Column("count".into())),
1135 right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1136 })),
1137 group_by: None,
1138 having: None,
1139 order_by: None,
1140 limit: None,
1141 };
1142 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1143 assert_eq!(rows.len(), 2);
1144 }
1145
1146 #[test]
1147 fn test_order_by_desc() {
1148 let q = SelectQuery {
1149 columns: ColumnList::All,
1150 table: "test".into(),
1151 table_alias: None,
1152 joins: vec![],
1153 where_clause: None,
1154 group_by: None,
1155 having: None,
1156 order_by: Some(vec![OrderSpec {
1157 column: "count".into(),
1158 expr: Some(Expr::Column("count".into())),
1159 descending: true,
1160 }]),
1161 limit: None,
1162 };
1163 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1164 assert_eq!(rows[0]["count"], Value::Int(20));
1165 assert_eq!(rows[2]["count"], Value::Int(5));
1166 }
1167
1168 #[test]
1169 fn test_limit() {
1170 let q = SelectQuery {
1171 columns: ColumnList::All,
1172 table: "test".into(),
1173 table_alias: None,
1174 joins: vec![],
1175 where_clause: None,
1176 group_by: None,
1177 having: None,
1178 order_by: None,
1179 limit: Some(2),
1180 };
1181 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1182 assert_eq!(rows.len(), 2);
1183 }
1184
1185 #[test]
1186 fn test_like() {
1187 let q = SelectQuery {
1188 columns: ColumnList::All,
1189 table: "test".into(),
1190 table_alias: None,
1191 joins: vec![],
1192 where_clause: Some(WhereClause::Comparison(Comparison {
1193 column: "title".into(),
1194 op: "LIKE".into(),
1195 value: Some(SqlValue::String("%lph%".into())),
1196 left_expr: Some(Expr::Column("title".into())),
1197 right_expr: None,
1198 })),
1199 group_by: None,
1200 having: None,
1201 order_by: None,
1202 limit: None,
1203 };
1204 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1205 assert_eq!(rows.len(), 1);
1206 assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1207 }
1208
1209 #[test]
1210 fn test_is_null() {
1211 let mut rows = make_rows();
1212 rows[1].insert("optional".into(), Value::Null);
1213
1214 let q = SelectQuery {
1215 columns: ColumnList::All,
1216 table: "test".into(),
1217 table_alias: None,
1218 joins: vec![],
1219 where_clause: Some(WhereClause::Comparison(Comparison {
1220 column: "optional".into(),
1221 op: "IS NULL".into(),
1222 value: None,
1223 left_expr: Some(Expr::Column("optional".into())),
1224 right_expr: None,
1225 })),
1226 group_by: None,
1227 having: None,
1228 order_by: None,
1229 limit: None,
1230 };
1231 let (result, _) = execute(&q, &rows, None).unwrap();
1232 assert_eq!(result.len(), 3);
1234 }
1235
1236 #[test]
1239 fn test_evaluate_expr_literal() {
1240 let row = Row::new();
1241 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1242 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1243 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1244 }
1245
1246 #[test]
1247 fn test_evaluate_expr_column() {
1248 let row = Row::from([("x".into(), Value::Int(10))]);
1249 assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1250 assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1251 }
1252
1253 #[test]
1254 fn test_evaluate_expr_int_arithmetic() {
1255 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1256 let add = Expr::BinaryOp {
1257 left: Box::new(Expr::Column("a".into())),
1258 op: ArithOp::Add,
1259 right: Box::new(Expr::Column("b".into())),
1260 };
1261 assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1262
1263 let sub = Expr::BinaryOp {
1264 left: Box::new(Expr::Column("a".into())),
1265 op: ArithOp::Sub,
1266 right: Box::new(Expr::Column("b".into())),
1267 };
1268 assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1269
1270 let mul = Expr::BinaryOp {
1271 left: Box::new(Expr::Column("a".into())),
1272 op: ArithOp::Mul,
1273 right: Box::new(Expr::Column("b".into())),
1274 };
1275 assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1276
1277 let div = Expr::BinaryOp {
1278 left: Box::new(Expr::Column("a".into())),
1279 op: ArithOp::Div,
1280 right: Box::new(Expr::Column("b".into())),
1281 };
1282 assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); let modulo = Expr::BinaryOp {
1285 left: Box::new(Expr::Column("a".into())),
1286 op: ArithOp::Mod,
1287 right: Box::new(Expr::Column("b".into())),
1288 };
1289 assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1290 }
1291
1292 #[test]
1293 fn test_evaluate_expr_float_coercion() {
1294 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1295 let add = Expr::BinaryOp {
1296 left: Box::new(Expr::Column("a".into())),
1297 op: ArithOp::Add,
1298 right: Box::new(Expr::Column("b".into())),
1299 };
1300 assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1301 }
1302
1303 #[test]
1304 fn test_evaluate_expr_null_propagation() {
1305 let row = Row::from([("a".into(), Value::Int(10))]);
1306 let add = Expr::BinaryOp {
1307 left: Box::new(Expr::Column("a".into())),
1308 op: ArithOp::Add,
1309 right: Box::new(Expr::Column("missing".into())),
1310 };
1311 assert_eq!(evaluate_expr(&add, &row), Value::Null);
1312 }
1313
1314 #[test]
1315 fn test_evaluate_expr_div_by_zero() {
1316 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1317 let div = Expr::BinaryOp {
1318 left: Box::new(Expr::Column("a".into())),
1319 op: ArithOp::Div,
1320 right: Box::new(Expr::Column("b".into())),
1321 };
1322 assert_eq!(evaluate_expr(&div, &row), Value::Null);
1323 }
1324
1325 #[test]
1326 fn test_evaluate_expr_unary_minus() {
1327 let row = Row::from([("x".into(), Value::Int(5))]);
1328 let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1329 assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1330 }
1331
1332 #[test]
1333 fn test_select_with_expression() {
1334 let stmt = crate::query_parser::parse_query(
1336 "SELECT count * 2 AS doubled FROM test"
1337 ).unwrap();
1338 if let crate::query_parser::Statement::Select(q) = stmt {
1339 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1340 assert_eq!(cols, vec!["doubled"]);
1341 assert_eq!(rows.len(), 3);
1342 let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1344 assert!(values.contains(&Value::Int(20)));
1345 assert!(values.contains(&Value::Int(10)));
1346 assert!(values.contains(&Value::Int(40)));
1347 } else {
1348 panic!("Expected Select");
1349 }
1350 }
1351
1352 #[test]
1353 fn test_where_with_expression() {
1354 let stmt = crate::query_parser::parse_query(
1356 "SELECT * FROM test WHERE count * 2 > 15"
1357 ).unwrap();
1358 if let crate::query_parser::Statement::Select(q) = stmt {
1359 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1360 assert_eq!(rows.len(), 2);
1362 } else {
1363 panic!("Expected Select");
1364 }
1365 }
1366
1367 #[test]
1368 fn test_order_by_expression() {
1369 let stmt = crate::query_parser::parse_query(
1371 "SELECT title, count FROM test ORDER BY count * -1 ASC"
1372 ).unwrap();
1373 if let crate::query_parser::Statement::Select(q) = stmt {
1374 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1375 assert_eq!(rows[0]["count"], Value::Int(20));
1377 assert_eq!(rows[1]["count"], Value::Int(10));
1378 assert_eq!(rows[2]["count"], Value::Int(5));
1379 } else {
1380 panic!("Expected Select");
1381 }
1382 }
1383
1384 #[test]
1387 fn test_case_when_eval_basic() {
1388 let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1389 let expr = Expr::Case {
1390 whens: vec![(
1391 WhereClause::Comparison(Comparison {
1392 column: "status".into(),
1393 op: "=".into(),
1394 value: Some(SqlValue::String("ACTIVE".into())),
1395 left_expr: Some(Expr::Column("status".into())),
1396 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1397 }),
1398 Box::new(Expr::Literal(SqlValue::Int(1))),
1399 )],
1400 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1401 };
1402 assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1403 }
1404
1405 #[test]
1406 fn test_case_when_eval_else() {
1407 let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1408 let expr = Expr::Case {
1409 whens: vec![(
1410 WhereClause::Comparison(Comparison {
1411 column: "status".into(),
1412 op: "=".into(),
1413 value: Some(SqlValue::String("ACTIVE".into())),
1414 left_expr: Some(Expr::Column("status".into())),
1415 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1416 }),
1417 Box::new(Expr::Literal(SqlValue::Int(1))),
1418 )],
1419 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1420 };
1421 assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1422 }
1423
1424 #[test]
1425 fn test_case_when_eval_no_else_null() {
1426 let row = Row::from([("x".into(), Value::Int(99))]);
1427 let expr = Expr::Case {
1428 whens: vec![(
1429 WhereClause::Comparison(Comparison {
1430 column: "x".into(),
1431 op: "=".into(),
1432 value: Some(SqlValue::Int(1)),
1433 left_expr: Some(Expr::Column("x".into())),
1434 right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1435 }),
1436 Box::new(Expr::Literal(SqlValue::String("one".into()))),
1437 )],
1438 else_expr: None,
1439 };
1440 assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1441 }
1442
1443 #[test]
1444 fn test_case_when_in_aggregate_query() {
1445 let stmt = crate::query_parser::parse_query(
1448 "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1449 ).unwrap();
1450 if let crate::query_parser::Statement::Select(q) = stmt {
1451 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1452 assert_eq!(cols, vec!["total"]);
1453 assert_eq!(rows.len(), 1);
1454 assert_eq!(rows[0]["total"], Value::Float(30.0));
1455 } else {
1456 panic!("Expected Select");
1457 }
1458 }
1459
1460 #[test]
1461 fn test_case_when_with_unary_minus_in_aggregate() {
1462 let stmt = crate::query_parser::parse_query(
1465 "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1466 ).unwrap();
1467 if let crate::query_parser::Statement::Select(q) = stmt {
1468 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1469 assert_eq!(rows.len(), 1);
1470 assert_eq!(rows[0]["net"], Value::Float(-15.0));
1471 } else {
1472 panic!("Expected Select");
1473 }
1474 }
1475
1476 #[test]
1477 fn test_dateadd_with_dict_in_group_by() {
1478 use indexmap::IndexMap;
1480 let mut params = IndexMap::new();
1481 params.insert("exit_days".to_string(), Value::Int(21));
1482
1483 let rows = vec![
1484 Row::from([
1485 ("o.token".into(), Value::String("BTC".into())),
1486 ("o.event_date".into(), Value::Date(
1487 chrono::NaiveDate::from_ymd_opt(2026, 1, 1).unwrap()
1488 )),
1489 ("o.size".into(), Value::Int(100)),
1490 ("s.params".into(), Value::Dict(params.clone())),
1491 ]),
1492 Row::from([
1493 ("o.token".into(), Value::String("BTC".into())),
1494 ("o.event_date".into(), Value::Date(
1495 chrono::NaiveDate::from_ymd_opt(2026, 1, 1).unwrap()
1496 )),
1497 ("o.size".into(), Value::Int(50)),
1498 ("s.params".into(), Value::Dict(params.clone())),
1499 ]),
1500 ];
1501
1502 let q = SelectQuery {
1503 columns: ColumnList::Named(vec![
1504 SelectExpr::Column("o.token".into()),
1505 SelectExpr::Column("o.event_date".into()),
1506 SelectExpr::Expr {
1507 expr: Expr::DateAdd {
1508 date: Box::new(Expr::Column("o.event_date".into())),
1509 days: Box::new(Expr::Column("s.params.exit_days".into())),
1510 },
1511 alias: Some("exit_date".into()),
1512 },
1513 SelectExpr::Aggregate {
1514 func: AggFunc::Sum,
1515 arg: "o.size".into(),
1516 arg_expr: Some(Expr::Column("o.size".into())),
1517 alias: Some("total".into()),
1518 },
1519 ]),
1520 table: "orders".into(),
1521 table_alias: None,
1522 joins: vec![],
1523 where_clause: None,
1524 group_by: Some(vec!["o.token".into(), "o.event_date".into()]),
1525 having: None,
1526 order_by: None,
1527 limit: None,
1528 };
1529
1530 let (rows, cols) = execute(&q, &rows, None).unwrap();
1531 assert_eq!(rows.len(), 1);
1532 assert!(cols.contains(&"exit_date".to_string()));
1533 assert_eq!(rows[0]["total"], Value::Float(150.0));
1534 assert_eq!(
1536 rows[0]["exit_date"],
1537 Value::Date(chrono::NaiveDate::from_ymd_opt(2026, 1, 22).unwrap())
1538 );
1539 }
1540}