1use std::cell::RefCell;
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8
9use chrono::Datelike;
10use regex::Regex;
11use rust_decimal::Decimal;
12use rustledger_core::{
13 Amount, Directive, InternedStr, Inventory, NaiveDate, Position, Transaction,
14};
15
16use crate::ast::{
17 BalancesQuery, BinaryOp, BinaryOperator, Expr, FromClause, FunctionCall, JournalQuery, Literal,
18 OrderSpec, PrintQuery, Query, SelectQuery, SortDirection, Target, UnaryOp, UnaryOperator,
19 WindowFunction,
20};
21use crate::error::QueryError;
22
23#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum Value {
26 String(String),
28 Number(Decimal),
30 Integer(i64),
32 Date(NaiveDate),
34 Boolean(bool),
36 Amount(Amount),
38 Position(Position),
40 Inventory(Inventory),
42 StringSet(Vec<String>),
44 Null,
46}
47
48impl Value {
49 fn hash_value<H: Hasher>(&self, state: &mut H) {
55 std::mem::discriminant(self).hash(state);
56 match self {
57 Self::String(s) => s.hash(state),
58 Self::Number(d) => d.serialize().hash(state),
59 Self::Integer(i) => i.hash(state),
60 Self::Date(d) => {
61 d.year().hash(state);
62 d.month().hash(state);
63 d.day().hash(state);
64 }
65 Self::Boolean(b) => b.hash(state),
66 Self::Amount(a) => {
67 a.number.serialize().hash(state);
68 a.currency.as_str().hash(state);
69 }
70 Self::Position(p) => {
71 p.units.number.serialize().hash(state);
72 p.units.currency.as_str().hash(state);
73 if let Some(cost) = &p.cost {
74 cost.number.serialize().hash(state);
75 cost.currency.as_str().hash(state);
76 }
77 }
78 Self::Inventory(inv) => {
79 for pos in inv.positions() {
80 pos.units.number.serialize().hash(state);
81 pos.units.currency.as_str().hash(state);
82 if let Some(cost) = &pos.cost {
83 cost.number.serialize().hash(state);
84 cost.currency.as_str().hash(state);
85 }
86 }
87 }
88 Self::StringSet(ss) => {
89 let mut sorted = ss.clone();
91 sorted.sort();
92 for s in &sorted {
93 s.hash(state);
94 }
95 }
96 Self::Null => {}
97 }
98 }
99}
100
101fn hash_row(row: &Row) -> u64 {
103 use std::collections::hash_map::DefaultHasher;
104 let mut hasher = DefaultHasher::new();
105 for value in row {
106 value.hash_value(&mut hasher);
107 }
108 hasher.finish()
109}
110
111fn hash_single_value(value: &Value) -> u64 {
113 use std::collections::hash_map::DefaultHasher;
114 let mut hasher = DefaultHasher::new();
115 value.hash_value(&mut hasher);
116 hasher.finish()
117}
118
119pub type Row = Vec<Value>;
121
122#[derive(Debug, Clone)]
124pub struct QueryResult {
125 pub columns: Vec<String>,
127 pub rows: Vec<Row>,
129}
130
131impl QueryResult {
132 pub const fn new(columns: Vec<String>) -> Self {
134 Self {
135 columns,
136 rows: Vec::new(),
137 }
138 }
139
140 pub fn add_row(&mut self, row: Row) {
142 self.rows.push(row);
143 }
144
145 pub fn len(&self) -> usize {
147 self.rows.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.rows.is_empty()
153 }
154}
155
156#[derive(Debug)]
158pub struct PostingContext<'a> {
159 pub transaction: &'a Transaction,
161 pub posting_index: usize,
163 pub balance: Option<Inventory>,
165}
166
167#[derive(Debug, Clone)]
169pub struct WindowContext {
170 pub row_number: usize,
172 pub rank: usize,
174 pub dense_rank: usize,
176}
177
178pub struct Executor<'a> {
180 directives: &'a [Directive],
182 balances: HashMap<InternedStr, Inventory>,
184 price_db: crate::price::PriceDatabase,
186 target_currency: Option<String>,
188 regex_cache: RefCell<HashMap<String, Option<Regex>>>,
190}
191
192impl<'a> Executor<'a> {
193 pub fn new(directives: &'a [Directive]) -> Self {
195 let price_db = crate::price::PriceDatabase::from_directives(directives);
196 Self {
197 directives,
198 balances: HashMap::new(),
199 price_db,
200 target_currency: None,
201 regex_cache: RefCell::new(HashMap::new()),
202 }
203 }
204
205 fn get_or_compile_regex(&self, pattern: &str) -> Option<Regex> {
210 let mut cache = self.regex_cache.borrow_mut();
211 if let Some(cached) = cache.get(pattern) {
212 return cached.clone();
213 }
214 let compiled = Regex::new(pattern).ok();
215 cache.insert(pattern.to_string(), compiled.clone());
216 compiled
217 }
218
219 fn require_regex(&self, pattern: &str) -> Result<Regex, QueryError> {
221 self.get_or_compile_regex(pattern)
222 .ok_or_else(|| QueryError::Type(format!("invalid regex: {pattern}")))
223 }
224
225 pub fn set_target_currency(&mut self, currency: impl Into<String>) {
227 self.target_currency = Some(currency.into());
228 }
229
230 pub fn execute(&mut self, query: &Query) -> Result<QueryResult, QueryError> {
243 match query {
244 Query::Select(select) => self.execute_select(select),
245 Query::Journal(journal) => self.execute_journal(journal),
246 Query::Balances(balances) => self.execute_balances(balances),
247 Query::Print(print) => self.execute_print(print),
248 }
249 }
250
251 fn execute_select(&self, query: &SelectQuery) -> Result<QueryResult, QueryError> {
253 if let Some(from) = &query.from {
255 if let Some(subquery) = &from.subquery {
256 return self.execute_select_from_subquery(query, subquery);
257 }
258 }
259
260 let column_names = self.resolve_column_names(&query.targets)?;
262 let mut result = QueryResult::new(column_names.clone());
263
264 let postings = self.collect_postings(query.from.as_ref(), query.where_clause.as_ref())?;
266
267 let is_aggregate = query
269 .targets
270 .iter()
271 .any(|t| Self::is_aggregate_expr(&t.expr));
272
273 if is_aggregate {
274 let grouped = self.group_postings(&postings, query.group_by.as_ref())?;
276 for (_, group) in grouped {
277 let row = self.evaluate_aggregate_row(&query.targets, &group)?;
278
279 if let Some(having_expr) = &query.having {
281 if !self.evaluate_having_filter(
282 having_expr,
283 &row,
284 &column_names,
285 &query.targets,
286 &group,
287 )? {
288 continue;
289 }
290 }
291
292 result.add_row(row);
293 }
294 } else {
295 let has_windows = Self::has_window_functions(&query.targets);
297 let window_contexts = if has_windows {
298 if let Some(wf) = Self::find_window_function(&query.targets) {
299 Some(self.compute_window_contexts(&postings, wf)?)
300 } else {
301 None
302 }
303 } else {
304 None
305 };
306
307 let mut seen_hashes: HashSet<u64> = if query.distinct {
310 HashSet::with_capacity(postings.len())
311 } else {
312 HashSet::new()
313 };
314
315 for (i, ctx) in postings.iter().enumerate() {
316 let row = if let Some(ref wctxs) = window_contexts {
317 self.evaluate_row_with_window(&query.targets, ctx, Some(&wctxs[i]))?
318 } else {
319 self.evaluate_row(&query.targets, ctx)?
320 };
321 if query.distinct {
322 let row_hash = hash_row(&row);
324 if seen_hashes.insert(row_hash) {
325 result.add_row(row);
326 }
327 } else {
328 result.add_row(row);
329 }
330 }
331 }
332
333 if let Some(pivot_exprs) = &query.pivot_by {
335 result = self.apply_pivot(&result, pivot_exprs, &query.targets)?;
336 }
337
338 if let Some(order_by) = &query.order_by {
340 self.sort_results(&mut result, order_by)?;
341 } else if query.group_by.is_some() && !result.rows.is_empty() && !result.columns.is_empty()
342 {
343 let first_col = result.columns[0].clone();
346 let default_order = vec![OrderSpec {
347 expr: Expr::Column(first_col),
348 direction: SortDirection::Asc,
349 }];
350 self.sort_results(&mut result, &default_order)?;
351 }
352
353 if let Some(limit) = query.limit {
355 result.rows.truncate(limit as usize);
356 }
357
358 Ok(result)
359 }
360
361 fn execute_select_from_subquery(
363 &self,
364 outer_query: &SelectQuery,
365 inner_query: &SelectQuery,
366 ) -> Result<QueryResult, QueryError> {
367 let inner_result = self.execute_select(inner_query)?;
369
370 let inner_column_map: HashMap<String, usize> = inner_result
372 .columns
373 .iter()
374 .enumerate()
375 .map(|(i, name)| (name.to_lowercase(), i))
376 .collect();
377
378 let outer_column_names =
380 self.resolve_subquery_column_names(&outer_query.targets, &inner_result.columns)?;
381 let mut result = QueryResult::new(outer_column_names);
382
383 let mut seen_hashes: HashSet<u64> = if outer_query.distinct {
385 HashSet::with_capacity(inner_result.rows.len())
386 } else {
387 HashSet::new()
388 };
389
390 for inner_row in &inner_result.rows {
392 if let Some(where_expr) = &outer_query.where_clause {
394 if !self.evaluate_subquery_filter(where_expr, inner_row, &inner_column_map)? {
395 continue;
396 }
397 }
398
399 let outer_row =
401 self.evaluate_subquery_row(&outer_query.targets, inner_row, &inner_column_map)?;
402
403 if outer_query.distinct {
404 let row_hash = hash_row(&outer_row);
406 if seen_hashes.insert(row_hash) {
407 result.add_row(outer_row);
408 }
409 } else {
410 result.add_row(outer_row);
411 }
412 }
413
414 if let Some(order_by) = &outer_query.order_by {
416 self.sort_results(&mut result, order_by)?;
417 }
418
419 if let Some(limit) = outer_query.limit {
421 result.rows.truncate(limit as usize);
422 }
423
424 Ok(result)
425 }
426
427 fn resolve_subquery_column_names(
429 &self,
430 targets: &[Target],
431 inner_columns: &[String],
432 ) -> Result<Vec<String>, QueryError> {
433 let mut names = Vec::new();
434 for (i, target) in targets.iter().enumerate() {
435 if let Some(alias) = &target.alias {
436 names.push(alias.clone());
437 } else if matches!(target.expr, Expr::Wildcard) {
438 names.extend(inner_columns.iter().cloned());
440 } else {
441 names.push(self.expr_to_name(&target.expr, i));
442 }
443 }
444 Ok(names)
445 }
446
447 fn evaluate_subquery_filter(
449 &self,
450 expr: &Expr,
451 row: &[Value],
452 column_map: &HashMap<String, usize>,
453 ) -> Result<bool, QueryError> {
454 let val = self.evaluate_subquery_expr(expr, row, column_map)?;
455 self.to_bool(&val)
456 }
457
458 fn evaluate_subquery_expr(
460 &self,
461 expr: &Expr,
462 row: &[Value],
463 column_map: &HashMap<String, usize>,
464 ) -> Result<Value, QueryError> {
465 match expr {
466 Expr::Wildcard => Err(QueryError::Evaluation(
467 "Wildcard not allowed in expression context".to_string(),
468 )),
469 Expr::Column(name) => {
470 let lower = name.to_lowercase();
471 if let Some(&idx) = column_map.get(&lower) {
472 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
473 } else {
474 Err(QueryError::Evaluation(format!(
475 "Unknown column '{name}' in subquery result"
476 )))
477 }
478 }
479 Expr::Literal(lit) => self.evaluate_literal(lit),
480 Expr::Function(func) => {
481 let args: Vec<Value> = func
483 .args
484 .iter()
485 .map(|a| self.evaluate_subquery_expr(a, row, column_map))
486 .collect::<Result<Vec<_>, _>>()?;
487 self.evaluate_function_on_values(&func.name, &args)
488 }
489 Expr::BinaryOp(op) => {
490 let left = self.evaluate_subquery_expr(&op.left, row, column_map)?;
491 let right = self.evaluate_subquery_expr(&op.right, row, column_map)?;
492 self.binary_op_on_values(op.op, &left, &right)
493 }
494 Expr::UnaryOp(op) => {
495 let val = self.evaluate_subquery_expr(&op.operand, row, column_map)?;
496 self.unary_op_on_value(op.op, &val)
497 }
498 Expr::Paren(inner) => self.evaluate_subquery_expr(inner, row, column_map),
499 Expr::Window(_) => Err(QueryError::Evaluation(
500 "Window functions not supported in subquery expressions".to_string(),
501 )),
502 }
503 }
504
505 fn evaluate_subquery_row(
507 &self,
508 targets: &[Target],
509 inner_row: &[Value],
510 column_map: &HashMap<String, usize>,
511 ) -> Result<Row, QueryError> {
512 let mut row = Vec::new();
513 for target in targets {
514 if matches!(target.expr, Expr::Wildcard) {
515 row.extend(inner_row.iter().cloned());
517 } else {
518 row.push(self.evaluate_subquery_expr(&target.expr, inner_row, column_map)?);
519 }
520 }
521 Ok(row)
522 }
523
524 fn execute_journal(&mut self, query: &JournalQuery) -> Result<QueryResult, QueryError> {
526 let account_pattern = &query.account_pattern;
528
529 let account_regex = self.get_or_compile_regex(account_pattern);
531
532 let columns = vec![
533 "date".to_string(),
534 "flag".to_string(),
535 "payee".to_string(),
536 "narration".to_string(),
537 "account".to_string(),
538 "position".to_string(),
539 "balance".to_string(),
540 ];
541 let mut result = QueryResult::new(columns);
542
543 for directive in self.directives {
545 if let Directive::Transaction(txn) = directive {
546 if let Some(from) = &query.from {
548 if let Some(filter) = &from.filter {
549 if !self.evaluate_from_filter(filter, txn)? {
550 continue;
551 }
552 }
553 }
554
555 for posting in &txn.postings {
556 let matches = if let Some(ref regex) = account_regex {
558 regex.is_match(&posting.account)
559 } else {
560 posting.account.contains(account_pattern)
561 };
562
563 if matches {
564 let balance = self.balances.entry(posting.account.clone()).or_default();
566
567 if let Some(units) = posting.amount() {
569 let pos = if let Some(cost_spec) = &posting.cost {
570 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
571 Position::with_cost(units.clone(), cost)
572 } else {
573 Position::simple(units.clone())
574 }
575 } else {
576 Position::simple(units.clone())
577 };
578 balance.add(pos.clone());
579 }
580
581 let position_value = if let Some(at_func) = &query.at_function {
583 match at_func.to_uppercase().as_str() {
584 "COST" => {
585 if let Some(units) = posting.amount() {
586 if let Some(cost_spec) = &posting.cost {
587 if let Some(cost) =
588 cost_spec.resolve(units.number, txn.date)
589 {
590 let total = units.number * cost.number;
591 Value::Amount(Amount::new(total, &cost.currency))
592 } else {
593 Value::Amount(units.clone())
594 }
595 } else {
596 Value::Amount(units.clone())
597 }
598 } else {
599 Value::Null
600 }
601 }
602 "UNITS" => posting
603 .amount()
604 .map_or(Value::Null, |u| Value::Amount(u.clone())),
605 _ => posting
606 .amount()
607 .map_or(Value::Null, |u| Value::Amount(u.clone())),
608 }
609 } else {
610 posting
611 .amount()
612 .map_or(Value::Null, |u| Value::Amount(u.clone()))
613 };
614
615 let row = vec![
616 Value::Date(txn.date),
617 Value::String(txn.flag.to_string()),
618 Value::String(
619 txn.payee
620 .as_ref()
621 .map_or_else(String::new, ToString::to_string),
622 ),
623 Value::String(txn.narration.to_string()),
624 Value::String(posting.account.to_string()),
625 position_value,
626 Value::Inventory(balance.clone()),
627 ];
628 result.add_row(row);
629 }
630 }
631 }
632 }
633
634 Ok(result)
635 }
636
637 fn execute_balances(&mut self, query: &BalancesQuery) -> Result<QueryResult, QueryError> {
639 self.build_balances_with_filter(query.from.as_ref())?;
641
642 let columns = vec!["account".to_string(), "balance".to_string()];
643 let mut result = QueryResult::new(columns);
644
645 let mut accounts: Vec<_> = self.balances.keys().collect();
647 accounts.sort();
648
649 for account in accounts {
650 let Some(balance) = self.balances.get(account) else {
652 continue; };
654
655 let balance_value = if let Some(at_func) = &query.at_function {
657 match at_func.to_uppercase().as_str() {
658 "COST" => {
659 let cost_inventory = balance.at_cost();
661 Value::Inventory(cost_inventory)
662 }
663 "UNITS" => {
664 let units_inventory = balance.at_units();
666 Value::Inventory(units_inventory)
667 }
668 _ => Value::Inventory(balance.clone()),
669 }
670 } else {
671 Value::Inventory(balance.clone())
672 };
673
674 let row = vec![Value::String(account.to_string()), balance_value];
675 result.add_row(row);
676 }
677
678 Ok(result)
679 }
680
681 fn execute_print(&self, query: &PrintQuery) -> Result<QueryResult, QueryError> {
683 let columns = vec!["directive".to_string()];
685 let mut result = QueryResult::new(columns);
686
687 for directive in self.directives {
688 if let Some(from) = &query.from {
690 if let Some(filter) = &from.filter {
691 if let Directive::Transaction(txn) = directive {
693 if !self.evaluate_from_filter(filter, txn)? {
694 continue;
695 }
696 }
697 }
698 }
699
700 let formatted = self.format_directive(directive);
702 result.add_row(vec![Value::String(formatted)]);
703 }
704
705 Ok(result)
706 }
707
708 fn format_directive(&self, directive: &Directive) -> String {
710 match directive {
711 Directive::Transaction(txn) => {
712 let mut out = format!("{} {} ", txn.date, txn.flag);
713 if let Some(payee) = &txn.payee {
714 out.push_str(&format!("\"{payee}\" "));
715 }
716 out.push_str(&format!("\"{}\"", txn.narration));
717
718 for tag in &txn.tags {
719 out.push_str(&format!(" #{tag}"));
720 }
721 for link in &txn.links {
722 out.push_str(&format!(" ^{link}"));
723 }
724 out.push('\n');
725
726 for posting in &txn.postings {
727 out.push_str(&format!(" {}", posting.account));
728 if let Some(units) = posting.amount() {
729 out.push_str(&format!(" {} {}", units.number, units.currency));
730 }
731 out.push('\n');
732 }
733 out
734 }
735 Directive::Balance(bal) => {
736 format!(
737 "{} balance {} {} {}\n",
738 bal.date, bal.account, bal.amount.number, bal.amount.currency
739 )
740 }
741 Directive::Open(open) => {
742 let mut out = format!("{} open {}", open.date, open.account);
743 if !open.currencies.is_empty() {
744 out.push_str(&format!(" {}", open.currencies.join(",")));
745 }
746 out.push('\n');
747 out
748 }
749 Directive::Close(close) => {
750 format!("{} close {}\n", close.date, close.account)
751 }
752 Directive::Commodity(comm) => {
753 format!("{} commodity {}\n", comm.date, comm.currency)
754 }
755 Directive::Pad(pad) => {
756 format!("{} pad {} {}\n", pad.date, pad.account, pad.source_account)
757 }
758 Directive::Event(event) => {
759 format!(
760 "{} event \"{}\" \"{}\"\n",
761 event.date, event.event_type, event.value
762 )
763 }
764 Directive::Query(query) => {
765 format!(
766 "{} query \"{}\" \"{}\"\n",
767 query.date, query.name, query.query
768 )
769 }
770 Directive::Note(note) => {
771 format!("{} note {} \"{}\"\n", note.date, note.account, note.comment)
772 }
773 Directive::Document(doc) => {
774 format!("{} document {} \"{}\"\n", doc.date, doc.account, doc.path)
775 }
776 Directive::Price(price) => {
777 format!(
778 "{} price {} {} {}\n",
779 price.date, price.currency, price.amount.number, price.amount.currency
780 )
781 }
782 Directive::Custom(custom) => {
783 format!("{} custom \"{}\"\n", custom.date, custom.custom_type)
784 }
785 }
786 }
787
788 fn build_balances_with_filter(&mut self, from: Option<&FromClause>) -> Result<(), QueryError> {
790 for directive in self.directives {
791 if let Directive::Transaction(txn) = directive {
792 if let Some(from_clause) = from {
794 if let Some(filter) = &from_clause.filter {
795 if !self.evaluate_from_filter(filter, txn)? {
796 continue;
797 }
798 }
799 }
800
801 for posting in &txn.postings {
802 if let Some(units) = posting.amount() {
803 let balance = self.balances.entry(posting.account.clone()).or_default();
804
805 let pos = if let Some(cost_spec) = &posting.cost {
806 if let Some(cost) = cost_spec.resolve(units.number, txn.date) {
807 Position::with_cost(units.clone(), cost)
808 } else {
809 Position::simple(units.clone())
810 }
811 } else {
812 Position::simple(units.clone())
813 };
814 balance.add(pos);
815 }
816 }
817 }
818 }
819 Ok(())
820 }
821
822 fn collect_postings(
824 &self,
825 from: Option<&FromClause>,
826 where_clause: Option<&Expr>,
827 ) -> Result<Vec<PostingContext<'a>>, QueryError> {
828 let mut postings = Vec::new();
829 let mut running_balances: HashMap<InternedStr, Inventory> = HashMap::new();
831
832 for directive in self.directives {
833 if let Directive::Transaction(txn) = directive {
834 if let Some(from) = from {
836 if let Some(open_date) = from.open_on {
838 if txn.date < open_date {
839 for posting in &txn.postings {
841 if let Some(units) = posting.amount() {
842 let balance = running_balances
843 .entry(posting.account.clone())
844 .or_default();
845 balance.add(Position::simple(units.clone()));
846 }
847 }
848 continue;
849 }
850 }
851 if let Some(close_date) = from.close_on {
852 if txn.date > close_date {
853 continue;
854 }
855 }
856 if let Some(filter) = &from.filter {
858 if !self.evaluate_from_filter(filter, txn)? {
859 continue;
860 }
861 }
862 }
863
864 for (i, posting) in txn.postings.iter().enumerate() {
866 if let Some(units) = posting.amount() {
868 let balance = running_balances.entry(posting.account.clone()).or_default();
869 balance.add(Position::simple(units.clone()));
870 }
871
872 let ctx = PostingContext {
873 transaction: txn,
874 posting_index: i,
875 balance: running_balances.get(&posting.account).cloned(),
876 };
877
878 if let Some(where_expr) = where_clause {
880 if self.evaluate_predicate(where_expr, &ctx)? {
881 postings.push(ctx);
882 }
883 } else {
884 postings.push(ctx);
885 }
886 }
887 }
888 }
889
890 Ok(postings)
891 }
892
893 fn evaluate_from_filter(&self, filter: &Expr, txn: &Transaction) -> Result<bool, QueryError> {
895 match filter {
897 Expr::Function(func) => {
898 if func.name.to_uppercase().as_str() == "HAS_ACCOUNT" {
899 if func.args.len() != 1 {
900 return Err(QueryError::InvalidArguments(
901 "has_account".to_string(),
902 "expected 1 argument".to_string(),
903 ));
904 }
905 let pattern = match &func.args[0] {
906 Expr::Literal(Literal::String(s)) => s.clone(),
907 Expr::Column(s) => s.clone(),
908 _ => {
909 return Err(QueryError::Type(
910 "has_account expects a string pattern".to_string(),
911 ));
912 }
913 };
914 let regex = self.require_regex(&pattern)?;
916 for posting in &txn.postings {
917 if regex.is_match(&posting.account) {
918 return Ok(true);
919 }
920 }
921 Ok(false)
922 } else {
923 let dummy_ctx = PostingContext {
925 transaction: txn,
926 posting_index: 0,
927 balance: None,
928 };
929 self.evaluate_predicate(filter, &dummy_ctx)
930 }
931 }
932 Expr::BinaryOp(op) => {
933 match (&op.left, &op.right) {
935 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "YEAR" => {
936 if let Literal::Integer(n) = lit {
937 let matches = txn.date.year() == *n as i32;
938 Ok(if op.op == BinaryOperator::Eq {
939 matches
940 } else {
941 !matches
942 })
943 } else {
944 Ok(false)
945 }
946 }
947 (Expr::Column(col), Expr::Literal(lit)) if col.to_uppercase() == "MONTH" => {
948 if let Literal::Integer(n) = lit {
949 let matches = txn.date.month() == *n as u32;
950 Ok(if op.op == BinaryOperator::Eq {
951 matches
952 } else {
953 !matches
954 })
955 } else {
956 Ok(false)
957 }
958 }
959 (Expr::Column(col), Expr::Literal(Literal::Date(d)))
960 if col.to_uppercase() == "DATE" =>
961 {
962 let matches = match op.op {
963 BinaryOperator::Eq => txn.date == *d,
964 BinaryOperator::Ne => txn.date != *d,
965 BinaryOperator::Lt => txn.date < *d,
966 BinaryOperator::Le => txn.date <= *d,
967 BinaryOperator::Gt => txn.date > *d,
968 BinaryOperator::Ge => txn.date >= *d,
969 _ => false,
970 };
971 Ok(matches)
972 }
973 _ => {
974 let dummy_ctx = PostingContext {
976 transaction: txn,
977 posting_index: 0,
978 balance: None,
979 };
980 self.evaluate_predicate(filter, &dummy_ctx)
981 }
982 }
983 }
984 _ => {
985 let dummy_ctx = PostingContext {
987 transaction: txn,
988 posting_index: 0,
989 balance: None,
990 };
991 self.evaluate_predicate(filter, &dummy_ctx)
992 }
993 }
994 }
995
996 fn evaluate_predicate(&self, expr: &Expr, ctx: &PostingContext) -> Result<bool, QueryError> {
998 let value = self.evaluate_expr(expr, ctx)?;
999 match value {
1000 Value::Boolean(b) => Ok(b),
1001 Value::Null => Ok(false),
1002 _ => Err(QueryError::Type("expected boolean expression".to_string())),
1003 }
1004 }
1005
1006 fn evaluate_expr(&self, expr: &Expr, ctx: &PostingContext) -> Result<Value, QueryError> {
1008 match expr {
1009 Expr::Wildcard => Ok(Value::Null), Expr::Column(name) => self.evaluate_column(name, ctx),
1011 Expr::Literal(lit) => self.evaluate_literal(lit),
1012 Expr::Function(func) => self.evaluate_function(func, ctx),
1013 Expr::Window(_) => {
1014 Err(QueryError::Evaluation(
1017 "Window function cannot be evaluated in posting context".to_string(),
1018 ))
1019 }
1020 Expr::BinaryOp(op) => self.evaluate_binary_op(op, ctx),
1021 Expr::UnaryOp(op) => self.evaluate_unary_op(op, ctx),
1022 Expr::Paren(inner) => self.evaluate_expr(inner, ctx),
1023 }
1024 }
1025
1026 fn evaluate_column(&self, name: &str, ctx: &PostingContext) -> Result<Value, QueryError> {
1028 let posting = &ctx.transaction.postings[ctx.posting_index];
1029
1030 match name {
1031 "date" => Ok(Value::Date(ctx.transaction.date)),
1032 "account" => Ok(Value::String(posting.account.to_string())),
1033 "narration" => Ok(Value::String(ctx.transaction.narration.to_string())),
1034 "payee" => Ok(ctx
1035 .transaction
1036 .payee
1037 .as_ref()
1038 .map_or(Value::Null, |p| Value::String(p.to_string()))),
1039 "flag" => Ok(Value::String(ctx.transaction.flag.to_string())),
1040 "tags" => Ok(Value::StringSet(
1041 ctx.transaction
1042 .tags
1043 .iter()
1044 .map(ToString::to_string)
1045 .collect(),
1046 )),
1047 "links" => Ok(Value::StringSet(
1048 ctx.transaction
1049 .links
1050 .iter()
1051 .map(ToString::to_string)
1052 .collect(),
1053 )),
1054 "position" => {
1055 if let Some(units) = posting.amount() {
1057 if let Some(cost_spec) = &posting.cost {
1058 if let (Some(number_per), Some(currency)) =
1059 (&cost_spec.number_per, &cost_spec.currency)
1060 {
1061 let cost = rustledger_core::Cost::new(*number_per, currency.clone())
1063 .with_date_opt(cost_spec.date)
1064 .with_label_opt(cost_spec.label.clone());
1065 return Ok(Value::Position(Position::with_cost(units.clone(), cost)));
1066 }
1067 }
1068 Ok(Value::Position(Position::simple(units.clone())))
1069 } else {
1070 Ok(Value::Null)
1071 }
1072 }
1073 "units" => Ok(posting
1074 .amount()
1075 .map_or(Value::Null, |u| Value::Amount(u.clone()))),
1076 "cost" => {
1077 if let Some(units) = posting.amount() {
1079 if let Some(cost) = &posting.cost {
1080 if let Some(number_per) = &cost.number_per {
1081 if let Some(currency) = &cost.currency {
1082 let total = units.number.abs() * number_per;
1083 return Ok(Value::Amount(Amount::new(total, currency.clone())));
1084 }
1085 }
1086 }
1087 }
1088 Ok(Value::Null)
1089 }
1090 "weight" => {
1091 if let Some(units) = posting.amount() {
1095 if let Some(cost) = &posting.cost {
1096 if let Some(number_per) = &cost.number_per {
1097 if let Some(currency) = &cost.currency {
1098 let total = units.number * number_per;
1099 return Ok(Value::Amount(Amount::new(total, currency.clone())));
1100 }
1101 }
1102 }
1103 Ok(Value::Amount(units.clone()))
1105 } else {
1106 Ok(Value::Null)
1107 }
1108 }
1109 "balance" => {
1110 if let Some(ref balance) = ctx.balance {
1112 Ok(Value::Inventory(balance.clone()))
1113 } else {
1114 Ok(Value::Null)
1115 }
1116 }
1117 "year" => Ok(Value::Integer(ctx.transaction.date.year().into())),
1118 "month" => Ok(Value::Integer(ctx.transaction.date.month().into())),
1119 "day" => Ok(Value::Integer(ctx.transaction.date.day().into())),
1120 "currency" => Ok(posting
1121 .amount()
1122 .map_or(Value::Null, |u| Value::String(u.currency.to_string()))),
1123 "number" => Ok(posting
1124 .amount()
1125 .map_or(Value::Null, |u| Value::Number(u.number))),
1126 _ => Err(QueryError::UnknownColumn(name.to_string())),
1127 }
1128 }
1129
1130 fn evaluate_literal(&self, lit: &Literal) -> Result<Value, QueryError> {
1132 Ok(match lit {
1133 Literal::String(s) => Value::String(s.clone()),
1134 Literal::Number(n) => Value::Number(*n),
1135 Literal::Integer(i) => Value::Integer(*i),
1136 Literal::Date(d) => Value::Date(*d),
1137 Literal::Boolean(b) => Value::Boolean(*b),
1138 Literal::Null => Value::Null,
1139 })
1140 }
1141
1142 fn evaluate_function(
1146 &self,
1147 func: &FunctionCall,
1148 ctx: &PostingContext,
1149 ) -> Result<Value, QueryError> {
1150 let name = func.name.to_uppercase();
1151 match name.as_str() {
1152 "YEAR" | "MONTH" | "DAY" | "WEEKDAY" | "QUARTER" | "YMONTH" | "TODAY" => {
1154 self.eval_date_function(&name, func, ctx)
1155 }
1156 "LENGTH" | "UPPER" | "LOWER" | "SUBSTR" | "SUBSTRING" | "TRIM" | "STARTSWITH"
1158 | "ENDSWITH" => self.eval_string_function(&name, func, ctx),
1159 "PARENT" | "LEAF" | "ROOT" | "ACCOUNT_DEPTH" | "ACCOUNT_SORTKEY" => {
1161 self.eval_account_function(&name, func, ctx)
1162 }
1163 "ABS" | "NEG" | "ROUND" | "SAFEDIV" => self.eval_math_function(&name, func, ctx),
1165 "NUMBER" | "CURRENCY" | "GETITEM" | "GET" | "UNITS" | "COST" | "WEIGHT" | "VALUE" => {
1167 self.eval_position_function(&name, func, ctx)
1168 }
1169 "COALESCE" => self.eval_coalesce(func, ctx),
1171 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1174 _ => Err(QueryError::UnknownFunction(func.name.clone())),
1175 }
1176 }
1177
1178 fn eval_date_function(
1180 &self,
1181 name: &str,
1182 func: &FunctionCall,
1183 ctx: &PostingContext,
1184 ) -> Result<Value, QueryError> {
1185 if name == "TODAY" {
1186 if !func.args.is_empty() {
1187 return Err(QueryError::InvalidArguments(
1188 "TODAY".to_string(),
1189 "expected 0 arguments".to_string(),
1190 ));
1191 }
1192 return Ok(Value::Date(chrono::Local::now().date_naive()));
1193 }
1194
1195 if func.args.len() != 1 {
1197 return Err(QueryError::InvalidArguments(
1198 name.to_string(),
1199 "expected 1 argument".to_string(),
1200 ));
1201 }
1202
1203 let val = self.evaluate_expr(&func.args[0], ctx)?;
1204 let date = match val {
1205 Value::Date(d) => d,
1206 _ => return Err(QueryError::Type(format!("{name} expects a date"))),
1207 };
1208
1209 match name {
1210 "YEAR" => Ok(Value::Integer(date.year().into())),
1211 "MONTH" => Ok(Value::Integer(date.month().into())),
1212 "DAY" => Ok(Value::Integer(date.day().into())),
1213 "WEEKDAY" => Ok(Value::Integer(date.weekday().num_days_from_monday().into())),
1214 "QUARTER" => {
1215 let quarter = (date.month() - 1) / 3 + 1;
1216 Ok(Value::Integer(quarter.into()))
1217 }
1218 "YMONTH" => Ok(Value::String(format!(
1219 "{:04}-{:02}",
1220 date.year(),
1221 date.month()
1222 ))),
1223 _ => unreachable!(),
1224 }
1225 }
1226
1227 fn eval_string_function(
1229 &self,
1230 name: &str,
1231 func: &FunctionCall,
1232 ctx: &PostingContext,
1233 ) -> Result<Value, QueryError> {
1234 match name {
1235 "LENGTH" => {
1236 Self::require_args(name, func, 1)?;
1237 let val = self.evaluate_expr(&func.args[0], ctx)?;
1238 match val {
1239 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1240 Value::StringSet(s) => Ok(Value::Integer(s.len() as i64)),
1241 _ => Err(QueryError::Type(
1242 "LENGTH expects a string or set".to_string(),
1243 )),
1244 }
1245 }
1246 "UPPER" => {
1247 Self::require_args(name, func, 1)?;
1248 let val = self.evaluate_expr(&func.args[0], ctx)?;
1249 match val {
1250 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1251 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1252 }
1253 }
1254 "LOWER" => {
1255 Self::require_args(name, func, 1)?;
1256 let val = self.evaluate_expr(&func.args[0], ctx)?;
1257 match val {
1258 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1259 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1260 }
1261 }
1262 "TRIM" => {
1263 Self::require_args(name, func, 1)?;
1264 let val = self.evaluate_expr(&func.args[0], ctx)?;
1265 match val {
1266 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1267 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1268 }
1269 }
1270 "SUBSTR" | "SUBSTRING" => self.eval_substr(func, ctx),
1271 "STARTSWITH" => {
1272 Self::require_args(name, func, 2)?;
1273 let val = self.evaluate_expr(&func.args[0], ctx)?;
1274 let prefix = self.evaluate_expr(&func.args[1], ctx)?;
1275 match (val, prefix) {
1276 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.starts_with(&p))),
1277 _ => Err(QueryError::Type(
1278 "STARTSWITH expects two strings".to_string(),
1279 )),
1280 }
1281 }
1282 "ENDSWITH" => {
1283 Self::require_args(name, func, 2)?;
1284 let val = self.evaluate_expr(&func.args[0], ctx)?;
1285 let suffix = self.evaluate_expr(&func.args[1], ctx)?;
1286 match (val, suffix) {
1287 (Value::String(s), Value::String(p)) => Ok(Value::Boolean(s.ends_with(&p))),
1288 _ => Err(QueryError::Type("ENDSWITH expects two strings".to_string())),
1289 }
1290 }
1291 _ => unreachable!(),
1292 }
1293 }
1294
1295 fn eval_substr(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1297 if func.args.len() < 2 || func.args.len() > 3 {
1298 return Err(QueryError::InvalidArguments(
1299 "SUBSTR".to_string(),
1300 "expected 2 or 3 arguments".to_string(),
1301 ));
1302 }
1303
1304 let val = self.evaluate_expr(&func.args[0], ctx)?;
1305 let start = self.evaluate_expr(&func.args[1], ctx)?;
1306 let len = if func.args.len() == 3 {
1307 Some(self.evaluate_expr(&func.args[2], ctx)?)
1308 } else {
1309 None
1310 };
1311
1312 match (val, start, len) {
1313 (Value::String(s), Value::Integer(start), None) => {
1314 let start = start.max(0) as usize;
1315 if start >= s.len() {
1316 Ok(Value::String(String::new()))
1317 } else {
1318 Ok(Value::String(s[start..].to_string()))
1319 }
1320 }
1321 (Value::String(s), Value::Integer(start), Some(Value::Integer(len))) => {
1322 let start = start.max(0) as usize;
1323 let len = len.max(0) as usize;
1324 if start >= s.len() {
1325 Ok(Value::String(String::new()))
1326 } else {
1327 let end = (start + len).min(s.len());
1328 Ok(Value::String(s[start..end].to_string()))
1329 }
1330 }
1331 _ => Err(QueryError::Type(
1332 "SUBSTR expects (string, int, [int])".to_string(),
1333 )),
1334 }
1335 }
1336
1337 fn eval_account_function(
1339 &self,
1340 name: &str,
1341 func: &FunctionCall,
1342 ctx: &PostingContext,
1343 ) -> Result<Value, QueryError> {
1344 match name {
1345 "PARENT" => {
1346 Self::require_args(name, func, 1)?;
1347 let val = self.evaluate_expr(&func.args[0], ctx)?;
1348 match val {
1349 Value::String(s) => {
1350 if let Some(idx) = s.rfind(':') {
1351 Ok(Value::String(s[..idx].to_string()))
1352 } else {
1353 Ok(Value::Null)
1354 }
1355 }
1356 _ => Err(QueryError::Type(
1357 "PARENT expects an account string".to_string(),
1358 )),
1359 }
1360 }
1361 "LEAF" => {
1362 Self::require_args(name, func, 1)?;
1363 let val = self.evaluate_expr(&func.args[0], ctx)?;
1364 match val {
1365 Value::String(s) => {
1366 if let Some(idx) = s.rfind(':') {
1367 Ok(Value::String(s[idx + 1..].to_string()))
1368 } else {
1369 Ok(Value::String(s))
1370 }
1371 }
1372 _ => Err(QueryError::Type(
1373 "LEAF expects an account string".to_string(),
1374 )),
1375 }
1376 }
1377 "ROOT" => self.eval_root(func, ctx),
1378 "ACCOUNT_DEPTH" => {
1379 Self::require_args(name, func, 1)?;
1380 let val = self.evaluate_expr(&func.args[0], ctx)?;
1381 match val {
1382 Value::String(s) => {
1383 let depth = s.chars().filter(|c| *c == ':').count() + 1;
1384 Ok(Value::Integer(depth as i64))
1385 }
1386 _ => Err(QueryError::Type(
1387 "ACCOUNT_DEPTH expects an account string".to_string(),
1388 )),
1389 }
1390 }
1391 "ACCOUNT_SORTKEY" => {
1392 Self::require_args(name, func, 1)?;
1393 let val = self.evaluate_expr(&func.args[0], ctx)?;
1394 match val {
1395 Value::String(s) => Ok(Value::String(s)),
1396 _ => Err(QueryError::Type(
1397 "ACCOUNT_SORTKEY expects an account string".to_string(),
1398 )),
1399 }
1400 }
1401 _ => unreachable!(),
1402 }
1403 }
1404
1405 fn eval_root(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1407 if func.args.is_empty() || func.args.len() > 2 {
1408 return Err(QueryError::InvalidArguments(
1409 "ROOT".to_string(),
1410 "expected 1 or 2 arguments".to_string(),
1411 ));
1412 }
1413
1414 let val = self.evaluate_expr(&func.args[0], ctx)?;
1415 let n = if func.args.len() == 2 {
1416 match self.evaluate_expr(&func.args[1], ctx)? {
1417 Value::Integer(i) => i as usize,
1418 _ => {
1419 return Err(QueryError::Type(
1420 "ROOT second arg must be integer".to_string(),
1421 ));
1422 }
1423 }
1424 } else {
1425 1
1426 };
1427
1428 match val {
1429 Value::String(s) => {
1430 let parts: Vec<&str> = s.split(':').collect();
1431 if n >= parts.len() {
1432 Ok(Value::String(s))
1433 } else {
1434 Ok(Value::String(parts[..n].join(":")))
1435 }
1436 }
1437 _ => Err(QueryError::Type(
1438 "ROOT expects an account string".to_string(),
1439 )),
1440 }
1441 }
1442
1443 fn eval_math_function(
1445 &self,
1446 name: &str,
1447 func: &FunctionCall,
1448 ctx: &PostingContext,
1449 ) -> Result<Value, QueryError> {
1450 match name {
1451 "ABS" => {
1452 Self::require_args(name, func, 1)?;
1453 let val = self.evaluate_expr(&func.args[0], ctx)?;
1454 match val {
1455 Value::Number(n) => Ok(Value::Number(n.abs())),
1456 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1457 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1458 }
1459 }
1460 "NEG" => {
1461 Self::require_args(name, func, 1)?;
1462 let val = self.evaluate_expr(&func.args[0], ctx)?;
1463 match val {
1464 Value::Number(n) => Ok(Value::Number(-n)),
1465 Value::Integer(i) => Ok(Value::Integer(-i)),
1466 _ => Err(QueryError::Type("NEG expects a number".to_string())),
1467 }
1468 }
1469 "ROUND" => self.eval_round(func, ctx),
1470 "SAFEDIV" => self.eval_safediv(func, ctx),
1471 _ => unreachable!(),
1472 }
1473 }
1474
1475 fn eval_round(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1477 if func.args.is_empty() || func.args.len() > 2 {
1478 return Err(QueryError::InvalidArguments(
1479 "ROUND".to_string(),
1480 "expected 1 or 2 arguments".to_string(),
1481 ));
1482 }
1483
1484 let val = self.evaluate_expr(&func.args[0], ctx)?;
1485 let decimals = if func.args.len() == 2 {
1486 match self.evaluate_expr(&func.args[1], ctx)? {
1487 Value::Integer(i) => i as u32,
1488 _ => {
1489 return Err(QueryError::Type(
1490 "ROUND second arg must be integer".to_string(),
1491 ));
1492 }
1493 }
1494 } else {
1495 0
1496 };
1497
1498 match val {
1499 Value::Number(n) => Ok(Value::Number(n.round_dp(decimals))),
1500 Value::Integer(i) => Ok(Value::Integer(i)),
1501 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1502 }
1503 }
1504
1505 fn eval_safediv(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1507 Self::require_args("SAFEDIV", func, 2)?;
1508 let num = self.evaluate_expr(&func.args[0], ctx)?;
1509 let den = self.evaluate_expr(&func.args[1], ctx)?;
1510
1511 match (num, den) {
1512 (Value::Number(n), Value::Number(d)) => {
1513 if d.is_zero() {
1514 Ok(Value::Number(Decimal::ZERO))
1515 } else {
1516 Ok(Value::Number(n / d))
1517 }
1518 }
1519 (Value::Integer(n), Value::Integer(d)) => {
1520 if d == 0 {
1521 Ok(Value::Integer(0))
1522 } else {
1523 Ok(Value::Integer(n / d))
1524 }
1525 }
1526 _ => Err(QueryError::Type("SAFEDIV expects two numbers".to_string())),
1527 }
1528 }
1529
1530 fn eval_position_function(
1532 &self,
1533 name: &str,
1534 func: &FunctionCall,
1535 ctx: &PostingContext,
1536 ) -> Result<Value, QueryError> {
1537 match name {
1538 "NUMBER" => {
1539 Self::require_args(name, func, 1)?;
1540 let val = self.evaluate_expr(&func.args[0], ctx)?;
1541 match val {
1542 Value::Amount(a) => Ok(Value::Number(a.number)),
1543 Value::Position(p) => Ok(Value::Number(p.units.number)),
1544 Value::Number(n) => Ok(Value::Number(n)),
1545 Value::Integer(i) => Ok(Value::Number(Decimal::from(i))),
1546 _ => Err(QueryError::Type(
1547 "NUMBER expects an amount or position".to_string(),
1548 )),
1549 }
1550 }
1551 "CURRENCY" => {
1552 Self::require_args(name, func, 1)?;
1553 let val = self.evaluate_expr(&func.args[0], ctx)?;
1554 match val {
1555 Value::Amount(a) => Ok(Value::String(a.currency.to_string())),
1556 Value::Position(p) => Ok(Value::String(p.units.currency.to_string())),
1557 _ => Err(QueryError::Type(
1558 "CURRENCY expects an amount or position".to_string(),
1559 )),
1560 }
1561 }
1562 "GETITEM" | "GET" => self.eval_getitem(func, ctx),
1563 "UNITS" => self.eval_units(func, ctx),
1564 "COST" => self.eval_cost(func, ctx),
1565 "WEIGHT" => self.eval_weight(func, ctx),
1566 "VALUE" => self.eval_value(func, ctx),
1567 _ => unreachable!(),
1568 }
1569 }
1570
1571 fn eval_getitem(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1573 Self::require_args("GETITEM", func, 2)?;
1574 let val = self.evaluate_expr(&func.args[0], ctx)?;
1575 let key = self.evaluate_expr(&func.args[1], ctx)?;
1576
1577 match (val, key) {
1578 (Value::Inventory(inv), Value::String(currency)) => {
1579 let total = inv.units(¤cy);
1580 if total.is_zero() {
1581 Ok(Value::Null)
1582 } else {
1583 Ok(Value::Amount(Amount::new(total, currency)))
1584 }
1585 }
1586 _ => Err(QueryError::Type(
1587 "GETITEM expects (inventory, string)".to_string(),
1588 )),
1589 }
1590 }
1591
1592 fn eval_units(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1594 Self::require_args("UNITS", func, 1)?;
1595 let val = self.evaluate_expr(&func.args[0], ctx)?;
1596
1597 match val {
1598 Value::Position(p) => Ok(Value::Amount(p.units)),
1599 Value::Amount(a) => Ok(Value::Amount(a)),
1600 Value::Inventory(inv) => {
1601 let positions: Vec<String> = inv
1602 .positions()
1603 .iter()
1604 .map(|p| format!("{} {}", p.units.number, p.units.currency))
1605 .collect();
1606 Ok(Value::String(positions.join(", ")))
1607 }
1608 _ => Err(QueryError::Type(
1609 "UNITS expects a position or inventory".to_string(),
1610 )),
1611 }
1612 }
1613
1614 fn eval_cost(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1616 Self::require_args("COST", func, 1)?;
1617 let val = self.evaluate_expr(&func.args[0], ctx)?;
1618
1619 match val {
1620 Value::Position(p) => {
1621 if let Some(cost) = &p.cost {
1622 let total = p.units.number.abs() * cost.number;
1623 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1624 } else {
1625 Ok(Value::Null)
1626 }
1627 }
1628 Value::Amount(a) => Ok(Value::Amount(a)),
1629 Value::Inventory(inv) => {
1630 let mut total = Decimal::ZERO;
1631 let mut currency: Option<InternedStr> = None;
1632 for pos in inv.positions() {
1633 if let Some(cost) = &pos.cost {
1634 total += pos.units.number.abs() * cost.number;
1635 if currency.is_none() {
1636 currency = Some(cost.currency.clone());
1637 }
1638 }
1639 }
1640 if let Some(curr) = currency {
1641 Ok(Value::Amount(Amount::new(total, curr)))
1642 } else {
1643 Ok(Value::Null)
1644 }
1645 }
1646 _ => Err(QueryError::Type(
1647 "COST expects a position or inventory".to_string(),
1648 )),
1649 }
1650 }
1651
1652 fn eval_weight(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1654 Self::require_args("WEIGHT", func, 1)?;
1655 let val = self.evaluate_expr(&func.args[0], ctx)?;
1656
1657 match val {
1658 Value::Position(p) => {
1659 if let Some(cost) = &p.cost {
1660 let total = p.units.number * cost.number;
1661 Ok(Value::Amount(Amount::new(total, cost.currency.clone())))
1662 } else {
1663 Ok(Value::Amount(p.units))
1664 }
1665 }
1666 Value::Amount(a) => Ok(Value::Amount(a)),
1667 _ => Err(QueryError::Type(
1668 "WEIGHT expects a position or amount".to_string(),
1669 )),
1670 }
1671 }
1672
1673 fn eval_value(&self, func: &FunctionCall, ctx: &PostingContext) -> Result<Value, QueryError> {
1675 if func.args.is_empty() || func.args.len() > 2 {
1676 return Err(QueryError::InvalidArguments(
1677 "VALUE".to_string(),
1678 "expected 1-2 arguments".to_string(),
1679 ));
1680 }
1681
1682 let target_currency = if func.args.len() == 2 {
1683 match self.evaluate_expr(&func.args[1], ctx)? {
1684 Value::String(s) => s,
1685 _ => {
1686 return Err(QueryError::Type(
1687 "VALUE second argument must be a currency string".to_string(),
1688 ));
1689 }
1690 }
1691 } else {
1692 self.target_currency.clone().ok_or_else(|| {
1693 QueryError::InvalidArguments(
1694 "VALUE".to_string(),
1695 "no target currency set; either call set_target_currency() on the executor \
1696 or pass the currency as VALUE(amount, 'USD')"
1697 .to_string(),
1698 )
1699 })?
1700 };
1701
1702 let val = self.evaluate_expr(&func.args[0], ctx)?;
1703 let date = ctx.transaction.date;
1704
1705 match val {
1706 Value::Position(p) => {
1707 if p.units.currency == target_currency {
1708 Ok(Value::Amount(p.units))
1709 } else if let Some(converted) =
1710 self.price_db.convert(&p.units, &target_currency, date)
1711 {
1712 Ok(Value::Amount(converted))
1713 } else {
1714 Ok(Value::Amount(p.units))
1715 }
1716 }
1717 Value::Amount(a) => {
1718 if a.currency == target_currency {
1719 Ok(Value::Amount(a))
1720 } else if let Some(converted) = self.price_db.convert(&a, &target_currency, date) {
1721 Ok(Value::Amount(converted))
1722 } else {
1723 Ok(Value::Amount(a))
1724 }
1725 }
1726 Value::Inventory(inv) => {
1727 let mut total = Decimal::ZERO;
1728 for pos in inv.positions() {
1729 if pos.units.currency == target_currency {
1730 total += pos.units.number;
1731 } else if let Some(converted) =
1732 self.price_db.convert(&pos.units, &target_currency, date)
1733 {
1734 total += converted.number;
1735 }
1736 }
1737 Ok(Value::Amount(Amount::new(total, &target_currency)))
1738 }
1739 _ => Err(QueryError::Type(
1740 "VALUE expects a position or inventory".to_string(),
1741 )),
1742 }
1743 }
1744
1745 fn eval_coalesce(
1747 &self,
1748 func: &FunctionCall,
1749 ctx: &PostingContext,
1750 ) -> Result<Value, QueryError> {
1751 for arg in &func.args {
1752 let val = self.evaluate_expr(arg, ctx)?;
1753 if !matches!(val, Value::Null) {
1754 return Ok(val);
1755 }
1756 }
1757 Ok(Value::Null)
1758 }
1759
1760 fn evaluate_function_on_values(&self, name: &str, args: &[Value]) -> Result<Value, QueryError> {
1762 let name_upper = name.to_uppercase();
1763 match name_upper.as_str() {
1764 "TODAY" => Ok(Value::Date(chrono::Local::now().date_naive())),
1766 "YEAR" => {
1767 Self::require_args_count(&name_upper, args, 1)?;
1768 match &args[0] {
1769 Value::Date(d) => Ok(Value::Integer(d.year().into())),
1770 _ => Err(QueryError::Type("YEAR expects a date".to_string())),
1771 }
1772 }
1773 "MONTH" => {
1774 Self::require_args_count(&name_upper, args, 1)?;
1775 match &args[0] {
1776 Value::Date(d) => Ok(Value::Integer(d.month().into())),
1777 _ => Err(QueryError::Type("MONTH expects a date".to_string())),
1778 }
1779 }
1780 "DAY" => {
1781 Self::require_args_count(&name_upper, args, 1)?;
1782 match &args[0] {
1783 Value::Date(d) => Ok(Value::Integer(d.day().into())),
1784 _ => Err(QueryError::Type("DAY expects a date".to_string())),
1785 }
1786 }
1787 "LENGTH" => {
1789 Self::require_args_count(&name_upper, args, 1)?;
1790 match &args[0] {
1791 Value::String(s) => Ok(Value::Integer(s.len() as i64)),
1792 _ => Err(QueryError::Type("LENGTH expects a string".to_string())),
1793 }
1794 }
1795 "UPPER" => {
1796 Self::require_args_count(&name_upper, args, 1)?;
1797 match &args[0] {
1798 Value::String(s) => Ok(Value::String(s.to_uppercase())),
1799 _ => Err(QueryError::Type("UPPER expects a string".to_string())),
1800 }
1801 }
1802 "LOWER" => {
1803 Self::require_args_count(&name_upper, args, 1)?;
1804 match &args[0] {
1805 Value::String(s) => Ok(Value::String(s.to_lowercase())),
1806 _ => Err(QueryError::Type("LOWER expects a string".to_string())),
1807 }
1808 }
1809 "TRIM" => {
1810 Self::require_args_count(&name_upper, args, 1)?;
1811 match &args[0] {
1812 Value::String(s) => Ok(Value::String(s.trim().to_string())),
1813 _ => Err(QueryError::Type("TRIM expects a string".to_string())),
1814 }
1815 }
1816 "ABS" => {
1818 Self::require_args_count(&name_upper, args, 1)?;
1819 match &args[0] {
1820 Value::Number(n) => Ok(Value::Number(n.abs())),
1821 Value::Integer(i) => Ok(Value::Integer(i.abs())),
1822 _ => Err(QueryError::Type("ABS expects a number".to_string())),
1823 }
1824 }
1825 "ROUND" => {
1826 if args.is_empty() || args.len() > 2 {
1827 return Err(QueryError::InvalidArguments(
1828 "ROUND".to_string(),
1829 "expected 1 or 2 arguments".to_string(),
1830 ));
1831 }
1832 match &args[0] {
1833 Value::Number(n) => {
1834 let scale = if args.len() == 2 {
1835 match &args[1] {
1836 Value::Integer(i) => *i as u32,
1837 _ => 0,
1838 }
1839 } else {
1840 0
1841 };
1842 Ok(Value::Number(n.round_dp(scale)))
1843 }
1844 Value::Integer(i) => Ok(Value::Integer(*i)),
1845 _ => Err(QueryError::Type("ROUND expects a number".to_string())),
1846 }
1847 }
1848 "COALESCE" => {
1850 for arg in args {
1851 if !matches!(arg, Value::Null) {
1852 return Ok(arg.clone());
1853 }
1854 }
1855 Ok(Value::Null)
1856 }
1857 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG" => Ok(Value::Null),
1859 _ => Err(QueryError::UnknownFunction(name.to_string())),
1860 }
1861 }
1862
1863 fn require_args_count(name: &str, args: &[Value], expected: usize) -> Result<(), QueryError> {
1865 if args.len() != expected {
1866 return Err(QueryError::InvalidArguments(
1867 name.to_string(),
1868 format!("expected {} argument(s), got {}", expected, args.len()),
1869 ));
1870 }
1871 Ok(())
1872 }
1873
1874 fn require_args(name: &str, func: &FunctionCall, expected: usize) -> Result<(), QueryError> {
1876 if func.args.len() != expected {
1877 return Err(QueryError::InvalidArguments(
1878 name.to_string(),
1879 format!("expected {expected} argument(s)"),
1880 ));
1881 }
1882 Ok(())
1883 }
1884
1885 fn evaluate_binary_op(&self, op: &BinaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1887 let left = self.evaluate_expr(&op.left, ctx)?;
1888 let right = self.evaluate_expr(&op.right, ctx)?;
1889
1890 match op.op {
1891 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(&left, &right))),
1892 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(&left, &right))),
1893 BinaryOperator::Lt => self.compare_values(&left, &right, std::cmp::Ordering::is_lt),
1894 BinaryOperator::Le => self.compare_values(&left, &right, std::cmp::Ordering::is_le),
1895 BinaryOperator::Gt => self.compare_values(&left, &right, std::cmp::Ordering::is_gt),
1896 BinaryOperator::Ge => self.compare_values(&left, &right, std::cmp::Ordering::is_ge),
1897 BinaryOperator::And => {
1898 let l = self.to_bool(&left)?;
1899 let r = self.to_bool(&right)?;
1900 Ok(Value::Boolean(l && r))
1901 }
1902 BinaryOperator::Or => {
1903 let l = self.to_bool(&left)?;
1904 let r = self.to_bool(&right)?;
1905 Ok(Value::Boolean(l || r))
1906 }
1907 BinaryOperator::Regex => {
1908 let s = match left {
1910 Value::String(s) => s,
1911 _ => {
1912 return Err(QueryError::Type(
1913 "regex requires string left operand".to_string(),
1914 ));
1915 }
1916 };
1917 let pattern = match right {
1918 Value::String(p) => p,
1919 _ => {
1920 return Err(QueryError::Type(
1921 "regex requires string pattern".to_string(),
1922 ));
1923 }
1924 };
1925 Ok(Value::Boolean(s.contains(&pattern)))
1927 }
1928 BinaryOperator::In => {
1929 match right {
1931 Value::StringSet(set) => {
1932 let needle = match left {
1933 Value::String(s) => s,
1934 _ => {
1935 return Err(QueryError::Type(
1936 "IN requires string left operand".to_string(),
1937 ));
1938 }
1939 };
1940 Ok(Value::Boolean(set.contains(&needle)))
1941 }
1942 _ => Err(QueryError::Type(
1943 "IN requires set right operand".to_string(),
1944 )),
1945 }
1946 }
1947 BinaryOperator::Add => self.arithmetic_op(&left, &right, |a, b| a + b),
1948 BinaryOperator::Sub => self.arithmetic_op(&left, &right, |a, b| a - b),
1949 BinaryOperator::Mul => self.arithmetic_op(&left, &right, |a, b| a * b),
1950 BinaryOperator::Div => self.arithmetic_op(&left, &right, |a, b| a / b),
1951 }
1952 }
1953
1954 fn evaluate_unary_op(&self, op: &UnaryOp, ctx: &PostingContext) -> Result<Value, QueryError> {
1956 let val = self.evaluate_expr(&op.operand, ctx)?;
1957 self.unary_op_on_value(op.op, &val)
1958 }
1959
1960 fn unary_op_on_value(&self, op: UnaryOperator, val: &Value) -> Result<Value, QueryError> {
1962 match op {
1963 UnaryOperator::Not => {
1964 let b = self.to_bool(val)?;
1965 Ok(Value::Boolean(!b))
1966 }
1967 UnaryOperator::Neg => match val {
1968 Value::Number(n) => Ok(Value::Number(-*n)),
1969 Value::Integer(i) => Ok(Value::Integer(-*i)),
1970 _ => Err(QueryError::Type(
1971 "negation requires numeric value".to_string(),
1972 )),
1973 },
1974 }
1975 }
1976
1977 fn values_equal(&self, left: &Value, right: &Value) -> bool {
1979 match (left, right) {
1981 (Value::Null, Value::Null) => true,
1982 (Value::String(a), Value::String(b)) => a == b,
1983 (Value::Number(a), Value::Number(b)) => a == b,
1984 (Value::Integer(a), Value::Integer(b)) => a == b,
1985 (Value::Number(a), Value::Integer(b)) => *a == Decimal::from(*b),
1986 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a) == *b,
1987 (Value::Date(a), Value::Date(b)) => a == b,
1988 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1989 _ => false,
1990 }
1991 }
1992
1993 fn compare_values<F>(&self, left: &Value, right: &Value, pred: F) -> Result<Value, QueryError>
1995 where
1996 F: FnOnce(std::cmp::Ordering) -> bool,
1997 {
1998 let ord = match (left, right) {
1999 (Value::Number(a), Value::Number(b)) => a.cmp(b),
2000 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
2001 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
2002 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
2003 (Value::String(a), Value::String(b)) => a.cmp(b),
2004 (Value::Date(a), Value::Date(b)) => a.cmp(b),
2005 _ => return Err(QueryError::Type("cannot compare values".to_string())),
2006 };
2007 Ok(Value::Boolean(pred(ord)))
2008 }
2009
2010 fn value_less_than(&self, left: &Value, right: &Value) -> Result<bool, QueryError> {
2012 let ord = match (left, right) {
2013 (Value::Number(a), Value::Number(b)) => a.cmp(b),
2014 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
2015 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
2016 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
2017 (Value::String(a), Value::String(b)) => a.cmp(b),
2018 (Value::Date(a), Value::Date(b)) => a.cmp(b),
2019 _ => return Err(QueryError::Type("cannot compare values".to_string())),
2020 };
2021 Ok(ord.is_lt())
2022 }
2023
2024 fn arithmetic_op<F>(&self, left: &Value, right: &Value, op: F) -> Result<Value, QueryError>
2026 where
2027 F: FnOnce(Decimal, Decimal) -> Decimal,
2028 {
2029 let (a, b) = match (left, right) {
2030 (Value::Number(a), Value::Number(b)) => (*a, *b),
2031 (Value::Integer(a), Value::Integer(b)) => (Decimal::from(*a), Decimal::from(*b)),
2032 (Value::Number(a), Value::Integer(b)) => (*a, Decimal::from(*b)),
2033 (Value::Integer(a), Value::Number(b)) => (Decimal::from(*a), *b),
2034 _ => {
2035 return Err(QueryError::Type(
2036 "arithmetic requires numeric values".to_string(),
2037 ));
2038 }
2039 };
2040 Ok(Value::Number(op(a, b)))
2041 }
2042
2043 fn to_bool(&self, val: &Value) -> Result<bool, QueryError> {
2045 match val {
2046 Value::Boolean(b) => Ok(*b),
2047 Value::Null => Ok(false),
2048 _ => Err(QueryError::Type("expected boolean".to_string())),
2049 }
2050 }
2051
2052 fn is_aggregate_expr(expr: &Expr) -> bool {
2054 match expr {
2055 Expr::Function(func) => {
2056 matches!(
2057 func.name.to_uppercase().as_str(),
2058 "SUM" | "COUNT" | "MIN" | "MAX" | "FIRST" | "LAST" | "AVG"
2059 )
2060 }
2061 Expr::BinaryOp(op) => {
2062 Self::is_aggregate_expr(&op.left) || Self::is_aggregate_expr(&op.right)
2063 }
2064 Expr::UnaryOp(op) => Self::is_aggregate_expr(&op.operand),
2065 Expr::Paren(inner) => Self::is_aggregate_expr(inner),
2066 _ => false,
2067 }
2068 }
2069
2070 const fn is_window_expr(expr: &Expr) -> bool {
2072 matches!(expr, Expr::Window(_))
2073 }
2074
2075 fn has_window_functions(targets: &[Target]) -> bool {
2077 targets.iter().any(|t| Self::is_window_expr(&t.expr))
2078 }
2079
2080 fn resolve_column_names(&self, targets: &[Target]) -> Result<Vec<String>, QueryError> {
2082 let mut names = Vec::new();
2083 for (i, target) in targets.iter().enumerate() {
2084 if let Some(alias) = &target.alias {
2085 names.push(alias.clone());
2086 } else {
2087 names.push(self.expr_to_name(&target.expr, i));
2088 }
2089 }
2090 Ok(names)
2091 }
2092
2093 fn expr_to_name(&self, expr: &Expr, index: usize) -> String {
2095 match expr {
2096 Expr::Wildcard => "*".to_string(),
2097 Expr::Column(name) => name.clone(),
2098 Expr::Function(func) => func.name.clone(),
2099 Expr::Window(wf) => wf.name.clone(),
2100 _ => format!("col{index}"),
2101 }
2102 }
2103
2104 fn evaluate_row(&self, targets: &[Target], ctx: &PostingContext) -> Result<Row, QueryError> {
2106 self.evaluate_row_with_window(targets, ctx, None)
2107 }
2108
2109 fn evaluate_row_with_window(
2111 &self,
2112 targets: &[Target],
2113 ctx: &PostingContext,
2114 window_ctx: Option<&WindowContext>,
2115 ) -> Result<Row, QueryError> {
2116 let mut row = Vec::new();
2117 for target in targets {
2118 if matches!(target.expr, Expr::Wildcard) {
2119 row.push(Value::Date(ctx.transaction.date));
2121 row.push(Value::String(ctx.transaction.flag.to_string()));
2122 row.push(
2123 ctx.transaction
2124 .payee
2125 .as_ref()
2126 .map_or(Value::Null, |p| Value::String(p.to_string())),
2127 );
2128 row.push(Value::String(ctx.transaction.narration.to_string()));
2129 let posting = &ctx.transaction.postings[ctx.posting_index];
2130 row.push(Value::String(posting.account.to_string()));
2131 row.push(
2132 posting
2133 .amount()
2134 .map_or(Value::Null, |u| Value::Amount(u.clone())),
2135 );
2136 } else if let Expr::Window(wf) = &target.expr {
2137 row.push(self.evaluate_window_function(wf, window_ctx)?);
2139 } else {
2140 row.push(self.evaluate_expr(&target.expr, ctx)?);
2141 }
2142 }
2143 Ok(row)
2144 }
2145
2146 fn evaluate_window_function(
2148 &self,
2149 wf: &WindowFunction,
2150 window_ctx: Option<&WindowContext>,
2151 ) -> Result<Value, QueryError> {
2152 let ctx = window_ctx.ok_or_else(|| {
2153 QueryError::Evaluation("Window function requires window context".to_string())
2154 })?;
2155
2156 match wf.name.to_uppercase().as_str() {
2157 "ROW_NUMBER" => Ok(Value::Integer(ctx.row_number as i64)),
2158 "RANK" => Ok(Value::Integer(ctx.rank as i64)),
2159 "DENSE_RANK" => Ok(Value::Integer(ctx.dense_rank as i64)),
2160 _ => Err(QueryError::Evaluation(format!(
2161 "Window function '{}' not yet implemented",
2162 wf.name
2163 ))),
2164 }
2165 }
2166
2167 fn compute_window_contexts(
2169 &self,
2170 postings: &[PostingContext],
2171 wf: &WindowFunction,
2172 ) -> Result<Vec<WindowContext>, QueryError> {
2173 let spec = &wf.over;
2174
2175 let mut partition_keys: Vec<String> = Vec::with_capacity(postings.len());
2177 for ctx in postings {
2178 if let Some(partition_exprs) = &spec.partition_by {
2179 let mut key_values = Vec::new();
2180 for expr in partition_exprs {
2181 key_values.push(self.evaluate_expr(expr, ctx)?);
2182 }
2183 partition_keys.push(Self::make_group_key(&key_values));
2184 } else {
2185 partition_keys.push(String::new());
2187 }
2188 }
2189
2190 let mut partitions: HashMap<String, Vec<usize>> = HashMap::new();
2192 for (idx, key) in partition_keys.iter().enumerate() {
2193 partitions.entry(key.clone()).or_default().push(idx);
2194 }
2195
2196 let mut order_values: Vec<Vec<Value>> = Vec::with_capacity(postings.len());
2198 for ctx in postings {
2199 if let Some(order_specs) = &spec.order_by {
2200 let mut values = Vec::new();
2201 for order_spec in order_specs {
2202 values.push(self.evaluate_expr(&order_spec.expr, ctx)?);
2203 }
2204 order_values.push(values);
2205 } else {
2206 order_values.push(Vec::new());
2207 }
2208 }
2209
2210 let mut window_contexts: Vec<WindowContext> = vec![
2212 WindowContext {
2213 row_number: 0,
2214 rank: 0,
2215 dense_rank: 0,
2216 };
2217 postings.len()
2218 ];
2219
2220 for indices in partitions.values() {
2222 let mut sorted_indices: Vec<usize> = indices.clone();
2224 if let Some(order_specs) = &spec.order_by {
2225 sorted_indices.sort_by(|&a, &b| {
2226 let vals_a = &order_values[a];
2227 let vals_b = &order_values[b];
2228 for (i, (va, vb)) in vals_a.iter().zip(vals_b.iter()).enumerate() {
2229 let cmp = self.compare_values_for_sort(va, vb);
2230 if cmp != std::cmp::Ordering::Equal {
2231 return if order_specs
2232 .get(i)
2233 .is_some_and(|s| s.direction == SortDirection::Desc)
2234 {
2235 cmp.reverse()
2236 } else {
2237 cmp
2238 };
2239 }
2240 }
2241 std::cmp::Ordering::Equal
2242 });
2243 }
2244
2245 let mut row_num = 1;
2247 let mut rank = 1;
2248 let mut dense_rank = 1;
2249 let mut prev_values: Option<&Vec<Value>> = None;
2250
2251 for (position, &original_idx) in sorted_indices.iter().enumerate() {
2252 let current_values = &order_values[original_idx];
2253
2254 let is_tie = if let Some(prev) = prev_values {
2256 current_values == prev
2257 } else {
2258 false
2259 };
2260
2261 if !is_tie && position > 0 {
2262 rank = position + 1;
2264 dense_rank += 1;
2265 }
2266 window_contexts[original_idx] = WindowContext {
2267 row_number: row_num,
2268 rank,
2269 dense_rank,
2270 };
2271
2272 row_num += 1;
2273 prev_values = Some(current_values);
2274 }
2275 }
2276
2277 Ok(window_contexts)
2278 }
2279
2280 fn find_window_function(targets: &[Target]) -> Option<&WindowFunction> {
2282 for target in targets {
2283 if let Expr::Window(wf) = &target.expr {
2284 return Some(wf);
2285 }
2286 }
2287 None
2288 }
2289
2290 fn make_group_key(values: &[Value]) -> String {
2293 use std::fmt::Write;
2294 let mut key = String::new();
2295 for (i, v) in values.iter().enumerate() {
2296 if i > 0 {
2297 key.push('\x00'); }
2299 match v {
2300 Value::String(s) => {
2301 key.push('S');
2302 key.push_str(s);
2303 }
2304 Value::Number(n) => {
2305 key.push('N');
2306 let _ = write!(key, "{n}");
2307 }
2308 Value::Integer(n) => {
2309 key.push('I');
2310 let _ = write!(key, "{n}");
2311 }
2312 Value::Date(d) => {
2313 key.push('D');
2314 let _ = write!(key, "{d}");
2315 }
2316 Value::Boolean(b) => {
2317 key.push(if *b { 'T' } else { 'F' });
2318 }
2319 Value::Amount(a) => {
2320 key.push('A');
2321 let _ = write!(key, "{} {}", a.number, a.currency);
2322 }
2323 Value::Position(p) => {
2324 key.push('P');
2325 let _ = write!(key, "{} {}", p.units.number, p.units.currency);
2326 }
2327 Value::Inventory(_) => {
2328 key.push('V');
2331 }
2332 Value::StringSet(ss) => {
2333 key.push('Z');
2334 for s in ss {
2335 key.push_str(s);
2336 key.push(',');
2337 }
2338 }
2339 Value::Null => {
2340 key.push('0');
2341 }
2342 }
2343 }
2344 key
2345 }
2346
2347 fn group_postings<'b>(
2350 &self,
2351 postings: &'b [PostingContext<'a>],
2352 group_by: Option<&Vec<Expr>>,
2353 ) -> Result<Vec<(Vec<Value>, Vec<&'b PostingContext<'a>>)>, QueryError> {
2354 if let Some(group_exprs) = group_by {
2355 let mut group_map: HashMap<String, (Vec<Value>, Vec<&PostingContext<'a>>)> =
2357 HashMap::new();
2358
2359 for ctx in postings {
2360 let mut key_values = Vec::with_capacity(group_exprs.len());
2361 for expr in group_exprs {
2362 key_values.push(self.evaluate_expr(expr, ctx)?);
2363 }
2364 let hash_key = Self::make_group_key(&key_values);
2365
2366 group_map
2367 .entry(hash_key)
2368 .or_insert_with(|| (key_values, Vec::new()))
2369 .1
2370 .push(ctx);
2371 }
2372
2373 Ok(group_map.into_values().collect())
2374 } else {
2375 if postings.is_empty() {
2378 Ok(vec![])
2379 } else {
2380 Ok(vec![(Vec::new(), postings.iter().collect())])
2381 }
2382 }
2383 }
2384
2385 fn evaluate_aggregate_row(
2387 &self,
2388 targets: &[Target],
2389 group: &[&PostingContext],
2390 ) -> Result<Row, QueryError> {
2391 let mut row = Vec::new();
2392 for target in targets {
2393 row.push(self.evaluate_aggregate_expr(&target.expr, group)?);
2394 }
2395 Ok(row)
2396 }
2397
2398 fn evaluate_aggregate_expr(
2400 &self,
2401 expr: &Expr,
2402 group: &[&PostingContext],
2403 ) -> Result<Value, QueryError> {
2404 match expr {
2405 Expr::Function(func) => {
2406 match func.name.to_uppercase().as_str() {
2407 "COUNT" => {
2408 Ok(Value::Integer(group.len() as i64))
2410 }
2411 "SUM" => {
2412 if func.args.len() != 1 {
2413 return Err(QueryError::InvalidArguments(
2414 "SUM".to_string(),
2415 "expected 1 argument".to_string(),
2416 ));
2417 }
2418 let mut total = Inventory::new();
2419 for ctx in group {
2420 let val = self.evaluate_expr(&func.args[0], ctx)?;
2421 match val {
2422 Value::Amount(amt) => {
2423 let pos = Position::simple(amt);
2424 total.add(pos);
2425 }
2426 Value::Position(pos) => {
2427 total.add(pos);
2428 }
2429 Value::Number(n) => {
2430 let pos =
2432 Position::simple(Amount::new(n, "__NUMBER__".to_string()));
2433 total.add(pos);
2434 }
2435 Value::Null => {}
2436 _ => {
2437 return Err(QueryError::Type(
2438 "SUM requires numeric or position value".to_string(),
2439 ));
2440 }
2441 }
2442 }
2443 Ok(Value::Inventory(total))
2444 }
2445 "FIRST" => {
2446 if func.args.len() != 1 {
2447 return Err(QueryError::InvalidArguments(
2448 "FIRST".to_string(),
2449 "expected 1 argument".to_string(),
2450 ));
2451 }
2452 if let Some(ctx) = group.iter().min_by_key(|c| c.transaction.date) {
2454 self.evaluate_expr(&func.args[0], ctx)
2455 } else {
2456 Ok(Value::Null)
2457 }
2458 }
2459 "LAST" => {
2460 if func.args.len() != 1 {
2461 return Err(QueryError::InvalidArguments(
2462 "LAST".to_string(),
2463 "expected 1 argument".to_string(),
2464 ));
2465 }
2466 if let Some(ctx) = group.iter().max_by_key(|c| c.transaction.date) {
2468 self.evaluate_expr(&func.args[0], ctx)
2469 } else {
2470 Ok(Value::Null)
2471 }
2472 }
2473 "MIN" => {
2474 if func.args.len() != 1 {
2475 return Err(QueryError::InvalidArguments(
2476 "MIN".to_string(),
2477 "expected 1 argument".to_string(),
2478 ));
2479 }
2480 let mut min_val: Option<Value> = None;
2481 for ctx in group {
2482 let val = self.evaluate_expr(&func.args[0], ctx)?;
2483 if matches!(val, Value::Null) {
2484 continue;
2485 }
2486 min_val = Some(match min_val {
2487 None => val,
2488 Some(current) => {
2489 if self.value_less_than(&val, ¤t)? {
2490 val
2491 } else {
2492 current
2493 }
2494 }
2495 });
2496 }
2497 Ok(min_val.unwrap_or(Value::Null))
2498 }
2499 "MAX" => {
2500 if func.args.len() != 1 {
2501 return Err(QueryError::InvalidArguments(
2502 "MAX".to_string(),
2503 "expected 1 argument".to_string(),
2504 ));
2505 }
2506 let mut max_val: Option<Value> = None;
2507 for ctx in group {
2508 let val = self.evaluate_expr(&func.args[0], ctx)?;
2509 if matches!(val, Value::Null) {
2510 continue;
2511 }
2512 max_val = Some(match max_val {
2513 None => val,
2514 Some(current) => {
2515 if self.value_less_than(¤t, &val)? {
2516 val
2517 } else {
2518 current
2519 }
2520 }
2521 });
2522 }
2523 Ok(max_val.unwrap_or(Value::Null))
2524 }
2525 "AVG" => {
2526 if func.args.len() != 1 {
2527 return Err(QueryError::InvalidArguments(
2528 "AVG".to_string(),
2529 "expected 1 argument".to_string(),
2530 ));
2531 }
2532 let mut sum = Decimal::ZERO;
2533 let mut count = 0i64;
2534 for ctx in group {
2535 let val = self.evaluate_expr(&func.args[0], ctx)?;
2536 match val {
2537 Value::Number(n) => {
2538 sum += n;
2539 count += 1;
2540 }
2541 Value::Integer(i) => {
2542 sum += Decimal::from(i);
2543 count += 1;
2544 }
2545 Value::Null => {}
2546 _ => {
2547 return Err(QueryError::Type(
2548 "AVG expects numeric values".to_string(),
2549 ));
2550 }
2551 }
2552 }
2553 if count == 0 {
2554 Ok(Value::Null)
2555 } else {
2556 Ok(Value::Number(sum / Decimal::from(count)))
2557 }
2558 }
2559 _ => {
2560 if let Some(ctx) = group.first() {
2562 self.evaluate_function(func, ctx)
2563 } else {
2564 Ok(Value::Null)
2565 }
2566 }
2567 }
2568 }
2569 Expr::Column(_) => {
2570 if let Some(ctx) = group.first() {
2572 self.evaluate_expr(expr, ctx)
2573 } else {
2574 Ok(Value::Null)
2575 }
2576 }
2577 Expr::BinaryOp(op) => {
2578 let left = self.evaluate_aggregate_expr(&op.left, group)?;
2579 let right = self.evaluate_aggregate_expr(&op.right, group)?;
2580 self.binary_op_on_values(op.op, &left, &right)
2582 }
2583 _ => {
2584 if let Some(ctx) = group.first() {
2586 self.evaluate_expr(expr, ctx)
2587 } else {
2588 Ok(Value::Null)
2589 }
2590 }
2591 }
2592 }
2593
2594 fn binary_op_on_values(
2596 &self,
2597 op: BinaryOperator,
2598 left: &Value,
2599 right: &Value,
2600 ) -> Result<Value, QueryError> {
2601 match op {
2602 BinaryOperator::Eq => Ok(Value::Boolean(self.values_equal(left, right))),
2603 BinaryOperator::Ne => Ok(Value::Boolean(!self.values_equal(left, right))),
2604 BinaryOperator::Lt => self.compare_values(left, right, std::cmp::Ordering::is_lt),
2605 BinaryOperator::Le => self.compare_values(left, right, std::cmp::Ordering::is_le),
2606 BinaryOperator::Gt => self.compare_values(left, right, std::cmp::Ordering::is_gt),
2607 BinaryOperator::Ge => self.compare_values(left, right, std::cmp::Ordering::is_ge),
2608 BinaryOperator::And => {
2609 let l = self.to_bool(left)?;
2610 let r = self.to_bool(right)?;
2611 Ok(Value::Boolean(l && r))
2612 }
2613 BinaryOperator::Or => {
2614 let l = self.to_bool(left)?;
2615 let r = self.to_bool(right)?;
2616 Ok(Value::Boolean(l || r))
2617 }
2618 BinaryOperator::Regex => {
2619 let s = match left {
2621 Value::String(s) => s,
2622 _ => {
2623 return Err(QueryError::Type(
2624 "regex requires string left operand".to_string(),
2625 ));
2626 }
2627 };
2628 let pattern = match right {
2629 Value::String(p) => p,
2630 _ => {
2631 return Err(QueryError::Type(
2632 "regex requires string pattern".to_string(),
2633 ));
2634 }
2635 };
2636 let regex_result = self.get_or_compile_regex(pattern);
2638 let matches = if let Some(regex) = regex_result {
2639 regex.is_match(s)
2640 } else {
2641 s.contains(pattern)
2642 };
2643 Ok(Value::Boolean(matches))
2644 }
2645 BinaryOperator::In => {
2646 match right {
2648 Value::StringSet(set) => {
2649 let needle = match left {
2650 Value::String(s) => s,
2651 _ => {
2652 return Err(QueryError::Type(
2653 "IN requires string left operand".to_string(),
2654 ));
2655 }
2656 };
2657 Ok(Value::Boolean(set.contains(needle)))
2658 }
2659 _ => Err(QueryError::Type(
2660 "IN requires set right operand".to_string(),
2661 )),
2662 }
2663 }
2664 BinaryOperator::Add => self.arithmetic_op(left, right, |a, b| a + b),
2665 BinaryOperator::Sub => self.arithmetic_op(left, right, |a, b| a - b),
2666 BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
2667 BinaryOperator::Div => self.arithmetic_op(left, right, |a, b| a / b),
2668 }
2669 }
2670
2671 fn sort_results(
2673 &self,
2674 result: &mut QueryResult,
2675 order_by: &[OrderSpec],
2676 ) -> Result<(), QueryError> {
2677 if order_by.is_empty() {
2678 return Ok(());
2679 }
2680
2681 let column_indices: std::collections::HashMap<&str, usize> = result
2683 .columns
2684 .iter()
2685 .enumerate()
2686 .map(|(i, name)| (name.as_str(), i))
2687 .collect();
2688
2689 let mut sort_specs: Vec<(usize, bool)> = Vec::new();
2691 for spec in order_by {
2692 let idx = match &spec.expr {
2694 Expr::Column(name) => column_indices
2695 .get(name.as_str())
2696 .copied()
2697 .ok_or_else(|| QueryError::UnknownColumn(name.clone()))?,
2698 Expr::Function(func) => {
2699 column_indices
2701 .get(func.name.as_str())
2702 .copied()
2703 .ok_or_else(|| {
2704 QueryError::Evaluation(format!(
2705 "ORDER BY expression not found in SELECT: {}",
2706 func.name
2707 ))
2708 })?
2709 }
2710 _ => {
2711 return Err(QueryError::Evaluation(
2712 "ORDER BY expression must reference a selected column".to_string(),
2713 ));
2714 }
2715 };
2716 let ascending = spec.direction != SortDirection::Desc;
2717 sort_specs.push((idx, ascending));
2718 }
2719
2720 result.rows.sort_by(|a, b| {
2722 for (idx, ascending) in &sort_specs {
2723 if *idx >= a.len() || *idx >= b.len() {
2724 continue;
2725 }
2726 let ord = self.compare_values_for_sort(&a[*idx], &b[*idx]);
2727 if ord != std::cmp::Ordering::Equal {
2728 return if *ascending { ord } else { ord.reverse() };
2729 }
2730 }
2731 std::cmp::Ordering::Equal
2732 });
2733
2734 Ok(())
2735 }
2736
2737 fn compare_values_for_sort(&self, left: &Value, right: &Value) -> std::cmp::Ordering {
2739 match (left, right) {
2740 (Value::Null, Value::Null) => std::cmp::Ordering::Equal,
2741 (Value::Null, _) => std::cmp::Ordering::Greater, (_, Value::Null) => std::cmp::Ordering::Less,
2743 (Value::Number(a), Value::Number(b)) => a.cmp(b),
2744 (Value::Integer(a), Value::Integer(b)) => a.cmp(b),
2745 (Value::Number(a), Value::Integer(b)) => a.cmp(&Decimal::from(*b)),
2746 (Value::Integer(a), Value::Number(b)) => Decimal::from(*a).cmp(b),
2747 (Value::String(a), Value::String(b)) => a.cmp(b),
2748 (Value::Date(a), Value::Date(b)) => a.cmp(b),
2749 (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
2750 (Value::Amount(a), Value::Amount(b)) => a.number.cmp(&b.number),
2752 (Value::Position(a), Value::Position(b)) => a.units.number.cmp(&b.units.number),
2754 (Value::Inventory(a), Value::Inventory(b)) => {
2756 let a_val = a.positions().first().map(|p| &p.units.number);
2757 let b_val = b.positions().first().map(|p| &p.units.number);
2758 match (a_val, b_val) {
2759 (Some(av), Some(bv)) => av.cmp(bv),
2760 (Some(_), None) => std::cmp::Ordering::Less,
2761 (None, Some(_)) => std::cmp::Ordering::Greater,
2762 (None, None) => std::cmp::Ordering::Equal,
2763 }
2764 }
2765 _ => std::cmp::Ordering::Equal, }
2767 }
2768
2769 fn evaluate_having_filter(
2775 &self,
2776 having_expr: &Expr,
2777 row: &[Value],
2778 column_names: &[String],
2779 targets: &[Target],
2780 group: &[&PostingContext],
2781 ) -> Result<bool, QueryError> {
2782 let col_map: HashMap<String, usize> = column_names
2784 .iter()
2785 .enumerate()
2786 .map(|(i, name)| (name.to_uppercase(), i))
2787 .collect();
2788
2789 let alias_map: HashMap<String, usize> = targets
2791 .iter()
2792 .enumerate()
2793 .filter_map(|(i, t)| t.alias.as_ref().map(|a| (a.to_uppercase(), i)))
2794 .collect();
2795
2796 let val = self.evaluate_having_expr(having_expr, row, &col_map, &alias_map, group)?;
2797
2798 match val {
2799 Value::Boolean(b) => Ok(b),
2800 Value::Null => Ok(false), _ => Err(QueryError::Type(
2802 "HAVING clause must evaluate to boolean".to_string(),
2803 )),
2804 }
2805 }
2806
2807 fn evaluate_having_expr(
2809 &self,
2810 expr: &Expr,
2811 row: &[Value],
2812 col_map: &HashMap<String, usize>,
2813 alias_map: &HashMap<String, usize>,
2814 group: &[&PostingContext],
2815 ) -> Result<Value, QueryError> {
2816 match expr {
2817 Expr::Column(name) => {
2818 let upper_name = name.to_uppercase();
2819 if let Some(&idx) = alias_map.get(&upper_name) {
2821 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2822 } else if let Some(&idx) = col_map.get(&upper_name) {
2823 Ok(row.get(idx).cloned().unwrap_or(Value::Null))
2824 } else {
2825 Err(QueryError::Evaluation(format!(
2826 "Column '{name}' not found in SELECT clause for HAVING"
2827 )))
2828 }
2829 }
2830 Expr::Literal(lit) => self.evaluate_literal(lit),
2831 Expr::Function(_) => {
2832 self.evaluate_aggregate_expr(expr, group)
2834 }
2835 Expr::BinaryOp(op) => {
2836 let left = self.evaluate_having_expr(&op.left, row, col_map, alias_map, group)?;
2837 let right = self.evaluate_having_expr(&op.right, row, col_map, alias_map, group)?;
2838 self.binary_op_on_values(op.op, &left, &right)
2839 }
2840 Expr::UnaryOp(op) => {
2841 let val = self.evaluate_having_expr(&op.operand, row, col_map, alias_map, group)?;
2842 match op.op {
2843 UnaryOperator::Not => {
2844 let b = self.to_bool(&val)?;
2845 Ok(Value::Boolean(!b))
2846 }
2847 UnaryOperator::Neg => match val {
2848 Value::Number(n) => Ok(Value::Number(-n)),
2849 Value::Integer(i) => Ok(Value::Integer(-i)),
2850 _ => Err(QueryError::Type(
2851 "Cannot negate non-numeric value".to_string(),
2852 )),
2853 },
2854 }
2855 }
2856 Expr::Paren(inner) => self.evaluate_having_expr(inner, row, col_map, alias_map, group),
2857 Expr::Wildcard => Err(QueryError::Evaluation(
2858 "Wildcard not allowed in HAVING clause".to_string(),
2859 )),
2860 Expr::Window(_) => Err(QueryError::Evaluation(
2861 "Window functions not allowed in HAVING clause".to_string(),
2862 )),
2863 }
2864 }
2865
2866 fn apply_pivot(
2872 &self,
2873 result: &QueryResult,
2874 pivot_exprs: &[Expr],
2875 _targets: &[Target],
2876 ) -> Result<QueryResult, QueryError> {
2877 if pivot_exprs.is_empty() {
2878 return Ok(result.clone());
2879 }
2880
2881 let pivot_expr = &pivot_exprs[0];
2884
2885 let pivot_col_idx = self.find_pivot_column(result, pivot_expr)?;
2887
2888 let mut pivot_values: Vec<Value> = result
2890 .rows
2891 .iter()
2892 .map(|row| row.get(pivot_col_idx).cloned().unwrap_or(Value::Null))
2893 .collect();
2894 pivot_values.sort_by(|a, b| self.compare_values_for_sort(a, b));
2895 pivot_values.dedup();
2896
2897 let mut new_columns: Vec<String> = result
2899 .columns
2900 .iter()
2901 .enumerate()
2902 .filter(|(i, _)| *i != pivot_col_idx)
2903 .map(|(_, c)| c.clone())
2904 .collect();
2905
2906 let value_col_idx = result.columns.len() - 1;
2908
2909 for pv in &pivot_values {
2911 new_columns.push(self.value_to_string(pv));
2912 }
2913
2914 let mut new_result = QueryResult::new(new_columns);
2915
2916 let group_cols: Vec<usize> = (0..result.columns.len())
2918 .filter(|i| *i != pivot_col_idx && *i != value_col_idx)
2919 .collect();
2920
2921 let mut groups: HashMap<String, Vec<&Row>> = HashMap::new();
2922 for row in &result.rows {
2923 let key: String = group_cols
2924 .iter()
2925 .map(|&i| self.value_to_string(&row[i]))
2926 .collect::<Vec<_>>()
2927 .join("|");
2928 groups.entry(key).or_default().push(row);
2929 }
2930
2931 for (_key, group_rows) in groups {
2933 let mut new_row: Vec<Value> = group_cols
2934 .iter()
2935 .map(|&i| group_rows[0][i].clone())
2936 .collect();
2937
2938 let pivot_index: HashMap<u64, usize> = group_rows
2940 .iter()
2941 .enumerate()
2942 .filter_map(|(idx, row)| {
2943 row.get(pivot_col_idx).map(|v| (hash_single_value(v), idx))
2944 })
2945 .collect();
2946
2947 for pv in &pivot_values {
2949 let pv_hash = hash_single_value(pv);
2950 if let Some(&row_idx) = pivot_index.get(&pv_hash) {
2951 new_row.push(
2952 group_rows[row_idx]
2953 .get(value_col_idx)
2954 .cloned()
2955 .unwrap_or(Value::Null),
2956 );
2957 } else {
2958 new_row.push(Value::Null);
2959 }
2960 }
2961
2962 new_result.add_row(new_row);
2963 }
2964
2965 Ok(new_result)
2966 }
2967
2968 fn find_pivot_column(
2970 &self,
2971 result: &QueryResult,
2972 pivot_expr: &Expr,
2973 ) -> Result<usize, QueryError> {
2974 match pivot_expr {
2975 Expr::Column(name) => {
2976 let upper_name = name.to_uppercase();
2977 result
2978 .columns
2979 .iter()
2980 .position(|c| c.to_uppercase() == upper_name)
2981 .ok_or_else(|| {
2982 QueryError::Evaluation(format!(
2983 "PIVOT BY column '{name}' not found in SELECT"
2984 ))
2985 })
2986 }
2987 Expr::Literal(Literal::Integer(n)) => {
2988 let idx = (*n as usize).saturating_sub(1);
2989 if idx < result.columns.len() {
2990 Ok(idx)
2991 } else {
2992 Err(QueryError::Evaluation(format!(
2993 "PIVOT BY column index {n} out of range"
2994 )))
2995 }
2996 }
2997 Expr::Literal(Literal::Number(n)) => {
2998 use rust_decimal::prelude::ToPrimitive;
3000 let idx = n.to_usize().unwrap_or(0).saturating_sub(1);
3001 if idx < result.columns.len() {
3002 Ok(idx)
3003 } else {
3004 Err(QueryError::Evaluation(format!(
3005 "PIVOT BY column index {n} out of range"
3006 )))
3007 }
3008 }
3009 _ => {
3010 Err(QueryError::Evaluation(
3013 "PIVOT BY must reference a column name or index".to_string(),
3014 ))
3015 }
3016 }
3017 }
3018
3019 fn value_to_string(&self, val: &Value) -> String {
3021 match val {
3022 Value::String(s) => s.clone(),
3023 Value::Number(n) => n.to_string(),
3024 Value::Integer(i) => i.to_string(),
3025 Value::Date(d) => d.to_string(),
3026 Value::Boolean(b) => b.to_string(),
3027 Value::Amount(a) => format!("{} {}", a.number, a.currency),
3028 Value::Position(p) => p.to_string(),
3029 Value::Inventory(inv) => inv.to_string(),
3030 Value::StringSet(ss) => ss.join(" "),
3031 Value::Null => "NULL".to_string(),
3032 }
3033 }
3034}
3035
3036#[cfg(test)]
3037mod tests {
3038 use super::*;
3039 use crate::parse;
3040 use rust_decimal_macros::dec;
3041 use rustledger_core::Posting;
3042
3043 fn date(year: i32, month: u32, day: u32) -> NaiveDate {
3044 NaiveDate::from_ymd_opt(year, month, day).unwrap()
3045 }
3046
3047 fn sample_directives() -> Vec<Directive> {
3048 vec![
3049 Directive::Transaction(
3050 Transaction::new(date(2024, 1, 15), "Coffee")
3051 .with_flag('*')
3052 .with_payee("Coffee Shop")
3053 .with_posting(Posting::new(
3054 "Expenses:Food:Coffee",
3055 Amount::new(dec!(5.00), "USD"),
3056 ))
3057 .with_posting(Posting::new(
3058 "Assets:Bank:Checking",
3059 Amount::new(dec!(-5.00), "USD"),
3060 )),
3061 ),
3062 Directive::Transaction(
3063 Transaction::new(date(2024, 1, 16), "Groceries")
3064 .with_flag('*')
3065 .with_payee("Supermarket")
3066 .with_posting(Posting::new(
3067 "Expenses:Food:Groceries",
3068 Amount::new(dec!(50.00), "USD"),
3069 ))
3070 .with_posting(Posting::new(
3071 "Assets:Bank:Checking",
3072 Amount::new(dec!(-50.00), "USD"),
3073 )),
3074 ),
3075 ]
3076 }
3077
3078 #[test]
3079 fn test_simple_select() {
3080 let directives = sample_directives();
3081 let mut executor = Executor::new(&directives);
3082
3083 let query = parse("SELECT date, account").unwrap();
3084 let result = executor.execute(&query).unwrap();
3085
3086 assert_eq!(result.columns, vec!["date", "account"]);
3087 assert_eq!(result.len(), 4); }
3089
3090 #[test]
3091 fn test_where_clause() {
3092 let directives = sample_directives();
3093 let mut executor = Executor::new(&directives);
3094
3095 let query = parse("SELECT account WHERE account ~ \"Expenses:\"").unwrap();
3096 let result = executor.execute(&query).unwrap();
3097
3098 assert_eq!(result.len(), 2); }
3100
3101 #[test]
3102 fn test_balances() {
3103 let directives = sample_directives();
3104 let mut executor = Executor::new(&directives);
3105
3106 let query = parse("BALANCES").unwrap();
3107 let result = executor.execute(&query).unwrap();
3108
3109 assert_eq!(result.columns, vec!["account", "balance"]);
3110 assert!(result.len() >= 3); }
3112
3113 #[test]
3114 fn test_account_functions() {
3115 let directives = sample_directives();
3116 let mut executor = Executor::new(&directives);
3117
3118 let query = parse("SELECT DISTINCT LEAF(account) WHERE account ~ \"Expenses:\"").unwrap();
3120 let result = executor.execute(&query).unwrap();
3121 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT ROOT(account)").unwrap();
3125 let result = executor.execute(&query).unwrap();
3126 assert_eq!(result.len(), 2); let query = parse("SELECT DISTINCT PARENT(account) WHERE account ~ \"Expenses:\"").unwrap();
3130 let result = executor.execute(&query).unwrap();
3131 assert!(!result.is_empty()); }
3133
3134 #[test]
3135 fn test_min_max_aggregate() {
3136 let directives = sample_directives();
3137 let mut executor = Executor::new(&directives);
3138
3139 let query = parse("SELECT MIN(date)").unwrap();
3141 let result = executor.execute(&query).unwrap();
3142 assert_eq!(result.len(), 1);
3143 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 15)));
3144
3145 let query = parse("SELECT MAX(date)").unwrap();
3147 let result = executor.execute(&query).unwrap();
3148 assert_eq!(result.len(), 1);
3149 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3150 }
3151
3152 #[test]
3153 fn test_order_by() {
3154 let directives = sample_directives();
3155 let mut executor = Executor::new(&directives);
3156
3157 let query = parse("SELECT date, account ORDER BY date DESC").unwrap();
3158 let result = executor.execute(&query).unwrap();
3159
3160 assert_eq!(result.len(), 4);
3162 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3164 }
3165
3166 #[test]
3167 fn test_hash_value_all_variants() {
3168 use rustledger_core::{Cost, Inventory, Position};
3169
3170 let values = vec![
3172 Value::String("test".to_string()),
3173 Value::Number(dec!(123.45)),
3174 Value::Integer(42),
3175 Value::Date(date(2024, 1, 15)),
3176 Value::Boolean(true),
3177 Value::Boolean(false),
3178 Value::Amount(Amount::new(dec!(100), "USD")),
3179 Value::Position(Position::simple(Amount::new(dec!(10), "AAPL"))),
3180 Value::Position(Position::with_cost(
3181 Amount::new(dec!(10), "AAPL"),
3182 Cost::new(dec!(150), "USD"),
3183 )),
3184 Value::Inventory(Inventory::new()),
3185 Value::StringSet(vec!["tag1".to_string(), "tag2".to_string()]),
3186 Value::Null,
3187 ];
3188
3189 for value in &values {
3191 let hash = hash_single_value(value);
3192 assert!(hash != 0 || matches!(value, Value::Null));
3193 }
3194
3195 let hash1 = hash_single_value(&Value::String("a".to_string()));
3197 let hash2 = hash_single_value(&Value::String("b".to_string()));
3198 assert_ne!(hash1, hash2);
3199
3200 let hash3 = hash_single_value(&Value::Integer(42));
3202 let hash4 = hash_single_value(&Value::Integer(42));
3203 assert_eq!(hash3, hash4);
3204 }
3205
3206 #[test]
3207 fn test_hash_row_distinct() {
3208 let row1 = vec![Value::String("a".to_string()), Value::Integer(1)];
3210 let row2 = vec![Value::String("a".to_string()), Value::Integer(1)];
3211 let row3 = vec![Value::String("b".to_string()), Value::Integer(1)];
3212
3213 assert_eq!(hash_row(&row1), hash_row(&row2));
3214 assert_ne!(hash_row(&row1), hash_row(&row3));
3215 }
3216
3217 #[test]
3218 fn test_string_set_hash_order_independent() {
3219 let set1 = Value::StringSet(vec!["a".to_string(), "b".to_string(), "c".to_string()]);
3221 let set2 = Value::StringSet(vec!["c".to_string(), "a".to_string(), "b".to_string()]);
3222 let set3 = Value::StringSet(vec!["b".to_string(), "c".to_string(), "a".to_string()]);
3223
3224 let hash1 = hash_single_value(&set1);
3225 let hash2 = hash_single_value(&set2);
3226 let hash3 = hash_single_value(&set3);
3227
3228 assert_eq!(hash1, hash2);
3229 assert_eq!(hash2, hash3);
3230 }
3231
3232 #[test]
3233 fn test_inventory_hash_includes_cost() {
3234 use rustledger_core::{Cost, Inventory, Position};
3235
3236 let mut inv1 = Inventory::new();
3238 inv1.add(Position::with_cost(
3239 Amount::new(dec!(10), "AAPL"),
3240 Cost::new(dec!(100), "USD"),
3241 ));
3242
3243 let mut inv2 = Inventory::new();
3244 inv2.add(Position::with_cost(
3245 Amount::new(dec!(10), "AAPL"),
3246 Cost::new(dec!(200), "USD"),
3247 ));
3248
3249 let hash1 = hash_single_value(&Value::Inventory(inv1));
3250 let hash2 = hash_single_value(&Value::Inventory(inv2));
3251
3252 assert_ne!(hash1, hash2);
3253 }
3254
3255 #[test]
3256 fn test_distinct_deduplication() {
3257 let directives = sample_directives();
3258 let mut executor = Executor::new(&directives);
3259
3260 let query = parse("SELECT flag").unwrap();
3262 let result = executor.execute(&query).unwrap();
3263 assert_eq!(result.len(), 4); let query = parse("SELECT DISTINCT flag").unwrap();
3267 let result = executor.execute(&query).unwrap();
3268 assert_eq!(result.len(), 1); }
3270
3271 #[test]
3272 fn test_limit_clause() {
3273 let directives = sample_directives();
3274 let mut executor = Executor::new(&directives);
3275
3276 let query = parse("SELECT date, account LIMIT 2").unwrap();
3278 let result = executor.execute(&query).unwrap();
3279 assert_eq!(result.len(), 2);
3280
3281 let query = parse("SELECT date LIMIT 0").unwrap();
3283 let result = executor.execute(&query).unwrap();
3284 assert_eq!(result.len(), 0);
3285
3286 let query = parse("SELECT date LIMIT 100").unwrap();
3288 let result = executor.execute(&query).unwrap();
3289 assert_eq!(result.len(), 4);
3290 }
3291
3292 #[test]
3293 fn test_group_by_with_count() {
3294 let directives = sample_directives();
3295 let mut executor = Executor::new(&directives);
3296
3297 let query = parse("SELECT ROOT(account), COUNT(account) GROUP BY ROOT(account)").unwrap();
3299 let result = executor.execute(&query).unwrap();
3300
3301 assert_eq!(result.columns.len(), 2);
3302 assert_eq!(result.len(), 2);
3304 }
3305
3306 #[test]
3307 fn test_count_aggregate() {
3308 let directives = sample_directives();
3309 let mut executor = Executor::new(&directives);
3310
3311 let query = parse("SELECT COUNT(account)").unwrap();
3313 let result = executor.execute(&query).unwrap();
3314
3315 assert_eq!(result.len(), 1);
3316 assert_eq!(result.rows[0][0], Value::Integer(4));
3317
3318 let query = parse("SELECT ROOT(account), COUNT(account) GROUP BY ROOT(account)").unwrap();
3320 let result = executor.execute(&query).unwrap();
3321 assert_eq!(result.len(), 2); }
3323
3324 #[test]
3325 fn test_journal_query() {
3326 let directives = sample_directives();
3327 let mut executor = Executor::new(&directives);
3328
3329 let query = parse("JOURNAL \"Expenses\"").unwrap();
3331 let result = executor.execute(&query).unwrap();
3332
3333 assert!(result.columns.contains(&"account".to_string()));
3335 assert_eq!(result.len(), 2);
3337 }
3338
3339 #[test]
3340 fn test_print_query() {
3341 let directives = sample_directives();
3342 let mut executor = Executor::new(&directives);
3343
3344 let query = parse("PRINT").unwrap();
3346 let result = executor.execute(&query).unwrap();
3347
3348 assert_eq!(result.columns.len(), 1);
3350 assert_eq!(result.columns[0], "directive");
3351 assert_eq!(result.len(), 2);
3353 }
3354
3355 #[test]
3356 fn test_empty_directives() {
3357 let directives: Vec<Directive> = vec![];
3358 let mut executor = Executor::new(&directives);
3359
3360 let query = parse("SELECT date, account").unwrap();
3362 let result = executor.execute(&query).unwrap();
3363 assert!(result.is_empty());
3364
3365 let query = parse("BALANCES").unwrap();
3367 let result = executor.execute(&query).unwrap();
3368 assert!(result.is_empty());
3369 }
3370
3371 #[test]
3372 fn test_comparison_operators() {
3373 let directives = sample_directives();
3374 let mut executor = Executor::new(&directives);
3375
3376 let query = parse("SELECT date WHERE date < 2024-01-16").unwrap();
3378 let result = executor.execute(&query).unwrap();
3379 assert_eq!(result.len(), 2); let query = parse("SELECT date WHERE year > 2023").unwrap();
3383 let result = executor.execute(&query).unwrap();
3384 assert_eq!(result.len(), 4); let query = parse("SELECT account WHERE day = 15").unwrap();
3388 let result = executor.execute(&query).unwrap();
3389 assert_eq!(result.len(), 2); }
3391
3392 #[test]
3393 fn test_logical_operators() {
3394 let directives = sample_directives();
3395 let mut executor = Executor::new(&directives);
3396
3397 let query = parse("SELECT account WHERE account ~ \"Expenses\" AND day > 14").unwrap();
3399 let result = executor.execute(&query).unwrap();
3400 assert_eq!(result.len(), 2); let query = parse("SELECT account WHERE day = 15 OR day = 16").unwrap();
3404 let result = executor.execute(&query).unwrap();
3405 assert_eq!(result.len(), 4); }
3407
3408 #[test]
3409 fn test_arithmetic_expressions() {
3410 let directives = sample_directives();
3411 let mut executor = Executor::new(&directives);
3412
3413 let query = parse("SELECT -day WHERE day = 15").unwrap();
3415 let result = executor.execute(&query).unwrap();
3416 assert_eq!(result.len(), 2);
3417 for row in &result.rows {
3419 if let Value::Integer(n) = &row[0] {
3420 assert_eq!(*n, -15);
3421 }
3422 }
3423 }
3424
3425 #[test]
3426 fn test_first_last_aggregates() {
3427 let directives = sample_directives();
3428 let mut executor = Executor::new(&directives);
3429
3430 let query = parse("SELECT FIRST(date)").unwrap();
3432 let result = executor.execute(&query).unwrap();
3433 assert_eq!(result.len(), 1);
3434 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 15)));
3435
3436 let query = parse("SELECT LAST(date)").unwrap();
3438 let result = executor.execute(&query).unwrap();
3439 assert_eq!(result.len(), 1);
3440 assert_eq!(result.rows[0][0], Value::Date(date(2024, 1, 16)));
3441 }
3442
3443 #[test]
3444 fn test_wildcard_select() {
3445 let directives = sample_directives();
3446 let mut executor = Executor::new(&directives);
3447
3448 let query = parse("SELECT *").unwrap();
3450 let result = executor.execute(&query).unwrap();
3451
3452 assert_eq!(result.columns, vec!["*"]);
3454 assert_eq!(result.len(), 4);
3456 assert_eq!(result.rows[0].len(), 6); }
3458
3459 #[test]
3460 fn test_query_result_methods() {
3461 let mut result = QueryResult::new(vec!["col1".to_string(), "col2".to_string()]);
3462
3463 assert!(result.is_empty());
3465 assert_eq!(result.len(), 0);
3466
3467 result.add_row(vec![Value::Integer(1), Value::String("a".to_string())]);
3469 assert!(!result.is_empty());
3470 assert_eq!(result.len(), 1);
3471
3472 result.add_row(vec![Value::Integer(2), Value::String("b".to_string())]);
3473 assert_eq!(result.len(), 2);
3474 }
3475}