1use super::{AccountID, AccountStatus, AccountingError, ProgramState};
2use chrono::NaiveDate;
3use commodity::exchange_rate::ExchangeRate;
4use commodity::Commodity;
5use rust_decimal::{prelude::Zero, Decimal};
6use std::fmt;
7use std::rc::Rc;
8use std::{marker::PhantomData, slice};
9
10#[cfg(feature = "serde-support")]
11use serde::{Deserialize, Serialize};
12
13#[derive(PartialEq, Eq, Debug, PartialOrd, Ord, Hash, Clone)]
15pub enum ActionType {
16 EditAccountStatus,
23 BalanceAssertion,
27 Transaction,
30}
31
32impl ActionTypeFor<ActionType> for ActionTypeValue {
33 fn action_type(&self) -> ActionType {
34 match self {
35 ActionTypeValue::EditAccountStatus(_) => ActionType::EditAccountStatus,
36 ActionTypeValue::BalanceAssertion(_) => ActionType::BalanceAssertion,
37 ActionTypeValue::Transaction(_) => ActionType::Transaction,
38 }
39 }
40}
41
42impl ActionType {
43 pub fn iterator() -> slice::Iter<'static, ActionType> {
45 static ACTION_TYPES: [ActionType; 3] = [
46 ActionType::EditAccountStatus,
47 ActionType::BalanceAssertion,
48 ActionType::Transaction,
49 ];
50 ACTION_TYPES.iter()
51 }
52}
53
54pub trait ActionTypeValueEnum<AT> {
62 fn as_action(&self) -> &dyn Action<AT, Self>;
63}
64
65#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
68#[cfg_attr(feature = "serde-support", serde(tag = "type"))]
69#[derive(Debug, Clone, PartialEq)]
70pub enum ActionTypeValue {
71 EditAccountStatus(EditAccountStatus),
72 BalanceAssertion(BalanceAssertion),
73 Transaction(Transaction),
74}
75
76impl<AT> ActionTypeValueEnum<AT> for ActionTypeValue {
77 fn as_action(&self) -> &dyn Action<AT, ActionTypeValue> {
78 match self {
79 ActionTypeValue::EditAccountStatus(action) => action,
80 ActionTypeValue::BalanceAssertion(action) => action,
81 ActionTypeValue::Transaction(action) => action,
82 }
83 }
84}
85
86impl From<EditAccountStatus> for ActionTypeValue {
87 fn from(action: EditAccountStatus) -> Self {
88 ActionTypeValue::EditAccountStatus(action)
89 }
90}
91
92impl From<BalanceAssertion> for ActionTypeValue {
93 fn from(action: BalanceAssertion) -> Self {
94 ActionTypeValue::BalanceAssertion(action)
95 }
96}
97
98impl From<Transaction> for ActionTypeValue {
99 fn from(action: Transaction) -> Self {
100 ActionTypeValue::Transaction(action)
101 }
102}
103
104pub trait ActionTypeFor<AT> {
106 fn action_type(&self) -> AT;
108}
109
110pub trait Action<AT, ATV>: fmt::Display + fmt::Debug {
112 fn date(&self) -> NaiveDate;
114
115 fn perform(&self, program_state: &mut ProgramState<AT, ATV>) -> Result<(), AccountingError>;
117}
118
119pub struct ActionOrder<AT, ATV> {
136 action_value: Rc<ATV>,
137 action_type: PhantomData<AT>,
138}
139
140impl<AT, ATV> ActionOrder<AT, ATV> {
141 pub fn new(action_value: Rc<ATV>) -> Self {
142 Self {
143 action_value,
144 action_type: PhantomData::default(),
145 }
146 }
147}
148
149impl<AT, ATV> PartialEq for ActionOrder<AT, ATV>
150where
151 AT: PartialEq,
152 ATV: ActionTypeValueEnum<AT> + ActionTypeFor<AT>,
153{
154 fn eq(&self, other: &ActionOrder<AT, ATV>) -> bool {
155 let self_action = self.action_value.as_action();
156 let other_action = other.action_value.as_action();
157 self.action_value.action_type() == other.action_value.action_type()
158 && self_action.date() == other_action.date()
159 }
160}
161
162impl<AT, ATV> Eq for ActionOrder<AT, ATV>
163where
164 ATV: ActionTypeValueEnum<AT> + ActionTypeFor<AT>,
165 AT: PartialEq,
166{
167}
168
169impl<AT, ATV> PartialOrd for ActionOrder<AT, ATV>
170where
171 AT: Ord,
172 ATV: ActionTypeValueEnum<AT> + ActionTypeFor<AT>,
173{
174 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
175 let self_action = self.action_value.as_action();
176 let other_action = other.action_value.as_action();
177 self_action
178 .date()
179 .partial_cmp(&other_action.date())
180 .map(|date_order| {
181 date_order.then(
182 self.action_value
183 .action_type()
184 .cmp(&other.action_value.action_type()),
185 )
186 })
187 }
188}
189
190impl<AT, ATV> Ord for ActionOrder<AT, ATV>
191where
192 AT: Ord,
193 ATV: ActionTypeValueEnum<AT> + ActionTypeFor<AT>,
194{
195 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
196 let self_action = self.action_value.as_action();
197 let other_action = other.action_value.as_action();
198 self_action.date().cmp(&other_action.date()).then(
199 self.action_value
200 .action_type()
201 .cmp(&other.action_value.action_type()),
202 )
203 }
204}
205
206#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
215#[derive(Debug, Clone, PartialEq)]
216pub struct Transaction {
217 pub description: Option<String>,
219 pub date: NaiveDate,
221 pub elements: Vec<TransactionElement>,
226}
227
228impl Transaction {
229 pub fn new<S: Into<String>>(
231 description: Option<S>,
232 date: NaiveDate,
233 elements: Vec<TransactionElement>,
234 ) -> Transaction {
235 Transaction {
236 description: description.map(|s| s.into()),
237 date,
238 elements,
239 }
240 }
241
242 pub fn new_simple<S: Into<String>>(
280 description: Option<S>,
281 date: NaiveDate,
282 from_account: AccountID,
283 to_account: AccountID,
284 amount: Commodity,
285 exchange_rate: Option<ExchangeRate>,
286 ) -> Transaction {
287 Transaction::new(
288 description,
289 date,
290 vec![
291 TransactionElement::new(from_account, Some(amount.neg()), exchange_rate.clone()),
292 TransactionElement::new(to_account, None, exchange_rate),
293 ],
294 )
295 }
296
297 pub fn get_element(&self, account_id: &AccountID) -> Option<&TransactionElement> {
300 self.elements.iter().find(|e| &e.account_id == account_id)
301 }
302}
303
304impl ActionTypeFor<ActionType> for Transaction {
305 fn action_type(&self) -> ActionType {
306 todo!()
307 }
308}
309
310impl fmt::Display for Transaction {
311 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
312 write!(f, "Transaction")
313 }
314}
315
316impl<AT, ATV> Action<AT, ATV> for Transaction
317where
318 ATV: ActionTypeValueEnum<AT>,
319{
320 fn date(&self) -> NaiveDate {
321 self.date
322 }
323
324 fn perform(&self, program_state: &mut ProgramState<AT, ATV>) -> Result<(), AccountingError> {
325 if self.elements.len() < 2 {
327 return Err(AccountingError::InvalidTransaction(
328 self.clone(),
329 String::from("a transaction cannot have less than 2 elements"),
330 ));
331 }
332
333 let mut empty_amount_element: Option<usize> = None;
338 for (i, element) in self.elements.iter().enumerate() {
339 if element.amount.is_none() {
340 if empty_amount_element.is_none() {
341 empty_amount_element = Some(i)
342 } else {
343 return Err(AccountingError::InvalidTransaction(
344 self.clone(),
345 String::from("multiple elements with no amount specified"),
346 ));
347 }
348 }
349 }
350
351 let sum_commodity_type_id = match empty_amount_element {
352 Some(empty_i) => {
353 let empty_element = self.elements.get(empty_i).unwrap();
354
355 match program_state.get_account(&empty_element.account_id) {
356 Some(account) => account.commodity_type_id,
357 None => {
358 return Err(AccountingError::MissingAccountState(
359 empty_element.account_id,
360 ))
361 }
362 }
363 }
364 None => {
365 let account_id = self
366 .elements
367 .get(0)
368 .expect("there should be at least 2 elements in the transaction")
369 .account_id;
370
371 match program_state.get_account(&account_id) {
372 Some(account) => account.commodity_type_id,
373 None => return Err(AccountingError::MissingAccountState(account_id)),
374 }
375 }
376 };
377
378 let mut sum = Commodity::new(Decimal::zero(), sum_commodity_type_id);
379
380 let mut modified_elements = self.elements.clone();
381
382 for (i, element) in self.elements.iter().enumerate() {
384 if let Some(empty_i) = empty_amount_element {
385 if i != empty_i {
386 sum = match sum.add(&element.amount.as_ref().unwrap()) {
388 Ok(value) => value,
389 Err(error) => return Err(AccountingError::Commodity(error)),
390 }
391 }
392 }
393 }
394
395 if let Some(empty_i) = empty_amount_element {
397 let modified_emtpy_element: &mut TransactionElement =
398 modified_elements.get_mut(empty_i).unwrap();
399 let negated_sum = sum.neg();
400 modified_emtpy_element.amount = Some(negated_sum);
401
402 sum = match sum.add(&negated_sum) {
403 Ok(value) => value,
404 Err(error) => return Err(AccountingError::Commodity(error)),
405 }
406 }
407
408 if sum.value != Decimal::zero() {
409 return Err(AccountingError::InvalidTransaction(
410 self.clone(),
411 String::from("sum of transaction elements does not equal zero"),
412 ));
413 }
414
415 for transaction in &modified_elements {
416 let mut account_state = program_state
417 .get_account_state_mut(&transaction.account_id)
418 .unwrap_or_else(||
419 panic!(
420 "unable to find state for account with id: {} please ensure this account was added to the program state before execution.",
421 transaction.account_id
422 )
423 );
424
425 match account_state.status {
426 AccountStatus::Closed => Err(AccountingError::InvalidAccountStatus {
427 account_id: transaction.account_id,
428 status: account_state.status,
429 }),
430 _ => Ok(()),
431 }?;
432
433 let transaction_amount = match &transaction.amount {
436 Some(amount) => amount,
437 None => {
438 return Err(AccountingError::InvalidTransaction(
439 self.clone(),
440 String::from(
441 "unable to calculate all required amounts for this transaction",
442 ),
443 ))
444 }
445 };
446
447 account_state.amount = match account_state.amount.add(transaction_amount) {
448 Ok(commodity) => commodity,
449 Err(err) => {
450 return Err(AccountingError::Commodity(err));
451 }
452 }
453 }
454
455 Ok(())
456 }
457}
458
459#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
461#[derive(Debug, Clone, PartialEq)]
462pub struct TransactionElement {
463 pub account_id: AccountID,
465
466 pub amount: Option<Commodity>,
473
474 pub exchange_rate: Option<ExchangeRate>,
477}
478
479impl TransactionElement {
480 pub fn new(
482 account_id: AccountID,
483 amount: Option<Commodity>,
484 exchange_rate: Option<ExchangeRate>,
485 ) -> TransactionElement {
486 TransactionElement {
487 account_id,
488 amount,
489 exchange_rate,
490 }
491 }
492}
493
494#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
498#[derive(Debug, Clone, PartialEq)]
499pub struct EditAccountStatus {
500 account_id: AccountID,
501 newstatus: AccountStatus,
502 date: NaiveDate,
503}
504
505impl EditAccountStatus {
506 pub fn new(
508 account_id: AccountID,
509 newstatus: AccountStatus,
510 date: NaiveDate,
511 ) -> EditAccountStatus {
512 EditAccountStatus {
513 account_id,
514 newstatus,
515 date,
516 }
517 }
518}
519
520impl fmt::Display for EditAccountStatus {
521 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
522 write!(f, "Edit Account Status")
523 }
524}
525
526impl<AT, ATV> Action<AT, ATV> for EditAccountStatus
527where
528 ATV: ActionTypeValueEnum<AT>,
529{
530 fn date(&self) -> NaiveDate {
531 self.date
532 }
533
534 fn perform(&self, program_state: &mut ProgramState<AT, ATV>) -> Result<(), AccountingError> {
535 let mut account_state = program_state
536 .get_account_state_mut(&self.account_id)
537 .unwrap();
538 account_state.status = self.newstatus;
539 Ok(())
540 }
541}
542
543impl ActionTypeFor<ActionType> for EditAccountStatus {
544 fn action_type(&self) -> ActionType {
545 ActionType::EditAccountStatus
546 }
547}
548
549#[cfg_attr(feature = "serde-support", derive(Serialize, Deserialize))]
557#[derive(Debug, Clone, PartialEq)]
558pub struct BalanceAssertion {
559 account_id: AccountID,
560 date: NaiveDate,
561 expected_balance: Commodity,
562}
563
564impl BalanceAssertion {
565 pub fn new(
568 account_id: AccountID,
569 date: NaiveDate,
570 expected_balance: Commodity,
571 ) -> BalanceAssertion {
572 BalanceAssertion {
573 account_id,
574 date,
575 expected_balance,
576 }
577 }
578}
579
580impl fmt::Display for BalanceAssertion {
581 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
582 write!(f, "Assert Account Balance")
583 }
584}
585
586#[derive(Debug, Clone)]
590pub struct FailedBalanceAssertion {
591 pub assertion: BalanceAssertion,
592 pub actual_balance: Commodity,
593}
594
595impl FailedBalanceAssertion {
596 pub fn new(assertion: BalanceAssertion, actual_balance: Commodity) -> FailedBalanceAssertion {
598 FailedBalanceAssertion {
599 assertion,
600 actual_balance,
601 }
602 }
603}
604
605impl fmt::Display for FailedBalanceAssertion {
606 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
607 write!(f, "Failed Account Balance Assertion")
608 }
609}
610
611impl<AT, ATV> Action<AT, ATV> for BalanceAssertion
615where
616 ATV: ActionTypeValueEnum<AT>,
617{
618 fn date(&self) -> NaiveDate {
619 self.date
620 }
621
622 fn perform(&self, program_state: &mut ProgramState<AT, ATV>) -> Result<(), AccountingError> {
623 let failed_assertion = match program_state.get_account_state(&self.account_id) {
624 Some(state) => {
625 if !state
626 .amount
627 .eq_approx(self.expected_balance, Commodity::default_epsilon())
628 {
629 Some(FailedBalanceAssertion::new(self.clone(), state.amount))
630 } else {
631 None
632 }
633 }
634 None => {
635 return Err(AccountingError::MissingAccountState(self.account_id));
636 }
637 };
638
639 if let Some(failed_assertion) = failed_assertion {
640 program_state.record_failed_balance_assertion(failed_assertion)
641 }
642
643 Ok(())
644 }
645}
646
647impl ActionTypeFor<ActionType> for BalanceAssertion {
648 fn action_type(&self) -> ActionType {
649 ActionType::BalanceAssertion
650 }
651}
652
653#[cfg(test)]
654mod tests {
655 use super::ActionType;
656 use crate::{
657 Account, AccountStatus, AccountingError, ActionTypeValue, BalanceAssertion, Program,
658 ProgramState, Transaction,
659 };
660 use chrono::NaiveDate;
661 use commodity::{Commodity, CommodityType};
662 use rust_decimal::Decimal;
663 use std::{collections::HashSet, rc::Rc};
664
665 #[test]
666 fn action_type_order() {
667 let mut tested_types: HashSet<ActionType> = HashSet::new();
668
669 let mut action_types_unordered: Vec<ActionType> = vec![
670 ActionType::Transaction,
671 ActionType::EditAccountStatus,
672 ActionType::BalanceAssertion,
673 ActionType::EditAccountStatus,
674 ActionType::Transaction,
675 ActionType::BalanceAssertion,
676 ];
677
678 let num_action_types = ActionType::iterator().count();
679
680 action_types_unordered.iter().for_each(|action_type| {
681 tested_types.insert(action_type.clone());
682 });
683
684 assert_eq!(num_action_types, tested_types.len());
685
686 action_types_unordered.sort();
687
688 let action_types_ordered: Vec<ActionType> = vec![
689 ActionType::EditAccountStatus,
690 ActionType::EditAccountStatus,
691 ActionType::BalanceAssertion,
692 ActionType::BalanceAssertion,
693 ActionType::Transaction,
694 ActionType::Transaction,
695 ];
696
697 assert_eq!(action_types_ordered, action_types_unordered);
698 }
699
700 #[test]
701 fn balance_assertion() {
702 let aud = Rc::from(CommodityType::from_str("AUD", "Australian Dollar").unwrap());
703 let account1 = Rc::from(Account::new_with_id(Some("Account 1"), aud.id, None));
704 let account2 = Rc::from(Account::new_with_id(Some("Account 2"), aud.id, None));
705
706 let date_1 = NaiveDate::from_ymd(2020, 01, 01);
707 let date_2 = NaiveDate::from_ymd(2020, 01, 02);
708 let actions: Vec<Rc<ActionTypeValue>> = vec![
709 Rc::new(
710 Transaction::new_simple::<String>(
711 None,
712 date_1.clone(),
713 account1.id,
714 account2.id,
715 Commodity::new(Decimal::new(100, 2), &*aud),
716 None,
717 )
718 .into(),
719 ),
720 Rc::new(
723 BalanceAssertion::new(
724 account2.id,
725 date_1.clone(),
726 Commodity::new(Decimal::new(100, 2), &*aud),
727 )
728 .into(),
729 ),
730 Rc::new(
733 BalanceAssertion::new(
734 account2.id,
735 date_2.clone(),
736 Commodity::new(Decimal::new(100, 2), &*aud),
737 )
738 .into(),
739 ),
740 ];
741
742 let program = Program::new(actions);
743
744 let accounts = vec![account1, account2];
745 let mut program_state = ProgramState::new(&accounts, AccountStatus::Open);
746 match program_state.execute_program(&program) {
747 Err(AccountingError::BalanceAssertionFailed(failure)) => {
748 assert_eq!(
749 Commodity::new(Decimal::new(0, 2), &*aud),
750 failure.actual_balance
751 );
752 assert_eq!(date_1, failure.assertion.date);
753 }
754 _ => panic!("Expected an AccountingError:BalanceAssertionFailed"),
755 }
756
757 assert_eq!(1, program_state.failed_balance_assertions.len());
758 }
759}
760
761#[cfg(feature = "serde-support")]
762#[cfg(test)]
763mod serde_tests {
764 use super::{BalanceAssertion, EditAccountStatus, Transaction};
765 use crate::{AccountID, AccountStatus};
766 use chrono::NaiveDate;
767 use commodity::Commodity;
768 use std::str::FromStr;
769
770 #[test]
771 fn edit_account_status_serde() {
772 use serde_json;
773
774 let json = r#"{
775 "account_id": "TestAccount",
776 "newstatus": "Open",
777 "date": "2020-05-10"
778}"#;
779 let action: EditAccountStatus = serde_json::from_str(json).unwrap();
780
781 let reference_action = EditAccountStatus::new(
782 AccountID::from("TestAccount").unwrap(),
783 AccountStatus::Open,
784 NaiveDate::from_ymd(2020, 05, 10),
785 );
786
787 assert_eq!(action, reference_action);
788
789 insta::assert_json_snapshot!(action);
790 }
791
792 #[test]
793 fn balance_assertion_serde() {
794 use serde_json;
795
796 let json = r#"{
797 "account_id": "TestAccount",
798 "date": "2020-05-10",
799 "expected_balance": {
800 "value": "1.0",
801 "type_id": "AUD"
802 }
803}"#;
804 let action: BalanceAssertion = serde_json::from_str(json).unwrap();
805
806 let reference_action = BalanceAssertion::new(
807 AccountID::from("TestAccount").unwrap(),
808 NaiveDate::from_ymd(2020, 05, 10),
809 Commodity::from_str("1.0 AUD").unwrap(),
810 );
811
812 assert_eq!(action, reference_action);
813
814 insta::assert_json_snapshot!(action);
815 }
816
817 #[cfg(feature = "serde-support")]
818 #[test]
819 fn transaction_serde() {
820 use serde_json;
821
822 let json = r#"{
823 "description": "TestTransaction",
824 "date": "2020-05-10",
825 "elements": [
826 {
827 "account_id": "TestAccount1",
828 "amount": {
829 "value": "-1.0",
830 "type_id": "AUD"
831 }
832 },
833 {
834 "account_id": "TestAccount2"
835 }
836 ]
837}"#;
838 let action: Transaction = serde_json::from_str(json).unwrap();
839
840 let reference_action = Transaction::new_simple(
841 Some("TestTransaction"),
842 NaiveDate::from_ymd(2020, 05, 10),
843 AccountID::from("TestAccount1").unwrap(),
844 AccountID::from("TestAccount2").unwrap(),
845 Commodity::from_str("1.0 AUD").unwrap(),
846 None,
847 );
848
849 assert_eq!(action, reference_action);
850
851 insta::assert_json_snapshot!(action);
852 }
853}