1use std::cell::RefCell;
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8
9use chrono::Datelike;
10use regex::Regex;
11use rust_decimal::Decimal;
12use rustledger_core::{
13 Amount, Directive, InternedStr, Inventory, NaiveDate, Position, Transaction,
14};
15
16use crate::ast::{
17 BalancesQuery, BinaryOp, BinaryOperator, Expr, FromClause, FunctionCall, JournalQuery, Literal,
18 OrderSpec, PrintQuery, Query, SelectQuery, SortDirection, Target, UnaryOp, UnaryOperator,
19 WindowFunction,
20};
21use crate::error::QueryError;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Value {
26 String(String),
28 Number(Decimal),
30 Integer(i64),
32 Date(NaiveDate),
34 Boolean(bool),
36 Amount(Amount),
38 Position(Position),
40 Inventory(Inventory),
42 StringSet(Vec<String>),
44 Null,
46}
47
48impl Value {
49 fn hash_value<H: Hasher>(&self, state: &mut H) {
55 std::mem::discriminant(self).hash(state);
56 match self {
57 Self::String(s) => s.hash(state),
58 Self::Number(d) => d.serialize().hash(state),
59 Self::Integer(i) => i.hash(state),
60 Self::Date(d) => {
61 d.year().hash(state);
62 d.month().hash(state);
63 d.day().hash(state);
64 }
65 Self::Boolean(b) => b.hash(state),
66 Self::Amount(a) => {
67 a.number.serialize().hash(state);
68 a.currency.as_str().hash(state);
69 }
70 Self::Position(p) => {
71 p.units.number.serialize().hash(state);
72 p.units.currency.as_str().hash(state);
73 if let Some(cost) = &p.cost {
74 cost.number.serialize().hash(state);
75 cost.currency.as_str().hash(state);
76 }
77 }
78 Self::Inventory(inv) => {
79 for pos in inv.positions() {
80 pos.units.number.serialize().hash(state);
81 pos.units.currency.as_str().hash(state);
82 if let Some(cost) = &pos.cost {
83 cost.number.serialize().hash(state);
84 cost.currency.as_str().hash(state);
85 }
86 }
87 }
88 Self::StringSet(ss) => {
89 let mut sorted = ss.clone();
91 sorted.sort();
92 for s in &sorted {
93 s.hash(state);
94 }
95 }
96 Self::Null => {}
97 }
98 }
99}
100
101fn hash_row(row: &Row) -> u64 {
103 use std::collections::hash_map::DefaultHasher;
104 let mut hasher = DefaultHasher::new();
105 for value in row {
106 value.hash_value(&mut hasher);
107 }
108 hasher.finish()
109}
110
111fn hash_single_value(value: &Value) -> u64 {
113 use std::collections::hash_map::DefaultHasher;
114 let mut hasher = DefaultHasher::new();
115 value.hash_value(&mut hasher);
116 hasher.finish()
117}
118
119pub type Row = Vec<Value>;
121
122#[derive(Debug, Clone)]
124pub struct QueryResult {
125 pub columns: Vec<String>,
127 pub rows: Vec<Row>,
129}
130
131impl QueryResult {
132 pub const fn new(columns: Vec<String>) -> Self {
134 Self {
135 columns,
136 rows: Vec::new(),
137 }
138 }
139
140 pub fn add_row(&mut self, row: Row) {
142 self.rows.push(row);
143 }
144
145 pub fn len(&self) -> usize {
147 self.rows.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.rows.is_empty()
153 }
154}
155
156#[derive(Debug)]
158pub struct PostingContext<'a> {
159 pub transaction: &'a Transaction,
161 pub posting_index: usize,
163 pub balance: Option<Inventory>,
165}
166
167#[derive(Debug, Clone)]
169pub struct WindowContext {
170 pub row_number: usize,
172 pub rank: usize,
174 pub dense_rank: usize,
176}
177
178pub struct Executor<'a> {
180 directives: &'a [Directive],
182 balances: HashMap<InternedStr, Inventory>,
184 price_db: crate::price::PriceDatabase,
186 target_currency: Option<String>,
188 regex_cache: RefCell<HashMap<String, Option<Regex>>>,
190}
191
192impl<'a> Executor<'a> {
193 pub fn new(directives: &'a [Directive]) -> Self {
195 let price_db = crate::price::PriceDatabase::from_directives(directives);
196 Self {
197 directives,
198 balances: HashMap::new(),
199 price_db,
200 target_currency: None,
201 regex_cache: RefCell::new(HashMap::new()),
202 }
203 }
204
205 fn get_or_compile_regex(&self, pattern: &str) -> Option<Regex> {
210 let mut cache = self.regex_cache.borrow_mut();
211 if let Some(cached) = cache.get(pattern) {
212 return cached.clone();
213 }
214 let compiled = Regex::new(pattern).ok();
215 cache.insert(pattern.to_string(), compiled.clone());
216 compiled
217 }
218
219 fn require_regex(&self, pattern: &str) -> Result<Regex, QueryError> {
221 self.get_or_compile_regex(pattern)
222 .ok_or_else(|| QueryError::Type(format!("invalid regex: {pattern}")))
223 }
224
225 pub fn set_target_currency(&mut self, currency: impl Into<String>) {
227 self.target_currency = Some(currency.into());
228 }
229
230 pub fn execute(&mut self, query: &Query) -> Result<QueryResult, QueryError> {
243 match query {
244 Query::Select(select) => self.execute_select(select),
245 Query::Journal(journal) => self.execute_journal(journal),
246 Query::Balances(balances) => self.execute_balances(balances),
247 Query::Print(print) => self.execute_print(print),
248 }
249 }
250
251 fn execute_select(&self, query: &SelectQuery) -> Result<QueryResult, QueryError> {
253 if let Some(from) = &query.from {
255 if let Some(subquery) = &from.subquery {
256 return self.execute_select_from_subquery(query, subquery);
257 }
258 }
259
260 let column_names = self.resolve_column_names(&query.targets)?;
262 let mut result = QueryResult::new(column_names.clone());
263
264 let postings = self.collect_postings(query.from.as_ref(), query.where_clause.as_ref())?;
266
267 let is_aggregate = query
269 .targets
270 .iter()
271 .any(|t| Self::is_aggregate_expr(&t.expr));
272
273 if is_aggregate {
274 let grouped = self.group_postings(&postings, query.group_by.as_ref())?;
276 for (_, group) in grouped {
277 let row = self.evaluate_aggregate_row(&query.targets, &group)?;
278
279 if let Some(having_expr) = &query.having {
281 if !self.evaluate_having_filter(
282 having_expr,
283 &row,
284 &column_names,
285 &query.targets,
286 &group,
287 )? {
288 continue;
289 }
290 }
291
292 result.add_row(row);
293 }
294 } else {
295 let has_windows = Self::has_window_functions(&query.targets);
297 let window_contexts = if has_windows {
298 if let Some(wf) = Self::find_window_function(&query.targets) {
299 Some(self.compute_window_contexts(&postings, wf)?)
300 } else {
301 None
302 }
303 } else {
304 None
305 };
306
307 let mut seen_hashes: HashSet<u64> = if query.distinct {
310 HashSet::with_capacity(postings.len())
311 } else {
312 HashSet::new()
313 };
314
315 for (i, ctx) in postings.iter().enumerate() {
316 let row = if let Some(ref wctxs) = window_contexts {
317 self.evaluate_row_with_window(&query.targets, ctx, Some(&wctxs[i]))?
318 } else {
319 self.evaluate_row(&query.targets, ctx)?
320 };
321 if query.distinct {
322 let row_hash = hash_row(&row);
324 if seen_hashes.insert(row_hash) {
325 result.add_row(row);
326 }
327 } else {
328 result.add_row(row);
329 }
330 }
331 }
332
333 if let Some(pivot_exprs) = &query.pivot_by {
335 result = self.apply_pivot(&result, pivot_exprs, &query.targets)?;
336 }
337
338 if let Some(order_by) = &query.order_by {
340 self.sort_results(&mut result, order_by)?;
341 }
342
343 if let Some(limit) = query.limit {
345 result.rows.truncate(limit as usize);
346 }
347
348 Ok(result)
349 }
350
351 fn execute_select_from_subquery(
353 &self,
354 outer_query: &SelectQuery,
355 inner_query: &SelectQuery,
356 ) -> Result<QueryResult, QueryError> {
357 let inner_result = self.execute_select(inner_query)?;
359
360 let inner_column_map: HashMap<String, usize> = inner_result
362 .columns
363 .iter()
364 .enumerate()
365 .map(|(i, name)| (name.to_lowercase(), i))
366 .collect();
367
368 let outer_column_names =
370 self.resolve_subquery_column_names(&outer_query.targets, &inner_result.columns)?;
371 let mut result = QueryResult::new(outer_column_names);
372
373 let mut seen_hashes: HashSet<u64> = if outer_query.distinct {
375 HashSet::with_capacity(inner_result.rows.len())
376 } else {
377 HashSet::new()
378 };
379
380 for inner_row in &inner_result.rows {
382 if let Some(where_expr) = &outer_query.where_clause {
384 if !self.evaluate_subquery_filter(where_expr, inner_row, &inner_column_map)? {
385 continue;
386 }
387 }
388
389 let outer_row =
391 self.evaluate_subquery_row(&outer_query.targets, inner_row, &inner_column_map)?;
392
393 if outer_query.distinct {
394 let row_hash = hash_row(&outer_row);
396 if seen_hashes.insert(row_hash) {
397 result.add_row(outer_row);
398 }
399 } else {
400 result.add_row(outer_row);
401 }
402 }
403
404 if let Some(order_by) = &outer_query.order_by {
406 self.sort_results(&mut result, order_by)?;
407 }
408
409 if let Some(limit) = outer_query.limit {
411 result.rows.truncate(limit as usize);
412 }
413
414 Ok(result)
415 }
416
417 fn resolve_subquery_column_names(
419 &self,
420 targets: &[Target],
421 inner_columns: &[String],
422 ) -> Result<Vec<String>, QueryError> {
423 let mut names = Vec::new();
424 for (i, target) in targets.iter().enumerate() {
425 if let Some(alias) = &target.alias {
426 names.push(alias.clone());
427 } else if matches!(target.expr, Expr::Wildcard) {
428 names.extend(inner_columns.iter().cloned());
430 } else {
431 names.push(self.expr_to_name(&target.expr, i));
432 }
433 }
434 Ok(names)
435 }
436
437 fn evaluate_subquery_filter(
439 &self,
440 expr: &Expr,
441 row: &[Value],
442 column_map: &HashMap<String, usize>,
443 ) -> Result<bool, QueryError> {
444 let val = self.evaluate_subquery_expr(expr, row, column_map)?;
445 self.to_bool(&val)
446 }
447
448 fn evaluate_subquery_expr(
450 &self,
451 expr: &Expr,
452 row: &[Value],
453 column_map: &HashMap<String, usize>,
454 ) -> Result<Value, QueryError> {
455 match expr {
456 Expr::Wildcard => Err(QueryError::Evaluation(
457 "Wildcard not allowed in expression context".to_string(),
458 )),
459 Expr::Column(name) => {
460 let lower = name.to_lowercase();
461 if let Some(&idx) = column_map.get(&lower) {
462 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
463 } else {
464 Err(QueryError::Evaluation(format!(
465 "Unknown column '{name}' in subquery result"
466 )))
467 }
468 }
469 Expr::Literal(lit) => self.evaluate_literal(lit),
470 Expr::Function(func) => {
471 let args: Vec<Value> = func
473 .args
474 .iter()
475 .map(|a| self.evaluate_subquery_expr(a, row, column_map))
476 .collect::<Result<Vec<_>, _>>()?;
477 self.evaluate_function_on_values(&func.name, &args)
478 }
479 Expr::BinaryOp(op) => {
480 let left = self.evaluate_subquery_expr(&op.left, row, column_map)?;
481 let right = self.evaluate_subquery_expr(&op.right, row, column_map)?;
482 self.binary_op_on_values(op.op, &left, &right)
483 }
484 Expr::UnaryOp(op) => {
485 let val = self.evaluate_subquery_expr(&op.operand, row, column_map)?;
486 self.unary_op_on_value(op.op, &val)
487 }
488 Expr::Paren(inner) => self.evaluate_subquery_expr(inner, row, column_map),
489 Expr::Window(_) => Err(QueryError::Evaluation(
490 "Window functions not supported in subquery expressions".to_string(),
491 )),
492 }
493 }
494
495 fn evaluate_subquery_row(
497 &self,
498 targets: &[Target],
499 inner_row: &[Value],
500 column_map: &HashMap<String, usize>,
501 ) -> Result<Row, QueryError> {
502 let mut row = Vec::new();
503 for target in targets {
504 if matches!(target.expr, Expr::Wildcard) {
505 row.extend(inner_row.iter().cloned());
507 } else {
508 row.push(self.evaluate_subquery_expr(&target.expr, inner_row, column_map)?);
509 }
510 }
511 Ok(row)
512 }
513
514 fn execute_journal(&mut self, query: &JournalQuery) -> Result<QueryResult, QueryError> {
516 let account_pattern = &query.account_pattern;
518
519 let account_regex = self.get_or_compile_regex(account_pattern);
521
522 let columns = vec![
523 "date".to_string(),
524 "flag".to_string(),
525 "payee".to_string(),
526 "narration".to_string(),
527 "account".to_string(),
528 "position".to_string(),
529 "balance".to_string(),
530 ];
531 let mut result = QueryResult::new(columns);
532
533 for directive in self.directives {
535 if let Directive::Transaction(txn) = directive {
536 if let Some(from) = &query.from {
538 if let Some(filter) = &from.filter {
539 if !self.evaluate_from_filter(filter, txn)? {
540 continue;
541 }
542 }
543 }
544
545 for posting in &txn.postings {
546 let matches = if let Some(ref regex) = account_regex {
548 regex.is_match(&posting.account)
549 } else {
550 posting.account.contains(account_pattern)
551 };
552
553 if matches {
554 let balance = self.balances.entry(posting.account.clone()).or_default();
556
557 if let Some(units) = posting.amount() {
559 let pos = if let Some(cost_spec) = &posting.cost {
560 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
561 Position::with_cost(units.clone(), cost)
562 } else {
563 Position::simple(units.clone())
564 }
565 } else {
566 Position::simple(units.clone())
567 };
568 balance.add(pos.clone());
569 }
570
571 let position_value = if let Some(at_func) = &query.at_function {
573 match at_func.to_uppercase().as_str() {
574 "COST" => {
575 if let Some(units) = posting.amount() {
576 if let Some(cost_spec) = &posting.cost {
577 if let Some(cost) =
578 cost_spec.resolve(units.number, txn.date)
579 {
580 let total = units.number * cost.number;
581 Value::Amount(Amount::new(total, &cost.currency))
582 } else {
583 Value::Amount(units.clone())
584 }
585 } else {
586 Value::Amount(units.clone())
587 }
588 } else {
589 Value::Null
590 }
591 }
592 "UNITS" => posting
593 .amount()
594 .map_or(Value::Null, |u| Value::Amount(u.clone())),
595 _ => posting
596 .amount()
597 .map_or(Value::Null, |u| Value::Amount(u.clone())),
598 }
599 } else {
600 posting
601 .amount()
602 .map_or(Value::Null, |u| Value::Amount(u.clone()))
603 };
604
605 let row = vec![
606 Value::Date(txn.date),
607 Value::String(txn.flag.to_string()),
608 Value::String(
609 txn.payee
610 .as_ref()
611 .map_or_else(String::new, ToString::to_string),
612 ),
613 Value::String(txn.narration.to_string()),
614 Value::String(posting.account.to_string()),
615 position_value,
616 Value::Inventory(balance.clone()),
617 ];
618 result.add_row(row);
619 }
620 }
621 }
622 }
623
624 Ok(result)
625 }
626
627 fn execute_balances(&mut self, query: &BalancesQuery) -> Result<QueryResult, QueryError> {
629 self.build_balances_with_filter(query.from.as_ref())?;
631
632 let columns = vec!["account".to_string(), "balance".to_string()];
633 let mut result = QueryResult::new(columns);
634
635 let mut accounts: Vec<_> = self.balances.keys().collect();
637 accounts.sort();
638
639 for account in accounts {
640 let Some(balance) = self.balances.get(account) else {
642 continue; };
644
645 let balance_value = if let Some(at_func) = &query.at_function {
647 match at_func.to_uppercase().as_str() {
648 "COST" => {
649 let cost_inventory = balance.at_cost();
651 Value::Inventory(cost_inventory)
652 }
653 "UNITS" => {
654 let units_inventory = balance.at_units();
656 Value::Inventory(units_inventory)
657 }
658 _ => Value::Inventory(balance.clone()),
659 }
660 } else {
661 Value::Inventory(balance.clone())
662 };
663
664 let row = vec![Value::String(account.to_string()), balance_value];
665 result.add_row(row);
666 }
667
668 Ok(result)
669 }
670
671 fn execute_print(&self, query: &PrintQuery) -> Result<QueryResult, QueryError> {
673 let columns = vec!["directive".to_string()];
675 let mut result = QueryResult::new(columns);
676
677 for directive in self.directives {
678 if let Some(from) = &query.from {
680 if let Some(filter) = &from.filter {
681 if let Directive::Transaction(txn) = directive {
683 if !self.evaluate_from_filter(filter, txn)? {
684 continue;
685 }
686 }
687 }
688 }
689
690 let formatted = self.format_directive(directive);
692 result.add_row(vec![Value::String(formatted)]);
693 }
694
695 Ok(result)
696 }
697
698 fn format_directive(&self, directive: &Directive) -> String {
700 match directive {
701 Directive::Transaction(txn) => {
702 let mut out = format!("{} {} ", txn.date, txn.flag);
703 if let Some(payee) = &txn.payee {
704 out.push_str(&format!("\"{payee}\" "));
705 }
706 out.push_str(&format!("\"{}\"", txn.narration));
707
708 for tag in &txn.tags {
709 out.push_str(&format!(" #{tag}"));
710 }
711 for link in &txn.links {
712 out.push_str(&format!(" ^{link}"));
713 }
714 out.push('\n');
715
716 for posting in &txn.postings {
717 out.push_str(&format!(" {}", posting.account));
718 if let Some(units) = posting.amount() {
719 out.push_str(&format!(" {} {}", units.number, units.currency));
720 }
721 out.push('\n');
722 }
723 out
724 }
725 Directive::Balance(bal) => {
726 format!(
727 "{} balance {} {} {}\n",
728 bal.date, bal.account, bal.amount.number, bal.amount.currency
729 )
730 }
731 Directive::Open(open) => {
732 let mut out = format!("{} open {}", open.date, open.account);
733 if !open.currencies.is_empty() {
734 out.push_str(&format!(" {}", open.currencies.join(",")));
735 }
736 out.push('\n');
737 out
738 }
739 Directive::Close(close) => {
740 format!("{} close {}\n", close.date, close.account)
741 }
742 Directive::Commodity(comm) => {
743 format!("{} commodity {}\n", comm.date, comm.currency)
744 }
745 Directive::Pad(pad) => {
746 format!("{} pad {} {}\n", pad.date, pad.account, pad.source_account)
747 }
748 Directive::Event(event) => {
749 format!(
750 "{} event \"{}\" \"{}\"\n",
751 event.date, event.event_type, event.value
752 )
753 }
754 Directive::Query(query) => {
755 format!(
756 "{} query \"{}\" \"{}\"\n",
757 query.date, query.name, query.query
758 )
759 }
760 Directive::Note(note) => {
761 format!("{} note {} \"{}\"\n", note.date, note.account, note.comment)
762 }
763 Directive::Document(doc) => {
764 format!("{} document {} \"{}\"\n", doc.date, doc.account, doc.path)
765 }
766 Directive::Price(price) => {
767 format!(
768 "{} price {} {} {}\n",
769 price.date, price.currency, price.amount.number, price.amount.currency
770 )
771 }
772 Directive::Custom(custom) => {
773 format!("{} custom \"{}\"\n", custom.date, custom.custom_type)
774 }
775 }
776 }
777
778 fn build_balances_with_filter(&mut self, from: Option<&FromClause>) -> Result<(), QueryError> {
780 for directive in self.directives {
781 if let Directive::Transaction(txn) = directive {
782 if let Some(from_clause) = from {
784 if let Some(filter) = &from_clause.filter {
785 if !self.evaluate_from_filter(filter, txn)? {
786 continue;
787 }
788 }
789 }
790
791 for posting in &txn.postings {
792 if let Some(units) = posting.amount() {
793 let balance = self.balances.entry(posting.account.clone()).or_default();
794
795 let pos = if let Some(cost_spec) = &posting.cost {
796 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
797 Position::with_cost(units.clone(), cost)
798 } else {
799 Position::simple(units.clone())
800 }
801 } else {
802 Position::simple(units.clone())
803 };
804 balance.add(pos);
805 }
806 }
807 }
808 }
809 Ok(())
810 }
811
812 fn collect_postings(
814 &self,
815 from: Option<&FromClause>,
816 where_clause: Option<&Expr>,
817 ) -> Result<Vec<PostingContext<'a>>, QueryError> {
818 let mut postings = Vec::new();
819 let mut running_balances: HashMap<InternedStr, Inventory> = HashMap::new();
821
822 for directive in self.directives {
823 if let Directive::Transaction(txn) = directive {
824 if let Some(from) = from {
826 if let Some(open_date) = from.open_on {
828 if txn.date < open_date {
829 for posting in &txn.postings {
831 if let Some(units) = posting.amount() {
832 let balance = running_balances
833 .entry(posting.account.clone())
834 .or_default();
835 balance.add(Position::simple(units.clone()));
836 }
837 }
838 continue;
839 }
840 }
841 if let Some(close_date) = from.close_on {
842 if txn.date > close_date {
843 continue;
844 }
845 }
846 if let Some(filter) = &from.filter {
848 if !self.evaluate_from_filter(filter, txn)? {
849 continue;
850 }
851 }
852 }
853
854 for (i, posting) in txn.postings.iter().enumerate() {
856 if let Some(units) = posting.amount() {
858 let balance = running_balances.entry(posting.account.clone()).or_default();
859 balance.add(Position::simple(units.clone()));
860 }
861
862 let ctx = PostingContext {
863 transaction: txn,
864 posting_index: i,
865 balance: running_balances.get(&posting.account).cloned(),
866 };
867
868 if let Some(where_expr) = where_clause {
870 if self.evaluate_predicate(where_expr, &ctx)? {
871 postings.push(ctx);
872 }
873 } else {
874 postings.push(ctx);
875 }
876 }
877 }
878 }
879
880 Ok(postings)
881 }
882
883 fn evaluate_from_filter(&self, filter: &Expr, txn: &Transaction) -> Result<bool, QueryError> {
885 match filter {
887 Expr::Function(func) => {
888 if func.name.to_uppercase().as_str() == "HAS_ACCOUNT" {
889 if func.args.len() != 1 {
890 return Err(QueryError::InvalidArguments(
891 "has_account".to_string(),
892 "expected 1 argument".to_string(),
893 ));
894 }
895 let pattern = match &func.args[0] {
896 Expr::Literal(Literal::String(s)) => s.clone(),
897 Expr::Column(s) => s.clone(),
898 _ => {
899 return Err(QueryError::Type(
900 "has_account expects a string pattern".to_string(),
901 ));
902 }
903 };
904 let regex = self.require_regex(&pattern)?;
906 for posting in &txn.postings {
907 if regex.is_match(&posting.account) {
908 return Ok(true);
909 }
910 }
911 Ok(false)
912 } else {
913 let dummy_ctx = PostingContext {
915 transaction: txn,
916 posting_index: 0,
917 balance: None,
918 };
919 self.evaluate_predicate(filter, &dummy_ctx)
920 }
921 }
922 Expr::BinaryOp(op) => {
923 match (&op.left, &op.right) {
925 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "YEAR" => {
926 if let Literal::Integer(n) = lit {
927 let matches = txn.date.year() == *n as i32;
928 Ok(if op.op == BinaryOperator::Eq {
929 matches
930 } else {
931 !matches
932 })
933 } else {
934 Ok(false)
935 }
936 }
937 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "MONTH" => {
938 if let Literal::Integer(n) = lit {
939 let matches = txn.date.month() == *n as u32;
940 Ok(if op.op == BinaryOperator::Eq {
941 matches
942 } else {
943 !matches
944 })
945 } else {
946 Ok(false)
947 }
948 }
949 (Expr::Column(col), Expr::Literal(Literal::Date(d)))
950 if col.to_uppercase() == "DATE" =>
951 {
952 let matches = match op.op {
953 BinaryOperator::Eq => txn.date == *d,
954 BinaryOperator::Ne => txn.date != *d,
955 BinaryOperator::Lt => txn.date < *d,
956 BinaryOperator::Le => txn.date <= *d,
957 BinaryOperator::Gt => txn.date > *d,
958 BinaryOperator::Ge => txn.date >= *d,
959 _ => false,
960 };
961 Ok(matches)
962 }
963 _ => {
964 let dummy_ctx = PostingContext {
966 transaction: txn,
967 posting_index: 0,
968 balance: None,
969 };
970 self.evaluate_predicate(filter, &dummy_ctx)
971 }
972 }
973 }
974 _ => {
975 let dummy_ctx = PostingContext {
977 transaction: txn,
978 posting_index: 0,
979 balance: None,
980 };
981 self.evaluate_predicate(filter, &dummy_ctx)
982 }
983 }
984 }
985
986 fn evaluate_predicate(&self, expr: &Expr, ctx: &PostingContext) -> Result<bool, QueryError> {
988 let value = self.evaluate_expr(expr, ctx)?;
989 match value {
990 Value::Boolean(b) => Ok(b),
991 Value::Null => Ok(false),
992 _ => Err(QueryError::Type("expected boolean expression".to_string())),
993 }
994 }
995
996 fn evaluate_expr(&self, expr: &Expr, ctx: &PostingContext) -> Result<Value, QueryError> {
998 match expr {
999 Expr::Wildcard => Ok(Value::Null), Expr::Column(name) => self.evaluate_column(name, ctx),
1001 Expr::Literal(lit) => self.evaluate_literal(lit),
1002 Expr::Function(func) => self.evaluate_function(func, ctx),
1003 Expr::Window(_) => {
1004 Err(QueryError::Evaluation(
1007 "Window function cannot be evaluated in posting context".to_string(),
1008 ))
1009 }
1010 Expr::BinaryOp(op) => self.evaluate_binary_op(op, ctx),
1011 Expr::UnaryOp(op) => self.evaluate_unary_op(op, ctx),
1012 Expr::Paren(inner) => self.evaluate_expr(inner, ctx),
1013 }
1014 }
1015
1016 fn evaluate_column(&self, name: &str, ctx: &PostingContext) -> Result<Value, QueryError> {
1018 let posting = &ctx.transaction.postings[ctx.posting_index];
1019
1020 match name {
1021 "date" => Ok(Value::Date(ctx.transaction.date)),
1022 "account" => Ok(Value::String(posting.account.to_string())),
1023 "narration" => Ok(Value::String(ctx.transaction.narration.to_string())),
1024 "payee" => Ok(ctx
1025 .transaction
1026 .payee
1027 .as_ref()
1028 .map_or(Value::Null, |p| Value::String(p.to_string()))),
1029 "flag" => Ok(Value::String(ctx.transaction.flag.to_string())),
1030 "tags" => Ok(Value::StringSet(
1031 ctx.transaction
1032 .tags
1033 .iter()
1034 .map(ToString::to_string)
1035 .collect(),
1036 )),
1037 "links" => Ok(Value::StringSet(
1038 ctx.transaction
1039 .links
1040 .iter()
1041 .map(ToString::to_string)
1042 .collect(),
1043 )),
1044 "position" | "units" => Ok(posting
1045 .amount()
1046 .map_or(Value::Null, |u| Value::Amount(u.clone()))),
1047 "cost" => {
1048 if let Some(units) = posting.amount() {
1050 if let Some(cost) = &posting.cost {
1051 if let Some(number_per) = &cost.number_per {
1052 if let Some(currency) = &cost.currency {
1053 let total = units.number.abs() * number_per;
1054 return Ok(Value::Amount(Amount::new(total, currency.clone())));
1055 }
1056 }
1057 }
1058 }
1059 Ok(Value::Null)
1060 }
1061 "weight" => {
1062 if let Some(units) = posting.amount() {
1066 if let Some(cost) = &posting.cost {
1067 if let Some(number_per) = &cost.number_per {
1068 if let Some(currency) = &cost.currency {
1069 let total = units.number * number_per;
1070 return Ok(Value::Amount(Amount::new(total, currency.clone())));
1071 }
1072 }
1073 }
1074 Ok(Value::Amount(units.clone()))
1076 } else {
1077 Ok(Value::Null)
1078 }
1079 }
1080 "balance" => {
1081 if let Some(ref balance) = ctx.balance {
1083 Ok(Value::Inventory(balance.clone()))
1084 } else {
1085 Ok(Value::Null)
1086 }
1087 }
1088 "year" => Ok(Value::Integer(ctx.transaction.date.year().into())),
1089 "month" => Ok(Value::Integer(ctx.transaction.date.month().into())),
1090 "day" => Ok(Value::Integer(ctx.transaction.date.day().into())),
1091 _ => Err(QueryError::UnknownColumn(name.to_string())),
1092 }
1093 }
1094
1095 fn evaluate_literal(&self, lit: &Literal) -> Result<Value, QueryError> {
1097 Ok(match lit {
1098 Literal::String(s) => Value::String(s.clone()),
1099 Literal::Number(n) => Value::Number(*n),
1100 Literal::Integer(i) => Value::Integer(*i),
1101 Literal::Date(d) => Value::Date(*d),
1102 Literal::Boolean(b) => Value::Boolean(*b),
1103 Literal::Null => Value::Null,
1104 })
1105 }
1106
1107 fn evaluate_function(
1111 &self,
1112 func: &FunctionCall,
1113 ctx: &PostingContext,
1114 ) -> Result<Value, QueryError> {
1115 let name = func.name.to_uppercase();
1116 match name.as_str() {
1117 "YEAR" | "MONTH" | "DAY" | "WEEKDAY" | "QUARTER" | "YMONTH" | "TODAY" => {
1119 self.eval_date_function(&name, func, ctx)
1120 }
1121 "LENGTH" | "UPPER" | "LOWER" | "SUBSTR" | "SUBSTRING" | "TRIM" | "STARTSWITH"
1123 | "ENDSWITH" => self.eval_string_function(&name, func, ctx),
1124 "PARENT" | "LEAF" | "ROOT" | "ACCOUNT_DEPTH" | "ACCOUNT_SORTKEY" => {
1126 self.eval_account_function(&name, func, ctx)
1127 }
1128 "ABS" | "NEG" | "ROUND" | "SAFEDIV" => self.eval_math_function(&name, func, ctx),
1130 "NUMBER" | "CURRENCY" | "GETITEM" | "GET" | "UNITS" | "COST" | "WEIGHT" | "VALUE" => {
1132 self.eval_position_function(&name, func, ctx)
1133 }
1134 "COALESCE" => self.eval_coalesce(func, ctx),
1136 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1139 _ => Err(QueryError::UnknownFunction(func.name.clone())),
1140 }
1141 }
1142
1143 fn eval_date_function(
1145 &self,
1146 name: &str,
1147 func: &FunctionCall,
1148 ctx: &PostingContext,
1149 ) -> Result<Value, QueryError> {
1150 if name == "TODAY" {
1151 if !func.args.is_empty() {
1152 return Err(QueryError::InvalidArguments(
1153 "TODAY".to_string(),
1154 "expected 0 arguments".to_string(),
1155 ));
1156 }
1157 return Ok(Value::Date(chrono::Local::now().date_naive()));
1158 }
1159
1160 if func.args.len() != 1 {
1162 return Err(QueryError::InvalidArguments(
1163 name.to_string(),
1164 "expected 1 argument".to_string(),
1165 ));
1166 }
1167
1168 let val = self.evaluate_expr(&func.args[0], ctx)?;
1169 let date = match val {
1170 Value::Date(d) => d,
1171 _ => return Err(QueryError::Type(format!("{name} expects a date"))),
1172 };
1173
1174 match name {
1175 "YEAR" => Ok(Value::Integer(date.year().into())),
1176 "MONTH" => Ok(Value::Integer(date.month().into())),
1177 "DAY" => Ok(Value::Integer(date.day().into())),
1178 "WEEKDAY" => Ok(Value::Integer(date.weekday().num_days_from_monday().into())),
1179 "QUARTER" => {
1180 let quarter = (date.month() - 1) / 3 + 1;
1181 Ok(Value::Integer(quarter.into()))
1182 }
1183 "YMONTH" => Ok(Value::String(format!(
1184 "{:04}-{:02}",
1185 date.year(),
1186 date.month()
1187 ))),
1188 _ => unreachable!(),
1189 }
1190 }
1191
1192 fn eval_string_function(
1194 &self,
1195 name: &str,
1196 func: &FunctionCall,
1197 ctx: &PostingContext,
1198 ) -> Result<Value, QueryError> {
1199 match name {
1200 "LENGTH" => {
1201 Self::require_args(name, func, 1)?;
1202 let val = self.evaluate_expr(&func.args[0], ctx)?;
1203 match val {
1204 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1205 Value::StringSet(s) => Ok(Value::Integer(s.len() as i64)),
1206 _ => Err(QueryError::Type(
1207 "LENGTH expects a string or set".to_string(),
1208 )),
1209 }
1210 }
1211 "UPPER" => {
1212 Self::require_args(name, func, 1)?;
1213 let val = self.evaluate_expr(&func.args[0], ctx)?;
1214 match val {
1215 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1216 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1217 }
1218 }
1219 "LOWER" => {
1220 Self::require_args(name, func, 1)?;
1221 let val = self.evaluate_expr(&func.args[0], ctx)?;
1222 match val {
1223 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1224 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1225 }
1226 }
1227 "TRIM" => {
1228 Self::require_args(name, func, 1)?;
1229 let val = self.evaluate_expr(&func.args[0], ctx)?;
1230 match val {
1231 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1232 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1233 }
1234 }
1235 "SUBSTR" | "SUBSTRING" => self.eval_substr(func, ctx),
1236 "STARTSWITH" => {
1237 Self::require_args(name, func, 2)?;
1238 let val = self.evaluate_expr(&func.args[0], ctx)?;
1239 let prefix = self.evaluate_expr(&func.args[1], ctx)?;
1240 match (val, prefix) {
1241 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.starts_with(&p))),
1242 _ => Err(QueryError::Type(
1243 "STARTSWITH expects two strings".to_string(),
1244 )),
1245 }
1246 }
1247 "ENDSWITH" => {
1248 Self::require_args(name, func, 2)?;
1249 let val = self.evaluate_expr(&func.args[0], ctx)?;
1250 let suffix = self.evaluate_expr(&func.args[1], ctx)?;
1251 match (val, suffix) {
1252 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.ends_with(&p))),
1253 _ => Err(QueryError::Type("ENDSWITH expects two strings".to_string())),
1254 }
1255 }
1256 _ => unreachable!(),
1257 }
1258 }
1259
1260 fn eval_substr(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1262 if func.args.len() < 2 || func.args.len() > 3 {
1263 return Err(QueryError::InvalidArguments(
1264 "SUBSTR".to_string(),
1265 "expected 2 or 3 arguments".to_string(),
1266 ));
1267 }
1268
1269 let val = self.evaluate_expr(&func.args[0], ctx)?;
1270 let start = self.evaluate_expr(&func.args[1], ctx)?;
1271 let len = if func.args.len() == 3 {
1272 Some(self.evaluate_expr(&func.args[2], ctx)?)
1273 } else {
1274 None
1275 };
1276
1277 match (val, start, len) {
1278 (Value::String(s), Value::Integer(start), None) => {
1279 let start = start.max(0) as usize;
1280 if start >= s.len() {
1281 Ok(Value::String(String::new()))
1282 } else {
1283 Ok(Value::String(s[start..].to_string()))
1284 }
1285 }
1286 (Value::String(s), Value::Integer(start), Some(Value::Integer(len))) => {
1287 let start = start.max(0) as usize;
1288 let len = len.max(0) as usize;
1289 if start >= s.len() {
1290 Ok(Value::String(String::new()))
1291 } else {
1292 let end = (start + len).min(s.len());
1293 Ok(Value::String(s[start..end].to_string()))
1294 }
1295 }
1296 _ => Err(QueryError::Type(
1297 "SUBSTR expects (string, int, [int])".to_string(),
1298 )),
1299 }
1300 }
1301
1302 fn eval_account_function(
1304 &self,
1305 name: &str,
1306 func: &FunctionCall,
1307 ctx: &PostingContext,
1308 ) -> Result<Value, QueryError> {
1309 match name {
1310 "PARENT" => {
1311 Self::require_args(name, func, 1)?;
1312 let val = self.evaluate_expr(&func.args[0], ctx)?;
1313 match val {
1314 Value::String(s) => {
1315 if let Some(idx) = s.rfind(':') {
1316 Ok(Value::String(s[..idx].to_string()))
1317 } else {
1318 Ok(Value::Null)
1319 }
1320 }
1321 _ => Err(QueryError::Type(
1322 "PARENT expects an account string".to_string(),
1323 )),
1324 }
1325 }
1326 "LEAF" => {
1327 Self::require_args(name, func, 1)?;
1328 let val = self.evaluate_expr(&func.args[0], ctx)?;
1329 match val {
1330 Value::String(s) => {
1331 if let Some(idx) = s.rfind(':') {
1332 Ok(Value::String(s[idx + 1..].to_string()))
1333 } else {
1334 Ok(Value::String(s))
1335 }
1336 }
1337 _ => Err(QueryError::Type(
1338 "LEAF expects an account string".to_string(),
1339 )),
1340 }
1341 }
1342 "ROOT" => self.eval_root(func, ctx),
1343 "ACCOUNT_DEPTH" => {
1344 Self::require_args(name, func, 1)?;
1345 let val = self.evaluate_expr(&func.args[0], ctx)?;
1346 match val {
1347 Value::String(s) => {
1348 let depth = s.chars().filter(|c| *c == ':').count() + 1;
1349 Ok(Value::Integer(depth as i64))
1350 }
1351 _ => Err(QueryError::Type(
1352 "ACCOUNT_DEPTH expects an account string".to_string(),
1353 )),
1354 }
1355 }
1356 "ACCOUNT_SORTKEY" => {
1357 Self::require_args(name, func, 1)?;
1358 let val = self.evaluate_expr(&func.args[0], ctx)?;
1359 match val {
1360 Value::String(s) => Ok(Value::String(s)),
1361 _ => Err(QueryError::Type(
1362 "ACCOUNT_SORTKEY expects an account string".to_string(),
1363 )),
1364 }
1365 }
1366 _ => unreachable!(),
1367 }
1368 }
1369
1370 fn eval_root(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1372 if func.args.is_empty() || func.args.len() > 2 {
1373 return Err(QueryError::InvalidArguments(
1374 "ROOT".to_string(),
1375 "expected 1 or 2 arguments".to_string(),
1376 ));
1377 }
1378
1379 let val = self.evaluate_expr(&func.args[0], ctx)?;
1380 let n = if func.args.len() == 2 {
1381 match self.evaluate_expr(&func.args[1], ctx)? {
1382 Value::Integer(i) => i as usize,
1383 _ => {
1384 return Err(QueryError::Type(
1385 "ROOT second arg must be integer".to_string(),
1386 ));
1387 }
1388 }
1389 } else {
1390 1
1391 };
1392
1393 match val {
1394 Value::String(s) => {
1395 let parts: Vec<&str> = s.split(':').collect();
1396 if n >= parts.len() {
1397 Ok(Value::String(s))
1398 } else {
1399 Ok(Value::String(parts[..n].join(":")))
1400 }
1401 }
1402 _ => Err(QueryError::Type(
1403 "ROOT expects an account string".to_string(),
1404 )),
1405 }
1406 }
1407
1408 fn eval_math_function(
1410 &self,
1411 name: &str,
1412 func: &FunctionCall,
1413 ctx: &PostingContext,
1414 ) -> Result<Value, QueryError> {
1415 match name {
1416 "ABS" => {
1417 Self::require_args(name, func, 1)?;
1418 let val = self.evaluate_expr(&func.args[0], ctx)?;
1419 match val {
1420 Value::Number(n) => Ok(Value::Number(n.abs())),
1421 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1422 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1423 }
1424 }
1425 "NEG" => {
1426 Self::require_args(name, func, 1)?;
1427 let val = self.evaluate_expr(&func.args[0], ctx)?;
1428 match val {
1429 Value::Number(n) => Ok(Value::Number(-n)),
1430 Value::Integer(i) => Ok(Value::Integer(-i)),
1431 _ => Err(QueryError::Type("NEG expects a number".to_string())),
1432 }
1433 }
1434 "ROUND" => self.eval_round(func, ctx),
1435 "SAFEDIV" => self.eval_safediv(func, ctx),
1436 _ => unreachable!(),
1437 }
1438 }
1439
1440 fn eval_round(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1442 if func.args.is_empty() || func.args.len() > 2 {
1443 return Err(QueryError::InvalidArguments(
1444 "ROUND".to_string(),
1445 "expected 1 or 2 arguments".to_string(),
1446 ));
1447 }
1448
1449 let val = self.evaluate_expr(&func.args[0], ctx)?;
1450 let decimals = if func.args.len() == 2 {
1451 match self.evaluate_expr(&func.args[1], ctx)? {
1452 Value::Integer(i) => i as u32,
1453 _ => {
1454 return Err(QueryError::Type(
1455 "ROUND second arg must be integer".to_string(),
1456 ));
1457 }
1458 }
1459 } else {
1460 0
1461 };
1462
1463 match val {
1464 Value::Number(n) => Ok(Value::Number(n.round_dp(decimals))),
1465 Value::Integer(i) => Ok(Value::Integer(i)),
1466 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1467 }
1468 }
1469
1470 fn eval_safediv(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1472 Self::require_args("SAFEDIV", func, 2)?;
1473 let num = self.evaluate_expr(&func.args[0], ctx)?;
1474 let den = self.evaluate_expr(&func.args[1], ctx)?;
1475
1476 match (num, den) {
1477 (Value::Number(n), Value::Number(d)) => {
1478 if d.is_zero() {
1479 Ok(Value::Number(Decimal::ZERO))
1480 } else {
1481 Ok(Value::Number(n / d))
1482 }
1483 }
1484 (Value::Integer(n), Value::Integer(d)) => {
1485 if d == 0 {
1486 Ok(Value::Integer(0))
1487 } else {
1488 Ok(Value::Integer(n / d))
1489 }
1490 }
1491 _ => Err(QueryError::Type("SAFEDIV expects two numbers".to_string())),
1492 }
1493 }
1494
1495 fn eval_position_function(
1497 &self,
1498 name: &str,
1499 func: &FunctionCall,
1500 ctx: &PostingContext,
1501 ) -> Result<Value, QueryError> {
1502 match name {
1503 "NUMBER" => {
1504 Self::require_args(name, func, 1)?;
1505 let val = self.evaluate_expr(&func.args[0], ctx)?;
1506 match val {
1507 Value::Amount(a) => Ok(Value::Number(a.number)),
1508 Value::Position(p) => Ok(Value::Number(p.units.number)),
1509 Value::Number(n) => Ok(Value::Number(n)),
1510 Value::Integer(i) => Ok(Value::Number(Decimal::from(i))),
1511 _ => Err(QueryError::Type(
1512 "NUMBER expects an amount or position".to_string(),
1513 )),
1514 }
1515 }
1516 "CURRENCY" => {
1517 Self::require_args(name, func, 1)?;
1518 let val = self.evaluate_expr(&func.args[0], ctx)?;
1519 match val {
1520 Value::Amount(a) => Ok(Value::String(a.currency.to_string())),
1521 Value::Position(p) => Ok(Value::String(p.units.currency.to_string())),
1522 _ => Err(QueryError::Type(
1523 "CURRENCY expects an amount or position".to_string(),
1524 )),
1525 }
1526 }
1527 "GETITEM" | "GET" => self.eval_getitem(func, ctx),
1528 "UNITS" => self.eval_units(func, ctx),
1529 "COST" => self.eval_cost(func, ctx),
1530 "WEIGHT" => self.eval_weight(func, ctx),
1531 "VALUE" => self.eval_value(func, ctx),
1532 _ => unreachable!(),
1533 }
1534 }
1535
1536 fn eval_getitem(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1538 Self::require_args("GETITEM", func, 2)?;
1539 let val = self.evaluate_expr(&func.args[0], ctx)?;
1540 let key = self.evaluate_expr(&func.args[1], ctx)?;
1541
1542 match (val, key) {
1543 (Value::Inventory(inv), Value::String(currency)) => {
1544 let total = inv.units(¤cy);
1545 if total.is_zero() {
1546 Ok(Value::Null)
1547 } else {
1548 Ok(Value::Amount(Amount::new(total, currency)))
1549 }
1550 }
1551 _ => Err(QueryError::Type(
1552 "GETITEM expects (inventory, string)".to_string(),
1553 )),
1554 }
1555 }
1556
1557 fn eval_units(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1559 Self::require_args("UNITS", func, 1)?;
1560 let val = self.evaluate_expr(&func.args[0], ctx)?;
1561
1562 match val {
1563 Value::Position(p) => Ok(Value::Amount(p.units)),
1564 Value::Amount(a) => Ok(Value::Amount(a)),
1565 Value::Inventory(inv) => {
1566 let positions: Vec<String> = inv
1567 .positions()
1568 .iter()
1569 .map(|p| format!("{} {}", p.units.number, p.units.currency))
1570 .collect();
1571 Ok(Value::String(positions.join(", ")))
1572 }
1573 _ => Err(QueryError::Type(
1574 "UNITS expects a position or inventory".to_string(),
1575 )),
1576 }
1577 }
1578
1579 fn eval_cost(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1581 Self::require_args("COST", func, 1)?;
1582 let val = self.evaluate_expr(&func.args[0], ctx)?;
1583
1584 match val {
1585 Value::Position(p) => {
1586 if let Some(cost) = &p.cost {
1587 let total = p.units.number.abs() * cost.number;
1588 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1589 } else {
1590 Ok(Value::Null)
1591 }
1592 }
1593 Value::Amount(a) => Ok(Value::Amount(a)),
1594 Value::Inventory(inv) => {
1595 let mut total = Decimal::ZERO;
1596 let mut currency: Option<InternedStr> = None;
1597 for pos in inv.positions() {
1598 if let Some(cost) = &pos.cost {
1599 total += pos.units.number.abs() * cost.number;
1600 if currency.is_none() {
1601 currency = Some(cost.currency.clone());
1602 }
1603 }
1604 }
1605 if let Some(curr) = currency {
1606 Ok(Value::Amount(Amount::new(total, curr)))
1607 } else {
1608 Ok(Value::Null)
1609 }
1610 }
1611 _ => Err(QueryError::Type(
1612 "COST expects a position or inventory".to_string(),
1613 )),
1614 }
1615 }
1616
1617 fn eval_weight(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1619 Self::require_args("WEIGHT", func, 1)?;
1620 let val = self.evaluate_expr(&func.args[0], ctx)?;
1621
1622 match val {
1623 Value::Position(p) => {
1624 if let Some(cost) = &p.cost {
1625 let total = p.units.number * cost.number;
1626 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1627 } else {
1628 Ok(Value::Amount(p.units))
1629 }
1630 }
1631 Value::Amount(a) => Ok(Value::Amount(a)),
1632 _ => Err(QueryError::Type(
1633 "WEIGHT expects a position or amount".to_string(),
1634 )),
1635 }
1636 }
1637
1638 fn eval_value(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1640 if func.args.is_empty() || func.args.len() > 2 {
1641 return Err(QueryError::InvalidArguments(
1642 "VALUE".to_string(),
1643 "expected 1-2 arguments".to_string(),
1644 ));
1645 }
1646
1647 let target_currency = if func.args.len() == 2 {
1648 match self.evaluate_expr(&func.args[1], ctx)? {
1649 Value::String(s) => s,
1650 _ => {
1651 return Err(QueryError::Type(
1652 "VALUE second argument must be a currency string".to_string(),
1653 ));
1654 }
1655 }
1656 } else {
1657 self.target_currency.clone().ok_or_else(|| {
1658 QueryError::InvalidArguments(
1659 "VALUE".to_string(),
1660 "no target currency set; either call set_target_currency() on the executor \
1661 or pass the currency as VALUE(amount, 'USD')"
1662 .to_string(),
1663 )
1664 })?
1665 };
1666
1667 let val = self.evaluate_expr(&func.args[0], ctx)?;
1668 let date = ctx.transaction.date;
1669
1670 match val {
1671 Value::Position(p) => {
1672 if p.units.currency == target_currency {
1673 Ok(Value::Amount(p.units))
1674 } else if let Some(converted) =
1675 self.price_db.convert(&p.units, &target_currency, date)
1676 {
1677 Ok(Value::Amount(converted))
1678 } else {
1679 Ok(Value::Amount(p.units))
1680 }
1681 }
1682 Value::Amount(a) => {
1683 if a.currency == target_currency {
1684 Ok(Value::Amount(a))
1685 } else if let Some(converted) = self.price_db.convert(&a, &target_currency, date) {
1686 Ok(Value::Amount(converted))
1687 } else {
1688 Ok(Value::Amount(a))
1689 }
1690 }
1691 Value::Inventory(inv) => {
1692 let mut total = Decimal::ZERO;
1693 for pos in inv.positions() {
1694 if pos.units.currency == target_currency {
1695 total += pos.units.number;
1696 } else if let Some(converted) =
1697 self.price_db.convert(&pos.units, &target_currency, date)
1698 {
1699 total += converted.number;
1700 }
1701 }
1702 Ok(Value::Amount(Amount::new(total, &target_currency)))
1703 }
1704 _ => Err(QueryError::Type(
1705 "VALUE expects a position or inventory".to_string(),
1706 )),
1707 }
1708 }
1709
1710 fn eval_coalesce(
1712 &self,
1713 func: &FunctionCall,
1714 ctx: &PostingContext,
1715 ) -> Result<Value, QueryError> {
1716 for arg in &func.args {
1717 let val = self.evaluate_expr(arg, ctx)?;
1718 if !matches!(val, Value::Null) {
1719 return Ok(val);
1720 }
1721 }
1722 Ok(Value::Null)
1723 }
1724
1725 fn evaluate_function_on_values(&self, name: &str, args: &[Value]) -> Result<Value, QueryError> {
1727 let name_upper = name.to_uppercase();
1728 match name_upper.as_str() {
1729 "TODAY" => Ok(Value::Date(chrono::Local::now().date_naive())),
1731 "YEAR" => {
1732 Self::require_args_count(&name_upper, args, 1)?;
1733 match &args[0] {
1734 Value::Date(d) => Ok(Value::Integer(d.year().into())),
1735 _ => Err(QueryError::Type("YEAR expects a date".to_string())),
1736 }
1737 }
1738 "MONTH" => {
1739 Self::require_args_count(&name_upper, args, 1)?;
1740 match &args[0] {
1741 Value::Date(d) => Ok(Value::Integer(d.month().into())),
1742 _ => Err(QueryError::Type("MONTH expects a date".to_string())),
1743 }
1744 }
1745 "DAY" => {
1746 Self::require_args_count(&name_upper, args, 1)?;
1747 match &args[0] {
1748 Value::Date(d) => Ok(Value::Integer(d.day().into())),
1749 _ => Err(QueryError::Type("DAY expects a date".to_string())),
1750 }
1751 }
1752 "LENGTH" => {
1754 Self::require_args_count(&name_upper, args, 1)?;
1755 match &args[0] {
1756 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1757 _ => Err(QueryError::Type("LENGTH expects a string".to_string())),
1758 }
1759 }
1760 "UPPER" => {
1761 Self::require_args_count(&name_upper, args, 1)?;
1762 match &args[0] {
1763 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1764 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1765 }
1766 }
1767 "LOWER" => {
1768 Self::require_args_count(&name_upper, args, 1)?;
1769 match &args[0] {
1770 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1771 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1772 }
1773 }
1774 "TRIM" => {
1775 Self::require_args_count(&name_upper, args, 1)?;
1776 match &args[0] {
1777 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1778 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1779 }
1780 }
1781 "ABS" => {
1783 Self::require_args_count(&name_upper, args, 1)?;
1784 match &args[0] {
1785 Value::Number(n) => Ok(Value::Number(n.abs())),
1786 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1787 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1788 }
1789 }
1790 "ROUND" => {
1791 if args.is_empty() || args.len() > 2 {
1792 return Err(QueryError::InvalidArguments(
1793 "ROUND".to_string(),
1794 "expected 1 or 2 arguments".to_string(),
1795 ));
1796 }
1797 match &args[0] {
1798 Value::Number(n) => {
1799 let scale = if args.len() == 2 {
1800 match &args[1] {
1801 Value::Integer(i) => *i as u32,
1802 _ => 0,
1803 }
1804 } else {
1805 0
1806 };
1807 Ok(Value::Number(n.round_dp(scale)))
1808 }
1809 Value::Integer(i) => Ok(Value::Integer(*i)),
1810 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1811 }
1812 }
1813 "COALESCE" => {
1815 for arg in args {
1816 if !matches!(arg, Value::Null) {
1817 return Ok(arg.clone());
1818 }
1819 }
1820 Ok(Value::Null)
1821 }
1822 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1824 _ => Err(QueryError::UnknownFunction(name.to_string())),
1825 }
1826 }
1827
1828 fn require_args_count(name: &str, args: &[Value], expected: usize) -> Result<(), QueryError> {
1830 if args.len() != expected {
1831 return Err(QueryError::InvalidArguments(
1832 name.to_string(),
1833 format!("expected {} argument(s), got {}", expected, args.len()),
1834 ));
1835 }
1836 Ok(())
1837 }
1838
1839 fn require_args(name: &str, func: &FunctionCall, expected: usize) -> Result<(), QueryError> {
1841 if func.args.len() != expected {
1842 return Err(QueryError::InvalidArguments(
1843 name.to_string(),
1844 format!("expected {expected} argument(s)"),
1845 ));
1846 }
1847 Ok(())
1848 }
1849
1850 fn evaluate_binary_op(&self, op: &BinaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1852 let left = self.evaluate_expr(&op.left, ctx)?;
1853 let right = self.evaluate_expr(&op.right, ctx)?;
1854
1855 match op.op {
1856 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(&left, &right))),
1857 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(&left, &right))),
1858 BinaryOperator::Lt => self.compare_values(&left, &right, std::cmp::Ordering::is_lt),
1859 BinaryOperator::Le => self.compare_values(&left, &right, std::cmp::Ordering::is_le),
1860 BinaryOperator::Gt => self.compare_values(&left, &right, std::cmp::Ordering::is_gt),
1861 BinaryOperator::Ge => self.compare_values(&left, &right, std::cmp::Ordering::is_ge),
1862 BinaryOperator::And => {
1863 let l = self.to_bool(&left)?;
1864 let r = self.to_bool(&right)?;
1865 Ok(Value::Boolean(l && r))
1866 }
1867 BinaryOperator::Or => {
1868 let l = self.to_bool(&left)?;
1869 let r = self.to_bool(&right)?;
1870 Ok(Value::Boolean(l || r))
1871 }
1872 BinaryOperator::Regex => {
1873 let s = match left {
1875 Value::String(s) => s,
1876 _ => {
1877 return Err(QueryError::Type(
1878 "regex requires string left operand".to_string(),
1879 ));
1880 }
1881 };
1882 let pattern = match right {
1883 Value::String(p) => p,
1884 _ => {
1885 return Err(QueryError::Type(
1886 "regex requires string pattern".to_string(),
1887 ));
1888 }
1889 };
1890 Ok(Value::Boolean(s.contains(&pattern)))
1892 }
1893 BinaryOperator::In => {
1894 match right {
1896 Value::StringSet(set) => {
1897 let needle = match left {
1898 Value::String(s) => s,
1899 _ => {
1900 return Err(QueryError::Type(
1901 "IN requires string left operand".to_string(),
1902 ));
1903 }
1904 };
1905 Ok(Value::Boolean(set.contains(&needle)))
1906 }
1907 _ => Err(QueryError::Type(
1908 "IN requires set right operand".to_string(),
1909 )),
1910 }
1911 }
1912 BinaryOperator::Add => self.arithmetic_op(&left, &right, |a, b| a + b),
1913 BinaryOperator::Sub => self.arithmetic_op(&left, &right, |a, b| a - b),
1914 BinaryOperator::Mul => self.arithmetic_op(&left, &right, |a, b| a * b),
1915 BinaryOperator::Div => self.arithmetic_op(&left, &right, |a, b| a / b),
1916 }
1917 }
1918
1919 fn evaluate_unary_op(&self, op: &UnaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1921 let val = self.evaluate_expr(&op.operand, ctx)?;
1922 self.unary_op_on_value(op.op, &val)
1923 }
1924
1925 fn unary_op_on_value(&self, op: UnaryOperator, val: &Value) -> Result<Value, QueryError> {
1927 match op {
1928 UnaryOperator::Not => {
1929 let b = self.to_bool(val)?;
1930 Ok(Value::Boolean(!b))
1931 }
1932 UnaryOperator::Neg => match val {
1933 Value::Number(n) => Ok(Value::Number(-*n)),
1934 Value::Integer(i) => Ok(Value::Integer(-*i)),
1935 _ => Err(QueryError::Type(
1936 "negation requires numeric value".to_string(),
1937 )),
1938 },
1939 }
1940 }
1941
1942 fn values_equal(&self, left: &Value, right: &Value) -> bool {
1944 match (left, right) {
1946 (Value::Null, Value::Null) => true,
1947 (Value::String(a), Value::String(b)) => a == b,
1948 (Value::Number(a), Value::Number(b)) => a == b,
1949 (Value::Integer(a), Value::Integer(b)) => a == b,
1950 (Value::Number(a), Value::Integer(b)) => *a == Decimal::from(*b),
1951 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a) == *b,
1952 (Value::Date(a), Value::Date(b)) => a == b,
1953 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1954 _ => false,
1955 }
1956 }
1957
1958 fn compare_values<F>(&self, left: &Value, right: &Value, pred: F) -> Result<Value, QueryError>
1960 where
1961 F: FnOnce(std::cmp::Ordering) -> bool,
1962 {
1963 let ord = match (left, right) {
1964 (Value::Number(a), Value::Number(b)) => a.cmp(b),
1965 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
1966 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
1967 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
1968 (Value::String(a), Value::String(b)) => a.cmp(b),
1969 (Value::Date(a), Value::Date(b)) => a.cmp(b),
1970 _ => return Err(QueryError::Type("cannot compare values".to_string())),
1971 };
1972 Ok(Value::Boolean(pred(ord)))
1973 }
1974
1975 fn value_less_than(&self, left: &Value, right: &Value) -> Result<bool, QueryError> {
1977 let ord = match (left, right) {
1978 (Value::Number(a), Value::Number(b)) => a.cmp(b),
1979 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
1980 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
1981 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
1982 (Value::String(a), Value::String(b)) => a.cmp(b),
1983 (Value::Date(a), Value::Date(b)) => a.cmp(b),
1984 _ => return Err(QueryError::Type("cannot compare values".to_string())),
1985 };
1986 Ok(ord.is_lt())
1987 }
1988
1989 fn arithmetic_op<F>(&self, left: &Value, right: &Value, op: F) -> Result<Value, QueryError>
1991 where
1992 F: FnOnce(Decimal, Decimal) -> Decimal,
1993 {
1994 let (a, b) = match (left, right) {
1995 (Value::Number(a), Value::Number(b)) => (*a, *b),
1996 (Value::Integer(a), Value::Integer(b)) => (Decimal::from(*a), Decimal::from(*b)),
1997 (Value::Number(a), Value::Integer(b)) => (*a, Decimal::from(*b)),
1998 (Value::Integer(a), Value::Number(b)) => (Decimal::from(*a), *b),
1999 _ => {
2000 return Err(QueryError::Type(
2001 "arithmetic requires numeric values".to_string(),
2002 ));
2003 }
2004 };
2005 Ok(Value::Number(op(a, b)))
2006 }
2007
2008 fn to_bool(&self, val: &Value) -> Result<bool, QueryError> {
2010 match val {
2011 Value::Boolean(b) => Ok(*b),
2012 Value::Null => Ok(false),
2013 _ => Err(QueryError::Type("expected boolean".to_string())),
2014 }
2015 }
2016
2017 fn is_aggregate_expr(expr: &Expr) -> bool {
2019 match expr {
2020 Expr::Function(func) => {
2021 matches!(
2022 func.name.to_uppercase().as_str(),
2023 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG"
2024 )
2025 }
2026 Expr::BinaryOp(op) => {
2027 Self::is_aggregate_expr(&op.left) || Self::is_aggregate_expr(&op.right)
2028 }
2029 Expr::UnaryOp(op) => Self::is_aggregate_expr(&op.operand),
2030 Expr::Paren(inner) => Self::is_aggregate_expr(inner),
2031 _ => false,
2032 }
2033 }
2034
2035 const fn is_window_expr(expr: &Expr) -> bool {
2037 matches!(expr, Expr::Window(_))
2038 }
2039
2040 fn has_window_functions(targets: &[Target]) -> bool {
2042 targets.iter().any(|t| Self::is_window_expr(&t.expr))
2043 }
2044
2045 fn resolve_column_names(&self, targets: &[Target]) -> Result<Vec<String>, QueryError> {
2047 let mut names = Vec::new();
2048 for (i, target) in targets.iter().enumerate() {
2049 if let Some(alias) = &target.alias {
2050 names.push(alias.clone());
2051 } else {
2052 names.push(self.expr_to_name(&target.expr, i));
2053 }
2054 }
2055 Ok(names)
2056 }
2057
2058 fn expr_to_name(&self, expr: &Expr, index: usize) -> String {
2060 match expr {
2061 Expr::Wildcard => "*".to_string(),
2062 Expr::Column(name) => name.clone(),
2063 Expr::Function(func) => func.name.clone(),
2064 Expr::Window(wf) => wf.name.clone(),
2065 _ => format!("col{index}"),
2066 }
2067 }
2068
2069 fn evaluate_row(&self, targets: &[Target], ctx: &PostingContext) -> Result<Row, QueryError> {
2071 self.evaluate_row_with_window(targets, ctx, None)
2072 }
2073
2074 fn evaluate_row_with_window(
2076 &self,
2077 targets: &[Target],
2078 ctx: &PostingContext,
2079 window_ctx: Option<&WindowContext>,
2080 ) -> Result<Row, QueryError> {
2081 let mut row = Vec::new();
2082 for target in targets {
2083 if matches!(target.expr, Expr::Wildcard) {
2084 row.push(Value::Date(ctx.transaction.date));
2086 row.push(Value::String(ctx.transaction.flag.to_string()));
2087 row.push(
2088 ctx.transaction
2089 .payee
2090 .as_ref()
2091 .map_or(Value::Null, |p| Value::String(p.to_string())),
2092 );
2093 row.push(Value::String(ctx.transaction.narration.to_string()));
2094 let posting = &ctx.transaction.postings[ctx.posting_index];
2095 row.push(Value::String(posting.account.to_string()));
2096 row.push(
2097 posting
2098 .amount()
2099 .map_or(Value::Null, |u| Value::Amount(u.clone())),
2100 );
2101 } else if let Expr::Window(wf) = &target.expr {
2102 row.push(self.evaluate_window_function(wf, window_ctx)?);
2104 } else {
2105 row.push(self.evaluate_expr(&target.expr, ctx)?);
2106 }
2107 }
2108 Ok(row)
2109 }
2110
2111 fn evaluate_window_function(
2113 &self,
2114 wf: &WindowFunction,
2115 window_ctx: Option<&WindowContext>,
2116 ) -> Result<Value, QueryError> {
2117 let ctx = window_ctx.ok_or_else(|| {
2118 QueryError::Evaluation("Window function requires window context".to_string())
2119 })?;
2120
2121 match wf.name.to_uppercase().as_str() {
2122 "ROW_NUMBER" => Ok(Value::Integer(ctx.row_number as i64)),
2123 "RANK" => Ok(Value::Integer(ctx.rank as i64)),
2124 "DENSE_RANK" => Ok(Value::Integer(ctx.dense_rank as i64)),
2125 _ => Err(QueryError::Evaluation(format!(
2126 "Window function '{}' not yet implemented",
2127 wf.name
2128 ))),
2129 }
2130 }
2131
2132 fn compute_window_contexts(
2134 &self,
2135 postings: &[PostingContext],
2136 wf: &WindowFunction,
2137 ) -> Result<Vec<WindowContext>, QueryError> {
2138 let spec = &wf.over;
2139
2140 let mut partition_keys: Vec<String> = Vec::with_capacity(postings.len());
2142 for ctx in postings {
2143 if let Some(partition_exprs) = &spec.partition_by {
2144 let mut key_values = Vec::new();
2145 for expr in partition_exprs {
2146 key_values.push(self.evaluate_expr(expr, ctx)?);
2147 }
2148 partition_keys.push(Self::make_group_key(&key_values));
2149 } else {
2150 partition_keys.push(String::new());
2152 }
2153 }
2154
2155 let mut partitions: HashMap<String, Vec<usize>> = HashMap::new();
2157 for (idx, key) in partition_keys.iter().enumerate() {
2158 partitions.entry(key.clone()).or_default().push(idx);
2159 }
2160
2161 let mut order_values: Vec<Vec<Value>> = Vec::with_capacity(postings.len());
2163 for ctx in postings {
2164 if let Some(order_specs) = &spec.order_by {
2165 let mut values = Vec::new();
2166 for order_spec in order_specs {
2167 values.push(self.evaluate_expr(&order_spec.expr, ctx)?);
2168 }
2169 order_values.push(values);
2170 } else {
2171 order_values.push(Vec::new());
2172 }
2173 }
2174
2175 let mut window_contexts: Vec<WindowContext> = vec![
2177 WindowContext {
2178 row_number: 0,
2179 rank: 0,
2180 dense_rank: 0,
2181 };
2182 postings.len()
2183 ];
2184
2185 for indices in partitions.values() {
2187 let mut sorted_indices: Vec<usize> = indices.clone();
2189 if spec.order_by.is_some() {
2190 let order_specs = spec.order_by.as_ref().unwrap();
2191 sorted_indices.sort_by(|&a, &b| {
2192 let vals_a = &order_values[a];
2193 let vals_b = &order_values[b];
2194 for (i, (va, vb)) in vals_a.iter().zip(vals_b.iter()).enumerate() {
2195 let cmp = self.compare_values_for_sort(va, vb);
2196 if cmp != std::cmp::Ordering::Equal {
2197 return if order_specs
2198 .get(i)
2199 .is_some_and(|s| s.direction == SortDirection::Desc)
2200 {
2201 cmp.reverse()
2202 } else {
2203 cmp
2204 };
2205 }
2206 }
2207 std::cmp::Ordering::Equal
2208 });
2209 }
2210
2211 let mut row_num = 1;
2213 let mut rank = 1;
2214 let mut dense_rank = 1;
2215 let mut prev_values: Option<&Vec<Value>> = None;
2216
2217 for (position, &original_idx) in sorted_indices.iter().enumerate() {
2218 let current_values = &order_values[original_idx];
2219
2220 let is_tie = if let Some(prev) = prev_values {
2222 current_values == prev
2223 } else {
2224 false
2225 };
2226
2227 if !is_tie && position > 0 {
2228 rank = position + 1;
2230 dense_rank += 1;
2231 }
2232 window_contexts[original_idx] = WindowContext {
2233 row_number: row_num,
2234 rank,
2235 dense_rank,
2236 };
2237
2238 row_num += 1;
2239 prev_values = Some(current_values);
2240 }
2241 }
2242
2243 Ok(window_contexts)
2244 }
2245
2246 fn find_window_function(targets: &[Target]) -> Option<&WindowFunction> {
2248 for target in targets {
2249 if let Expr::Window(wf) = &target.expr {
2250 return Some(wf);
2251 }
2252 }
2253 None
2254 }
2255
2256 fn make_group_key(values: &[Value]) -> String {
2259 use std::fmt::Write;
2260 let mut key = String::new();
2261 for (i, v) in values.iter().enumerate() {
2262 if i > 0 {
2263 key.push('\x00'); }
2265 match v {
2266 Value::String(s) => {
2267 key.push('S');
2268 key.push_str(s);
2269 }
2270 Value::Number(n) => {
2271 key.push('N');
2272 let _ = write!(key, "{n}");
2273 }
2274 Value::Integer(n) => {
2275 key.push('I');
2276 let _ = write!(key, "{n}");
2277 }
2278 Value::Date(d) => {
2279 key.push('D');
2280 let _ = write!(key, "{d}");
2281 }
2282 Value::Boolean(b) => {
2283 key.push(if *b { 'T' } else { 'F' });
2284 }
2285 Value::Amount(a) => {
2286 key.push('A');
2287 let _ = write!(key, "{} {}", a.number, a.currency);
2288 }
2289 Value::Position(p) => {
2290 key.push('P');
2291 let _ = write!(key, "{} {}", p.units.number, p.units.currency);
2292 }
2293 Value::Inventory(_) => {
2294 key.push('V');
2297 }
2298 Value::StringSet(ss) => {
2299 key.push('Z');
2300 for s in ss {
2301 key.push_str(s);
2302 key.push(',');
2303 }
2304 }
2305 Value::Null => {
2306 key.push('0');
2307 }
2308 }
2309 }
2310 key
2311 }
2312
2313 fn group_postings<'b>(
2316 &self,
2317 postings: &'b [PostingContext<'a>],
2318 group_by: Option<&Vec<Expr>>,
2319 ) -> Result<Vec<(Vec<Value>, Vec<&'b PostingContext<'a>>)>, QueryError> {
2320 if let Some(group_exprs) = group_by {
2321 let mut group_map: HashMap<String, (Vec<Value>, Vec<&PostingContext<'a>>)> =
2323 HashMap::new();
2324
2325 for ctx in postings {
2326 let mut key_values = Vec::with_capacity(group_exprs.len());
2327 for expr in group_exprs {
2328 key_values.push(self.evaluate_expr(expr, ctx)?);
2329 }
2330 let hash_key = Self::make_group_key(&key_values);
2331
2332 group_map
2333 .entry(hash_key)
2334 .or_insert_with(|| (key_values, Vec::new()))
2335 .1
2336 .push(ctx);
2337 }
2338
2339 Ok(group_map.into_values().collect())
2340 } else {
2341 Ok(vec![(Vec::new(), postings.iter().collect())])
2343 }
2344 }
2345
2346 fn evaluate_aggregate_row(
2348 &self,
2349 targets: &[Target],
2350 group: &[&PostingContext],
2351 ) -> Result<Row, QueryError> {
2352 let mut row = Vec::new();
2353 for target in targets {
2354 row.push(self.evaluate_aggregate_expr(&target.expr, group)?);
2355 }
2356 Ok(row)
2357 }
2358
2359 fn evaluate_aggregate_expr(
2361 &self,
2362 expr: &Expr,
2363 group: &[&PostingContext],
2364 ) -> Result<Value, QueryError> {
2365 match expr {
2366 Expr::Function(func) => {
2367 match func.name.to_uppercase().as_str() {
2368 "COUNT" => {
2369 Ok(Value::Integer(group.len() as i64))
2371 }
2372 "SUM" => {
2373 if func.args.len() != 1 {
2374 return Err(QueryError::InvalidArguments(
2375 "SUM".to_string(),
2376 "expected 1 argument".to_string(),
2377 ));
2378 }
2379 let mut total = Inventory::new();
2380 for ctx in group {
2381 let val = self.evaluate_expr(&func.args[0], ctx)?;
2382 match val {
2383 Value::Amount(amt) => {
2384 let pos = Position::simple(amt);
2385 total.add(pos);
2386 }
2387 Value::Position(pos) => {
2388 total.add(pos);
2389 }
2390 Value::Number(n) => {
2391 let pos =
2393 Position::simple(Amount::new(n, "__NUMBER__".to_string()));
2394 total.add(pos);
2395 }
2396 Value::Null => {}
2397 _ => {
2398 return Err(QueryError::Type(
2399 "SUM requires numeric or position value".to_string(),
2400 ));
2401 }
2402 }
2403 }
2404 Ok(Value::Inventory(total))
2405 }
2406 "FIRST" => {
2407 if func.args.len() != 1 {
2408 return Err(QueryError::InvalidArguments(
2409 "FIRST".to_string(),
2410 "expected 1 argument".to_string(),
2411 ));
2412 }
2413 if let Some(ctx) = group.first() {
2414 self.evaluate_expr(&func.args[0], ctx)
2415 } else {
2416 Ok(Value::Null)
2417 }
2418 }
2419 "LAST" => {
2420 if func.args.len() != 1 {
2421 return Err(QueryError::InvalidArguments(
2422 "LAST".to_string(),
2423 "expected 1 argument".to_string(),
2424 ));
2425 }
2426 if let Some(ctx) = group.last() {
2427 self.evaluate_expr(&func.args[0], ctx)
2428 } else {
2429 Ok(Value::Null)
2430 }
2431 }
2432 "MIN" => {
2433 if func.args.len() != 1 {
2434 return Err(QueryError::InvalidArguments(
2435 "MIN".to_string(),
2436 "expected 1 argument".to_string(),
2437 ));
2438 }
2439 let mut min_val: Option<Value> = None;
2440 for ctx in group {
2441 let val = self.evaluate_expr(&func.args[0], ctx)?;
2442 if matches!(val, Value::Null) {
2443 continue;
2444 }
2445 min_val = Some(match min_val {
2446 None => val,
2447 Some(current) => {
2448 if self.value_less_than(&val, ¤t)? {
2449 val
2450 } else {
2451 current
2452 }
2453 }
2454 });
2455 }
2456 Ok(min_val.unwrap_or(Value::Null))
2457 }
2458 "MAX" => {
2459 if func.args.len() != 1 {
2460 return Err(QueryError::InvalidArguments(
2461 "MAX".to_string(),
2462 "expected 1 argument".to_string(),
2463 ));
2464 }
2465 let mut max_val: Option<Value> = None;
2466 for ctx in group {
2467 let val = self.evaluate_expr(&func.args[0], ctx)?;
2468 if matches!(val, Value::Null) {
2469 continue;
2470 }
2471 max_val = Some(match max_val {
2472 None => val,
2473 Some(current) => {
2474 if self.value_less_than(¤t, &val)? {
2475 val
2476 } else {
2477 current
2478 }
2479 }
2480 });
2481 }
2482 Ok(max_val.unwrap_or(Value::Null))
2483 }
2484 "AVG" => {
2485 if func.args.len() != 1 {
2486 return Err(QueryError::InvalidArguments(
2487 "AVG".to_string(),
2488 "expected 1 argument".to_string(),
2489 ));
2490 }
2491 let mut sum = Decimal::ZERO;
2492 let mut count = 0i64;
2493 for ctx in group {
2494 let val = self.evaluate_expr(&func.args[0], ctx)?;
2495 match val {
2496 Value::Number(n) => {
2497 sum += n;
2498 count += 1;
2499 }
2500 Value::Integer(i) => {
2501 sum += Decimal::from(i);
2502 count += 1;
2503 }
2504 Value::Null => {}
2505 _ => {
2506 return Err(QueryError::Type(
2507 "AVG expects numeric values".to_string(),
2508 ));
2509 }
2510 }
2511 }
2512 if count == 0 {
2513 Ok(Value::Null)
2514 } else {
2515 Ok(Value::Number(sum / Decimal::from(count)))
2516 }
2517 }
2518 _ => {
2519 if let Some(ctx) = group.first() {
2521 self.evaluate_function(func, ctx)
2522 } else {
2523 Ok(Value::Null)
2524 }
2525 }
2526 }
2527 }
2528 Expr::Column(_) => {
2529 if let Some(ctx) = group.first() {
2531 self.evaluate_expr(expr, ctx)
2532 } else {
2533 Ok(Value::Null)
2534 }
2535 }
2536 Expr::BinaryOp(op) => {
2537 let left = self.evaluate_aggregate_expr(&op.left, group)?;
2538 let right = self.evaluate_aggregate_expr(&op.right, group)?;
2539 self.binary_op_on_values(op.op, &left, &right)
2541 }
2542 _ => {
2543 if let Some(ctx) = group.first() {
2545 self.evaluate_expr(expr, ctx)
2546 } else {
2547 Ok(Value::Null)
2548 }
2549 }
2550 }
2551 }
2552
2553 fn binary_op_on_values(
2555 &self,
2556 op: BinaryOperator,
2557 left: &Value,
2558 right: &Value,
2559 ) -> Result<Value, QueryError> {
2560 match op {
2561 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(left, right))),
2562 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(left, right))),
2563 BinaryOperator::Lt => self.compare_values(left, right, std::cmp::Ordering::is_lt),
2564 BinaryOperator::Le => self.compare_values(left, right, std::cmp::Ordering::is_le),
2565 BinaryOperator::Gt => self.compare_values(left, right, std::cmp::Ordering::is_gt),
2566 BinaryOperator::Ge => self.compare_values(left, right, std::cmp::Ordering::is_ge),
2567 BinaryOperator::And => {
2568 let l = self.to_bool(left)?;
2569 let r = self.to_bool(right)?;
2570 Ok(Value::Boolean(l && r))
2571 }
2572 BinaryOperator::Or => {
2573 let l = self.to_bool(left)?;
2574 let r = self.to_bool(right)?;
2575 Ok(Value::Boolean(l || r))
2576 }
2577 BinaryOperator::Regex => {
2578 let s = match left {
2580 Value::String(s) => s,
2581 _ => {
2582 return Err(QueryError::Type(
2583 "regex requires string left operand".to_string(),
2584 ));
2585 }
2586 };
2587 let pattern = match right {
2588 Value::String(p) => p,
2589 _ => {
2590 return Err(QueryError::Type(
2591 "regex requires string pattern".to_string(),
2592 ));
2593 }
2594 };
2595 let regex_result = self.get_or_compile_regex(pattern);
2597 let matches = if let Some(regex) = regex_result {
2598 regex.is_match(s)
2599 } else {
2600 s.contains(pattern)
2601 };
2602 Ok(Value::Boolean(matches))
2603 }
2604 BinaryOperator::In => {
2605 match right {
2607 Value::StringSet(set) => {
2608 let needle = match left {
2609 Value::String(s) => s,
2610 _ => {
2611 return Err(QueryError::Type(
2612 "IN requires string left operand".to_string(),
2613 ));
2614 }
2615 };
2616 Ok(Value::Boolean(set.contains(needle)))
2617 }
2618 _ => Err(QueryError::Type(
2619 "IN requires set right operand".to_string(),
2620 )),
2621 }
2622 }
2623 BinaryOperator::Add => self.arithmetic_op(left, right, |a, b| a + b),
2624 BinaryOperator::Sub => self.arithmetic_op(left, right, |a, b| a - b),
2625 BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
2626 BinaryOperator::Div => self.arithmetic_op(left, right, |a, b| a / b),
2627 }
2628 }
2629
2630 fn sort_results(
2632 &self,
2633 result: &mut QueryResult,
2634 order_by: &[OrderSpec],
2635 ) -> Result<(), QueryError> {
2636 if order_by.is_empty() {
2637 return Ok(());
2638 }
2639
2640 let column_indices: std::collections::HashMap<&str, usize> = result
2642 .columns
2643 .iter()
2644 .enumerate()
2645 .map(|(i, name)| (name.as_str(), i))
2646 .collect();
2647
2648 let mut sort_specs: Vec<(usize, bool)> = Vec::new();
2650 for spec in order_by {
2651 let idx = match &spec.expr {
2653 Expr::Column(name) => column_indices
2654 .get(name.as_str())
2655 .copied()
2656 .ok_or_else(|| QueryError::UnknownColumn(name.clone()))?,
2657 Expr::Function(func) => {
2658 column_indices
2660 .get(func.name.as_str())
2661 .copied()
2662 .ok_or_else(|| {
2663 QueryError::Evaluation(format!(
2664 "ORDER BY expression not found in SELECT: {}",
2665 func.name
2666 ))
2667 })?
2668 }
2669 _ => {
2670 return Err(QueryError::Evaluation(
2671 "ORDER BY expression must reference a selected column".to_string(),
2672 ));
2673 }
2674 };
2675 let ascending = spec.direction != SortDirection::Desc;
2676 sort_specs.push((idx, ascending));
2677 }
2678
2679 result.rows.sort_by(|a, b| {
2681 for (idx, ascending) in &sort_specs {
2682 if *idx >= a.len() || *idx >= b.len() {
2683 continue;
2684 }
2685 let ord = self.compare_values_for_sort(&a[*idx], &b[*idx]);
2686 if ord != std::cmp::Ordering::Equal {
2687 return if *ascending { ord } else { ord.reverse() };
2688 }
2689 }
2690 std::cmp::Ordering::Equal
2691 });
2692
2693 Ok(())
2694 }
2695
2696 fn compare_values_for_sort(&self, left: &Value, right: &Value) -> std::cmp::Ordering {
2698 match (left, right) {
2699 (Value::Null, Value::Null) => std::cmp::Ordering::Equal,
2700 (Value::Null, _) => std::cmp::Ordering::Greater, (_, Value::Null) => std::cmp::Ordering::Less,
2702 (Value::Number(a), Value::Number(b)) => a.cmp(b),
2703 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
2704 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
2705 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
2706 (Value::String(a), Value::String(b)) => a.cmp(b),
2707 (Value::Date(a), Value::Date(b)) => a.cmp(b),
2708 (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
2709 (Value::Amount(a), Value::Amount(b)) => a.number.cmp(&b.number),
2711 (Value::Position(a), Value::Position(b)) => a.units.number.cmp(&b.units.number),
2713 (Value::Inventory(a), Value::Inventory(b)) => {
2715 let a_val = a.positions().first().map(|p| &p.units.number);
2716 let b_val = b.positions().first().map(|p| &p.units.number);
2717 match (a_val, b_val) {
2718 (Some(av), Some(bv)) => av.cmp(bv),
2719 (Some(_), None) => std::cmp::Ordering::Less,
2720 (None, Some(_)) => std::cmp::Ordering::Greater,
2721 (None, None) => std::cmp::Ordering::Equal,
2722 }
2723 }
2724 _ => std::cmp::Ordering::Equal, }
2726 }
2727
2728 fn evaluate_having_filter(
2734 &self,
2735 having_expr: &Expr,
2736 row: &[Value],
2737 column_names: &[String],
2738 targets: &[Target],
2739 group: &[&PostingContext],
2740 ) -> Result<bool, QueryError> {
2741 let col_map: HashMap<String, usize> = column_names
2743 .iter()
2744 .enumerate()
2745 .map(|(i, name)| (name.to_uppercase(), i))
2746 .collect();
2747
2748 let alias_map: HashMap<String, usize> = targets
2750 .iter()
2751 .enumerate()
2752 .filter_map(|(i, t)| t.alias.as_ref().map(|a| (a.to_uppercase(), i)))
2753 .collect();
2754
2755 let val = self.evaluate_having_expr(having_expr, row, &col_map, &alias_map, group)?;
2756
2757 match val {
2758 Value::Boolean(b) => Ok(b),
2759 Value::Null => Ok(false), _ => Err(QueryError::Type(
2761 "HAVING clause must evaluate to boolean".to_string(),
2762 )),
2763 }
2764 }
2765
2766 fn evaluate_having_expr(
2768 &self,
2769 expr: &Expr,
2770 row: &[Value],
2771 col_map: &HashMap<String, usize>,
2772 alias_map: &HashMap<String, usize>,
2773 group: &[&PostingContext],
2774 ) -> Result<Value, QueryError> {
2775 match expr {
2776 Expr::Column(name) => {
2777 let upper_name = name.to_uppercase();
2778 if let Some(&idx) = alias_map.get(&upper_name) {
2780 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2781 } else if let Some(&idx) = col_map.get(&upper_name) {
2782 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2783 } else {
2784 Err(QueryError::Evaluation(format!(
2785 "Column '{name}' not found in SELECT clause for HAVING"
2786 )))
2787 }
2788 }
2789 Expr::Literal(lit) => self.evaluate_literal(lit),
2790 Expr::Function(_) => {
2791 self.evaluate_aggregate_expr(expr, group)
2793 }
2794 Expr::BinaryOp(op) => {
2795 let left = self.evaluate_having_expr(&op.left, row, col_map, alias_map, group)?;
2796 let right = self.evaluate_having_expr(&op.right, row, col_map, alias_map, group)?;
2797 self.binary_op_on_values(op.op, &left, &right)
2798 }
2799 Expr::UnaryOp(op) => {
2800 let val = self.evaluate_having_expr(&op.operand, row, col_map, alias_map, group)?;
2801 match op.op {
2802 UnaryOperator::Not => {
2803 let b = self.to_bool(&val)?;
2804 Ok(Value::Boolean(!b))
2805 }
2806 UnaryOperator::Neg => match val {
2807 Value::Number(n) => Ok(Value::Number(-n)),
2808 Value::Integer(i) => Ok(Value::Integer(-i)),
2809 _ => Err(QueryError::Type(
2810 "Cannot negate non-numeric value".to_string(),
2811 )),
2812 },
2813 }
2814 }
2815 Expr::Paren(inner) => self.evaluate_having_expr(inner, row, col_map, alias_map, group),
2816 Expr::Wildcard => Err(QueryError::Evaluation(
2817 "Wildcard not allowed in HAVING clause".to_string(),
2818 )),
2819 Expr::Window(_) => Err(QueryError::Evaluation(
2820 "Window functions not allowed in HAVING clause".to_string(),
2821 )),
2822 }
2823 }
2824
2825 fn apply_pivot(
2831 &self,
2832 result: &QueryResult,
2833 pivot_exprs: &[Expr],
2834 _targets: &[Target],
2835 ) -> Result<QueryResult, QueryError> {
2836 if pivot_exprs.is_empty() {
2837 return Ok(result.clone());
2838 }
2839
2840 let pivot_expr = &pivot_exprs[0];
2843
2844 let pivot_col_idx = self.find_pivot_column(result, pivot_expr)?;
2846
2847 let mut pivot_values: Vec<Value> = result
2849 .rows
2850 .iter()
2851 .map(|row| row.get(pivot_col_idx).cloned().unwrap_or(Value::Null))
2852 .collect();
2853 pivot_values.sort_by(|a, b| self.compare_values_for_sort(a, b));
2854 pivot_values.dedup();
2855
2856 let mut new_columns: Vec<String> = result
2858 .columns
2859 .iter()
2860 .enumerate()
2861 .filter(|(i, _)| *i != pivot_col_idx)
2862 .map(|(_, c)| c.clone())
2863 .collect();
2864
2865 let value_col_idx = result.columns.len() - 1;
2867
2868 for pv in &pivot_values {
2870 new_columns.push(self.value_to_string(pv));
2871 }
2872
2873 let mut new_result = QueryResult::new(new_columns);
2874
2875 let group_cols: Vec<usize> = (0..result.columns.len())
2877 .filter(|i| *i != pivot_col_idx && *i != value_col_idx)
2878 .collect();
2879
2880 let mut groups: HashMap<String, Vec<&Row>> = HashMap::new();
2881 for row in &result.rows {
2882 let key: String = group_cols
2883 .iter()
2884 .map(|&i| self.value_to_string(&row[i]))
2885 .collect::<Vec<_>>()
2886 .join("|");
2887 groups.entry(key).or_default().push(row);
2888 }
2889
2890 for (_key, group_rows) in groups {
2892 let mut new_row: Vec<Value> = group_cols
2893 .iter()
2894 .map(|&i| group_rows[0][i].clone())
2895 .collect();
2896
2897 let pivot_index: HashMap<u64, usize> = group_rows
2899 .iter()
2900 .enumerate()
2901 .filter_map(|(idx, row)| {
2902 row.get(pivot_col_idx).map(|v| (hash_single_value(v), idx))
2903 })
2904 .collect();
2905
2906 for pv in &pivot_values {
2908 let pv_hash = hash_single_value(pv);
2909 if let Some(&row_idx) = pivot_index.get(&pv_hash) {
2910 new_row.push(
2911 group_rows[row_idx]
2912 .get(value_col_idx)
2913 .cloned()
2914 .unwrap_or(Value::Null),
2915 );
2916 } else {
2917 new_row.push(Value::Null);
2918 }
2919 }
2920
2921 new_result.add_row(new_row);
2922 }
2923
2924 Ok(new_result)
2925 }
2926
2927 fn find_pivot_column(
2929 &self,
2930 result: &QueryResult,
2931 pivot_expr: &Expr,
2932 ) -> Result<usize, QueryError> {
2933 match pivot_expr {
2934 Expr::Column(name) => {
2935 let upper_name = name.to_uppercase();
2936 result
2937 .columns
2938 .iter()
2939 .position(|c| c.to_uppercase() == upper_name)
2940 .ok_or_else(|| {
2941 QueryError::Evaluation(format!(
2942 "PIVOT BY column '{name}' not found in SELECT"
2943 ))
2944 })
2945 }
2946 Expr::Literal(Literal::Integer(n)) => {
2947 let idx = (*n as usize).saturating_sub(1);
2948 if idx < result.columns.len() {
2949 Ok(idx)
2950 } else {
2951 Err(QueryError::Evaluation(format!(
2952 "PIVOT BY column index {n} out of range"
2953 )))
2954 }
2955 }
2956 Expr::Literal(Literal::Number(n)) => {
2957 use rust_decimal::prelude::ToPrimitive;
2959 let idx = n.to_usize().unwrap_or(0).saturating_sub(1);
2960 if idx < result.columns.len() {
2961 Ok(idx)
2962 } else {
2963 Err(QueryError::Evaluation(format!(
2964 "PIVOT BY column index {n} out of range"
2965 )))
2966 }
2967 }
2968 _ => {
2969 Err(QueryError::Evaluation(
2972 "PIVOT BY must reference a column name or index".to_string(),
2973 ))
2974 }
2975 }
2976 }
2977
2978 fn value_to_string(&self, val: &Value) -> String {
2980 match val {
2981 Value::String(s) => s.clone(),
2982 Value::Number(n) => n.to_string(),
2983 Value::Integer(i) => i.to_string(),
2984 Value::Date(d) => d.to_string(),
2985 Value::Boolean(b) => b.to_string(),
2986 Value::Amount(a) => format!("{} {}", a.number, a.currency),
2987 Value::Position(p) => format!("{}", p.units),
2988 Value::Inventory(inv) => inv
2989 .positions()
2990 .iter()
2991 .map(|p| format!("{}", p.units))
2992 .collect::<Vec<_>>()
2993 .join(", "),
2994 Value::StringSet(ss) => ss.join(", "),
2995 Value::Null => "NULL".to_string(),
2996 }
2997 }
2998}
2999
3000#[cfg(test)]
3001mod tests {
3002 use super::*;
3003 use crate::parse;
3004 use rust_decimal_macros::dec;
3005 use rustledger_core::Posting;
3006
3007 fn date(year: i32, month: u32, day: u32) -> NaiveDate {
3008 NaiveDate::from_ymd_opt(year, month, day).unwrap()
3009 }
3010
3011 fn sample_directives() -> Vec<Directive> {
3012 vec![
3013 Directive::Transaction(
3014 Transaction::new(date(2024, 1, 15), "Coffee")
3015 .with_flag('*')
3016 .with_payee("Coffee Shop")
3017 .with_posting(Posting::new(
3018 "Expenses:Food:Coffee",
3019 Amount::new(dec!(5.00), "USD"),
3020 ))
3021 .with_posting(Posting::new(
3022 "Assets:Bank:Checking",
3023 Amount::new(dec!(-5.00), "USD"),
3024 )),
3025 ),
3026 Directive::Transaction(
3027 Transaction::new(date(2024, 1, 16), "Groceries")
3028 .with_flag('*')
3029 .with_payee("Supermarket")
3030 .with_posting(Posting::new(
3031 "Expenses:Food:Groceries",
3032 Amount::new(dec!(50.00), "USD"),
3033 ))
3034 .with_posting(Posting::new(
3035 "Assets:Bank:Checking",
3036 Amount::new(dec!(-50.00), "USD"),
3037 )),
3038 ),
3039 ]
3040 }
3041
3042 #[test]
3043 fn test_simple_select() {
3044 let directives = sample_directives();
3045 let mut executor = Executor::new(&directives);
3046
3047 let query = parse("SELECT date, account").unwrap();
3048 let result = executor.execute(&query).unwrap();
3049
3050 assert_eq!(result.columns, vec!["date", "account"]);
3051 assert_eq!(result.len(), 4); }
3053
3054 #[test]
3055 fn test_where_clause() {
3056 let directives = sample_directives();
3057 let mut executor = Executor::new(&directives);
3058
3059 let query = parse("SELECT account WHERE account ~ \"Expenses:\"").unwrap();
3060 let result = executor.execute(&query).unwrap();
3061
3062 assert_eq!(result.len(), 2); }
3064
3065 #[test]
3066 fn test_balances() {
3067 let directives = sample_directives();
3068 let mut executor = Executor::new(&directives);
3069
3070 let query = parse("BALANCES").unwrap();
3071 let result = executor.execute(&query).unwrap();
3072
3073 assert_eq!(result.columns, vec!["account", "balance"]);
3074 assert!(result.len() >= 3); }
3076
3077 #[test]
3078 fn test_account_functions() {
3079 let directives = sample_directives();
3080 let mut executor = Executor::new(&directives);
3081
3082 let query = parse("SELECT DISTINCT LEAF(account) WHERE account ~ \"Expenses:\"").unwrap();
3084 let result = executor.execute(&query).unwrap();
3085 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT ROOT(account)").unwrap();
3089 let result = executor.execute(&query).unwrap();
3090 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT PARENT(account) WHERE account ~ \"Expenses:\"").unwrap();
3094 let result = executor.execute(&query).unwrap();
3095 assert!(!result.is_empty()); }
3097
3098 #[test]
3099 fn test_min_max_aggregate() {
3100 let directives = sample_directives();
3101 let mut executor = Executor::new(&directives);
3102
3103 let query = parse("SELECT MIN(date)").unwrap();
3105 let result = executor.execute(&query).unwrap();
3106 assert_eq!(result.len(), 1);
3107 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 15)));
3108
3109 let query = parse("SELECT MAX(date)").unwrap();
3111 let result = executor.execute(&query).unwrap();
3112 assert_eq!(result.len(), 1);
3113 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3114 }
3115
3116 #[test]
3117 fn test_order_by() {
3118 let directives = sample_directives();
3119 let mut executor = Executor::new(&directives);
3120
3121 let query = parse("SELECT date, account ORDER BY date DESC").unwrap();
3122 let result = executor.execute(&query).unwrap();
3123
3124 assert_eq!(result.len(), 4);
3126 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3128 }
3129
3130 #[test]
3131 fn test_hash_value_all_variants() {
3132 use rustledger_core::{Cost, Inventory, Position};
3133
3134 let values = vec![
3136 Value::String("test".to_string()),
3137 Value::Number(dec!(123.45)),
3138 Value::Integer(42),
3139 Value::Date(date(2024, 1, 15)),
3140 Value::Boolean(true),
3141 Value::Boolean(false),
3142 Value::Amount(Amount::new(dec!(100), "USD")),
3143 Value::Position(Position::simple(Amount::new(dec!(10), "AAPL"))),
3144 Value::Position(Position::with_cost(
3145 Amount::new(dec!(10), "AAPL"),
3146 Cost::new(dec!(150), "USD"),
3147 )),
3148 Value::Inventory(Inventory::new()),
3149 Value::StringSet(vec!["tag1".to_string(), "tag2".to_string()]),
3150 Value::Null,
3151 ];
3152
3153 for value in &values {
3155 let hash = hash_single_value(value);
3156 assert!(hash != 0 || matches!(value, Value::Null));
3157 }
3158
3159 let hash1 = hash_single_value(&Value::String("a".to_string()));
3161 let hash2 = hash_single_value(&Value::String("b".to_string()));
3162 assert_ne!(hash1, hash2);
3163
3164 let hash3 = hash_single_value(&Value::Integer(42));
3166 let hash4 = hash_single_value(&Value::Integer(42));
3167 assert_eq!(hash3, hash4);
3168 }
3169
3170 #[test]
3171 fn test_hash_row_distinct() {
3172 let row1 = vec![Value::String("a".to_string()), Value::Integer(1)];
3174 let row2 = vec![Value::String("a".to_string()), Value::Integer(1)];
3175 let row3 = vec![Value::String("b".to_string()), Value::Integer(1)];
3176
3177 assert_eq!(hash_row(&row1), hash_row(&row2));
3178 assert_ne!(hash_row(&row1), hash_row(&row3));
3179 }
3180
3181 #[test]
3182 fn test_string_set_hash_order_independent() {
3183 let set1 = Value::StringSet(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
3185 let set2 = Value::StringSet(vec!["c".to_string(), "a".to_string(), "b".to_string()]);
3186 let set3 = Value::StringSet(vec!["b".to_string(), "c".to_string(), "a".to_string()]);
3187
3188 let hash1 = hash_single_value(&set1);
3189 let hash2 = hash_single_value(&set2);
3190 let hash3 = hash_single_value(&set3);
3191
3192 assert_eq!(hash1, hash2);
3193 assert_eq!(hash2, hash3);
3194 }
3195
3196 #[test]
3197 fn test_inventory_hash_includes_cost() {
3198 use rustledger_core::{Cost, Inventory, Position};
3199
3200 let mut inv1 = Inventory::new();
3202 inv1.add(Position::with_cost(
3203 Amount::new(dec!(10), "AAPL"),
3204 Cost::new(dec!(100), "USD"),
3205 ));
3206
3207 let mut inv2 = Inventory::new();
3208 inv2.add(Position::with_cost(
3209 Amount::new(dec!(10), "AAPL"),
3210 Cost::new(dec!(200), "USD"),
3211 ));
3212
3213 let hash1 = hash_single_value(&Value::Inventory(inv1));
3214 let hash2 = hash_single_value(&Value::Inventory(inv2));
3215
3216 assert_ne!(hash1, hash2);
3217 }
3218
3219 #[test]
3220 fn test_distinct_deduplication() {
3221 let directives = sample_directives();
3222 let mut executor = Executor::new(&directives);
3223
3224 let query = parse("SELECT flag").unwrap();
3226 let result = executor.execute(&query).unwrap();
3227 assert_eq!(result.len(), 4); let query = parse("SELECT DISTINCT flag").unwrap();
3231 let result = executor.execute(&query).unwrap();
3232 assert_eq!(result.len(), 1); }
3234
3235 #[test]
3236 fn test_limit_clause() {
3237 let directives = sample_directives();
3238 let mut executor = Executor::new(&directives);
3239
3240 let query = parse("SELECT date, account LIMIT 2").unwrap();
3242 let result = executor.execute(&query).unwrap();
3243 assert_eq!(result.len(), 2);
3244
3245 let query = parse("SELECT date LIMIT 0").unwrap();
3247 let result = executor.execute(&query).unwrap();
3248 assert_eq!(result.len(), 0);
3249
3250 let query = parse("SELECT date LIMIT 100").unwrap();
3252 let result = executor.execute(&query).unwrap();
3253 assert_eq!(result.len(), 4);
3254 }
3255
3256 #[test]
3257 fn test_group_by_with_count() {
3258 let directives = sample_directives();
3259 let mut executor = Executor::new(&directives);
3260
3261 let query = parse("SELECT ROOT(account), COUNT(account) GROUP BY ROOT(account)").unwrap();
3263 let result = executor.execute(&query).unwrap();
3264
3265 assert_eq!(result.columns.len(), 2);
3266 assert_eq!(result.len(), 2);
3268 }
3269
3270 #[test]
3271 fn test_count_aggregate() {
3272 let directives = sample_directives();
3273 let mut executor = Executor::new(&directives);
3274
3275 let query = parse("SELECT COUNT(account)").unwrap();
3277 let result = executor.execute(&query).unwrap();
3278
3279 assert_eq!(result.len(), 1);
3280 assert_eq!(result.rows[0][0], Value::Integer(4));
3281
3282 let query = parse("SELECT ROOT(account), COUNT(account) GROUP BY ROOT(account)").unwrap();
3284 let result = executor.execute(&query).unwrap();
3285 assert_eq!(result.len(), 2); }
3287
3288 #[test]
3289 fn test_journal_query() {
3290 let directives = sample_directives();
3291 let mut executor = Executor::new(&directives);
3292
3293 let query = parse("JOURNAL \"Expenses\"").unwrap();
3295 let result = executor.execute(&query).unwrap();
3296
3297 assert!(result.columns.contains(&"account".to_string()));
3299 assert_eq!(result.len(), 2);
3301 }
3302
3303 #[test]
3304 fn test_print_query() {
3305 let directives = sample_directives();
3306 let mut executor = Executor::new(&directives);
3307
3308 let query = parse("PRINT").unwrap();
3310 let result = executor.execute(&query).unwrap();
3311
3312 assert_eq!(result.columns.len(), 1);
3314 assert_eq!(result.columns[0], "directive");
3315 assert_eq!(result.len(), 2);
3317 }
3318
3319 #[test]
3320 fn test_empty_directives() {
3321 let directives: Vec<Directive> = vec![];
3322 let mut executor = Executor::new(&directives);
3323
3324 let query = parse("SELECT date, account").unwrap();
3326 let result = executor.execute(&query).unwrap();
3327 assert!(result.is_empty());
3328
3329 let query = parse("BALANCES").unwrap();
3331 let result = executor.execute(&query).unwrap();
3332 assert!(result.is_empty());
3333 }
3334
3335 #[test]
3336 fn test_comparison_operators() {
3337 let directives = sample_directives();
3338 let mut executor = Executor::new(&directives);
3339
3340 let query = parse("SELECT date WHERE date < 2024-01-16").unwrap();
3342 let result = executor.execute(&query).unwrap();
3343 assert_eq!(result.len(), 2); let query = parse("SELECT date WHERE year > 2023").unwrap();
3347 let result = executor.execute(&query).unwrap();
3348 assert_eq!(result.len(), 4); let query = parse("SELECT account WHERE day = 15").unwrap();
3352 let result = executor.execute(&query).unwrap();
3353 assert_eq!(result.len(), 2); }
3355
3356 #[test]
3357 fn test_logical_operators() {
3358 let directives = sample_directives();
3359 let mut executor = Executor::new(&directives);
3360
3361 let query = parse("SELECT account WHERE account ~ \"Expenses\" AND day > 14").unwrap();
3363 let result = executor.execute(&query).unwrap();
3364 assert_eq!(result.len(), 2); let query = parse("SELECT account WHERE day = 15 OR day = 16").unwrap();
3368 let result = executor.execute(&query).unwrap();
3369 assert_eq!(result.len(), 4); }
3371
3372 #[test]
3373 fn test_arithmetic_expressions() {
3374 let directives = sample_directives();
3375 let mut executor = Executor::new(&directives);
3376
3377 let query = parse("SELECT -day WHERE day = 15").unwrap();
3379 let result = executor.execute(&query).unwrap();
3380 assert_eq!(result.len(), 2);
3381 for row in &result.rows {
3383 if let Value::Integer(n) = &row[0] {
3384 assert_eq!(*n, -15);
3385 }
3386 }
3387 }
3388
3389 #[test]
3390 fn test_first_last_aggregates() {
3391 let directives = sample_directives();
3392 let mut executor = Executor::new(&directives);
3393
3394 let query = parse("SELECT FIRST(date)").unwrap();
3396 let result = executor.execute(&query).unwrap();
3397 assert_eq!(result.len(), 1);
3398 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 15)));
3399
3400 let query = parse("SELECT LAST(date)").unwrap();
3402 let result = executor.execute(&query).unwrap();
3403 assert_eq!(result.len(), 1);
3404 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3405 }
3406
3407 #[test]
3408 fn test_wildcard_select() {
3409 let directives = sample_directives();
3410 let mut executor = Executor::new(&directives);
3411
3412 let query = parse("SELECT *").unwrap();
3414 let result = executor.execute(&query).unwrap();
3415
3416 assert_eq!(result.columns, vec!["*"]);
3418 assert_eq!(result.len(), 4);
3420 assert_eq!(result.rows[0].len(), 6); }
3422
3423 #[test]
3424 fn test_query_result_methods() {
3425 let mut result = QueryResult::new(vec!["col1".to_string(), "col2".to_string()]);
3426
3427 assert!(result.is_empty());
3429 assert_eq!(result.len(), 0);
3430
3431 result.add_row(vec![Value::Integer(1), Value::String("a".to_string())]);
3433 assert!(!result.is_empty());
3434 assert_eq!(result.len(), 1);
3435
3436 result.add_row(vec![Value::Integer(2), Value::String("b".to_string())]);
3437 assert_eq!(result.len(), 2);
3438 }
3439}