use regex::Regex;
use rust_decimal::Decimal;
use rustledger_core::NaiveDate;
use std::collections::HashSet;
use std::str::FromStr;
use std::sync::LazyLock;
use crate::types::{
AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
TransactionData,
};
use super::super::NativePlugin;
static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
.expect("CONFIG_ENTRY_RE: invalid regex pattern")
});
pub struct CapitalGainsLongShortPlugin;
pub struct CapitalGainsGainLossPlugin;
impl NativePlugin for CapitalGainsLongShortPlugin {
fn name(&self) -> &'static str {
"long_short"
}
fn description(&self) -> &'static str {
"Classify capital gains into long-term vs short-term based on holding period"
}
fn process(&self, input: PluginInput) -> PluginOutput {
process_long_short(input)
}
}
impl NativePlugin for CapitalGainsGainLossPlugin {
fn name(&self) -> &'static str {
"gain_loss"
}
fn description(&self) -> &'static str {
"Classify capital gains into gains vs losses based on posting amount"
}
fn process(&self, input: PluginInput) -> PluginOutput {
process_gain_loss(input)
}
}
struct LongShortConfig {
pattern: Regex,
account_to_replace: String,
short_replacement: String,
long_replacement: String,
}
struct GainLossConfig {
pattern: Regex,
account_to_replace: String,
gains_replacement: String,
losses_replacement: String,
}
fn process_long_short(input: PluginInput) -> PluginOutput {
let config = match &input.config {
Some(c) => match parse_long_short_config(c) {
Some(cfg) => cfg,
None => {
return PluginOutput {
directives: input.directives,
errors: Vec::new(),
};
}
},
None => {
return PluginOutput {
directives: input.directives,
errors: Vec::new(),
};
}
};
let mut new_accounts: HashSet<String> = HashSet::new();
let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
for directive in input.directives {
if directive.directive_type != "transaction" {
new_directives.push(directive);
continue;
}
if let DirectiveData::Transaction(txn) = &directive.data {
let has_generic = txn
.postings
.iter()
.any(|p| config.pattern.is_match(&p.account));
let has_specific = txn.postings.iter().any(|p| {
p.account.contains(&config.short_replacement)
|| p.account.contains(&config.long_replacement)
});
if !has_generic || has_specific {
new_directives.push(directive);
continue;
}
let reductions: Vec<&PostingData> = txn
.postings
.iter()
.filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
.collect();
if reductions.is_empty() {
new_directives.push(directive);
continue;
}
let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
d
} else {
new_directives.push(directive);
continue;
};
let mut short_gains = Decimal::ZERO;
let mut long_gains = Decimal::ZERO;
for posting in &reductions {
if let (Some(cost), Some(units), Some(price)) =
(&posting.cost, &posting.units, &posting.price)
{
let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
if let Some(cost_date) = cost_date {
let cost_number = cost
.number_per
.as_ref()
.and_then(|n| Decimal::from_str(n).ok())
.unwrap_or(Decimal::ZERO);
let price_number = price
.amount
.as_ref()
.and_then(|a| Decimal::from_str(&a.number).ok())
.unwrap_or(Decimal::ZERO);
let units_number =
Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
let gain = (cost_number - price_number) * units_number.abs();
let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
let years_held = (days_held / 365) as u32;
let is_long_term = years_held > 1
|| (years_held == 1
&& (entry_date.month() > cost_date.month()
|| (entry_date.month() == cost_date.month()
&& entry_date.day() >= cost_date.day())));
if is_long_term {
long_gains += gain;
} else {
short_gains += gain;
}
}
}
}
let orig_postings: Vec<&PostingData> = txn
.postings
.iter()
.filter(|p| config.pattern.is_match(&p.account))
.collect();
if orig_postings.is_empty() {
new_directives.push(directive);
continue;
}
let orig_sum: Decimal = orig_postings
.iter()
.filter_map(|p| p.units.as_ref())
.filter_map(|u| Decimal::from_str(&u.number).ok())
.sum();
let diff = orig_sum - (short_gains + long_gains);
if diff.abs() > Decimal::new(1, 6) {
let total = short_gains + long_gains;
if total != Decimal::ZERO {
short_gains += (short_gains / total) * diff;
long_gains += (long_gains / total) * diff;
}
}
let mut new_postings: Vec<PostingData> = txn
.postings
.iter()
.filter(|p| !config.pattern.is_match(&p.account))
.cloned()
.collect();
let template = orig_postings[0];
if short_gains != Decimal::ZERO {
let new_account = template
.account
.replace(&config.account_to_replace, &config.short_replacement);
new_accounts.insert(new_account.clone());
new_postings.push(PostingData {
account: new_account,
units: template.units.as_ref().map(|u| AmountData {
number: format_decimal(short_gains),
currency: u.currency.clone(),
}),
cost: None,
price: None,
flag: template.flag.clone(),
metadata: vec![],
});
}
if long_gains != Decimal::ZERO {
let new_account = template
.account
.replace(&config.account_to_replace, &config.long_replacement);
new_accounts.insert(new_account.clone());
new_postings.push(PostingData {
account: new_account,
units: template.units.as_ref().map(|u| AmountData {
number: format_decimal(long_gains),
currency: u.currency.clone(),
}),
cost: None,
price: None,
flag: template.flag.clone(),
metadata: vec![],
});
}
new_directives.push(DirectiveWrapper {
directive_type: "transaction".to_string(),
date: directive.date.clone(),
filename: directive.filename.clone(),
lineno: directive.lineno,
data: DirectiveData::Transaction(TransactionData {
flag: txn.flag.clone(),
payee: txn.payee.clone(),
narration: txn.narration.clone(),
tags: txn.tags.clone(),
links: txn.links.clone(),
metadata: txn.metadata.clone(),
postings: new_postings,
}),
});
} else {
new_directives.push(directive);
}
}
let earliest_date = new_directives
.iter()
.map(|d| d.date.as_str())
.min()
.unwrap_or("1970-01-01")
.to_string();
let mut open_directives: Vec<DirectiveWrapper> = new_accounts
.iter()
.map(|account| DirectiveWrapper {
directive_type: "open".to_string(),
date: earliest_date.clone(),
filename: Some("<long_short>".to_string()),
lineno: Some(0),
data: DirectiveData::Open(OpenData {
account: account.clone(),
currencies: vec![],
booking: None,
metadata: vec![],
}),
})
.collect();
open_directives.extend(new_directives);
PluginOutput {
directives: open_directives,
errors: Vec::new(),
}
}
fn process_gain_loss(input: PluginInput) -> PluginOutput {
let config = match &input.config {
Some(c) => match parse_gain_loss_config(c) {
Some(cfg) => cfg,
None => {
return PluginOutput {
directives: input.directives,
errors: Vec::new(),
};
}
},
None => {
return PluginOutput {
directives: input.directives,
errors: Vec::new(),
};
}
};
let mut new_accounts: HashSet<String> = HashSet::new();
let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
for directive in input.directives {
if directive.directive_type != "transaction" {
new_directives.push(directive);
continue;
}
if let DirectiveData::Transaction(txn) = &directive.data {
let mut modified = false;
let mut new_postings: Vec<PostingData> = Vec::new();
for posting in &txn.postings {
if config.pattern.is_match(&posting.account)
&& let Some(units) = &posting.units
&& let Ok(number) = Decimal::from_str(&units.number)
{
let new_account = if number < Decimal::ZERO {
posting
.account
.replace(&config.account_to_replace, &config.gains_replacement)
} else {
posting
.account
.replace(&config.account_to_replace, &config.losses_replacement)
};
new_accounts.insert(new_account.clone());
new_postings.push(PostingData {
account: new_account,
units: posting.units.clone(),
cost: posting.cost.clone(),
price: posting.price.clone(),
flag: posting.flag.clone(),
metadata: posting.metadata.clone(),
});
modified = true;
continue;
}
new_postings.push(posting.clone());
}
if modified {
new_directives.push(DirectiveWrapper {
directive_type: "transaction".to_string(),
date: directive.date.clone(),
filename: directive.filename.clone(),
lineno: directive.lineno,
data: DirectiveData::Transaction(TransactionData {
flag: txn.flag.clone(),
payee: txn.payee.clone(),
narration: txn.narration.clone(),
tags: txn.tags.clone(),
links: txn.links.clone(),
metadata: txn.metadata.clone(),
postings: new_postings,
}),
});
} else {
new_directives.push(directive);
}
} else {
new_directives.push(directive);
}
}
let earliest_date = new_directives
.iter()
.map(|d| d.date.as_str())
.min()
.unwrap_or("1970-01-01")
.to_string();
let mut open_directives: Vec<DirectiveWrapper> = new_accounts
.iter()
.map(|account| DirectiveWrapper {
directive_type: "open".to_string(),
date: earliest_date.clone(),
filename: Some("<gain_loss>".to_string()),
lineno: Some(0),
data: DirectiveData::Open(OpenData {
account: account.clone(),
currencies: vec![],
booking: None,
metadata: vec![],
}),
})
.collect();
open_directives.extend(new_directives);
PluginOutput {
directives: open_directives,
errors: Vec::new(),
}
}
fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
let cap = CONFIG_ENTRY_RE.captures(config)?;
let pattern = Regex::new(&cap[1]).ok()?;
let account_to_replace = cap[2].to_string();
let short_replacement = cap[3].to_string();
let long_replacement = cap[4].to_string();
Some(LongShortConfig {
pattern,
account_to_replace,
short_replacement,
long_replacement,
})
}
fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
let cap = CONFIG_ENTRY_RE.captures(config)?;
let pattern = Regex::new(&cap[1]).ok()?;
let account_to_replace = cap[2].to_string();
let gains_replacement = cap[3].to_string();
let losses_replacement = cap[4].to_string();
Some(GainLossConfig {
pattern,
account_to_replace,
gains_replacement,
losses_replacement,
})
}
fn format_decimal(d: Decimal) -> String {
let s = d.to_string();
if s.contains('.') {
s.trim_end_matches('0').trim_end_matches('.').to_string()
} else {
s
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::*;
#[test]
fn test_parse_long_short_config() {
let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
let parsed = parse_long_short_config(config);
assert!(parsed.is_some());
let cfg = parsed.unwrap();
assert_eq!(cfg.account_to_replace, ":Capital-Gains");
assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
}
#[test]
fn test_parse_gain_loss_config() {
let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
let parsed = parse_gain_loss_config(config);
assert!(parsed.is_some());
let cfg = parsed.unwrap();
assert_eq!(cfg.account_to_replace, ":Long");
assert_eq!(cfg.gains_replacement, ":Long:Gains");
assert_eq!(cfg.losses_replacement, ":Long:Losses");
}
#[test]
fn test_gain_loss_classification() {
let plugin = CapitalGainsGainLossPlugin;
let input = PluginInput {
directives: vec![DirectiveWrapper {
directive_type: "transaction".to_string(),
date: "2024-01-15".to_string(),
filename: None,
lineno: None,
data: DirectiveData::Transaction(TransactionData {
flag: "*".to_string(),
payee: None,
narration: "Sell stock".to_string(),
tags: vec![],
links: vec![],
metadata: vec![],
postings: vec![
PostingData {
account: "Assets:Broker".to_string(),
units: Some(AmountData {
number: "1000".to_string(),
currency: "USD".to_string(),
}),
cost: None,
price: None,
flag: None,
metadata: vec![],
},
PostingData {
account: "Income:Capital-Gains:Long".to_string(),
units: Some(AmountData {
number: "-100".to_string(),
currency: "USD".to_string(),
}),
cost: None,
price: None,
flag: None,
metadata: vec![],
},
],
}),
}],
options: PluginOptions {
operating_currencies: vec!["USD".to_string()],
title: None,
},
config: Some(
"{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
.to_string(),
),
};
let output = plugin.process(input);
assert_eq!(output.errors.len(), 0);
let txn = output
.directives
.iter()
.find(|d| d.directive_type == "transaction");
assert!(txn.is_some());
if let DirectiveData::Transaction(t) = &txn.unwrap().data {
let gains_posting = t
.postings
.iter()
.find(|p| p.account.contains(":Long:Gains"));
assert!(gains_posting.is_some());
}
}
}