use anyhow::Result;
use crate::types::{CurrencyAmount, Price};
pub fn calculate_position_value(
amount0: CurrencyAmount,
amount1: CurrencyAmount,
fee0: Option<CurrencyAmount>,
fee1: Option<CurrencyAmount>,
price: Price,
target_currency0: bool,
) -> Result<CurrencyAmount> {
let total_amount0 = if let Some(fee) = fee0 {
amount0.add(&fee).map_err(|e| anyhow::anyhow!("Failed to add amount0 and fee0: {:?}", e))?
} else {
amount0
};
let total_amount1 = if let Some(fee) = fee1 {
amount1.add(&fee).map_err(|e| anyhow::anyhow!("Failed to add amount1 and fee1: {:?}", e))?
} else {
amount1
};
let total_value =
calculate_value_in_target_currency(total_amount0, total_amount1, price, target_currency0)
.map_err(|e| anyhow::anyhow!("Failed to calculate position value: {:?}", e))?;
Ok(total_value)
}
pub fn calculate_value_in_target_currency(
amount0: CurrencyAmount,
amount1: CurrencyAmount,
price: Price,
target_currency0: bool,
) -> Result<CurrencyAmount> {
if target_currency0 {
let inverted_price = price.invert();
let amount1_in_currency0 = inverted_price
.quote(&amount1)
.map_err(|e| anyhow::anyhow!("Failed to quote amount1 to currency0: {:?}", e))?;
amount0
.add(&amount1_in_currency0)
.map_err(|e| anyhow::anyhow!("Failed to add amounts in currency0: {:?}", e))
} else {
let amount0_in_currency1 = price
.quote(&amount0)
.map_err(|e| anyhow::anyhow!("Failed to quote amount0 to currency1: {:?}", e))?;
amount1
.add(&amount0_in_currency1)
.map_err(|e| anyhow::anyhow!("Failed to add amounts in currency1: {:?}", e))
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use rust_decimal::Decimal;
use solana_sdk::pubkey::Pubkey;
use super::*;
use crate::types::Currency;
fn sol_currency() -> Currency {
Currency::sol(None) }
fn usdc_currency() -> Currency {
Currency::token(
Pubkey::new_unique(),
6,
Some("USDC".to_string()),
Some("USD Coin".to_string()),
)
}
#[test]
fn test_value_in_target_currency_target_quote() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000); let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 50_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();
assert_eq!(result.currency(), &usdc);
let value = result.to_decimal();
let expected = Decimal::from(150);
let tolerance = Decimal::from(2);
assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
}
#[test]
fn test_value_in_target_currency_target_base() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000); let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_value_in_target_currency(amount0, amount1, price, true).unwrap();
assert_eq!(result.currency(), &sol);
let value = result.to_decimal();
let expected = Decimal::from(2);
let tolerance = Decimal::from_str("0.1").unwrap();
assert!((value - expected).abs() < tolerance, "Expected ~2 SOL, got {}", value);
}
#[test]
fn test_value_in_target_currency_zero_amounts() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 0);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 0);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();
assert_eq!(result.raw_amount(), 0);
}
#[test]
fn test_value_in_target_currency_only_one_token() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 0);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 200_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();
let value = result.to_decimal();
assert_eq!(value, Decimal::from(200));
}
#[test]
fn test_position_value_with_fees() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 2_000_000_000);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000);
let fee0 = Some(CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000));
let fee1 = Some(CurrencyAmount::from_raw_amount(usdc.clone(), 10_000_000));
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_position_value(amount0, amount1, fee0, fee1, price, false).unwrap();
let value = result.to_decimal();
let expected = Decimal::from(360);
let tolerance = Decimal::from(5);
assert!((value - expected).abs() < tolerance, "Expected ~360 USDC, got {}", value);
}
#[test]
fn test_position_value_without_fees() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 50_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_position_value(amount0, amount1, None, None, price, false).unwrap();
let value = result.to_decimal();
let expected = Decimal::from(150);
let tolerance = Decimal::from(2);
assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
}
#[test]
fn test_position_value_target_base_currency() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 200_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_position_value(amount0, amount1, None, None, price, true).unwrap();
let value = result.to_decimal();
let expected = Decimal::from(3);
let tolerance = Decimal::from_str("0.1").unwrap();
assert!((value - expected).abs() < tolerance, "Expected ~3 SOL, got {}", value);
}
#[test]
fn test_position_value_only_fee0() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 0);
let fee0 = Some(CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000));
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let result = calculate_position_value(amount0, amount1, fee0, None, price, false).unwrap();
let value = result.to_decimal();
let expected = Decimal::from(150);
let tolerance = Decimal::from(2);
assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
}
#[test]
fn values_default_recombination_identity() {
let sol = sol_currency();
let usdc = usdc_currency();
let amount_a = CurrencyAmount::from_raw_amount(sol.clone(), 2_000_000_000);
let amount_b = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000);
let fee_a = CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000);
let fee_b = CurrencyAmount::from_raw_amount(usdc.clone(), 10_000_000);
let price =
Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();
let total_a = calculate_position_value(
amount_a.clone(),
amount_b.clone(),
Some(fee_a.clone()),
Some(fee_b.clone()),
price.clone(),
true,
)
.unwrap();
let total_b = calculate_position_value(
amount_a.clone(),
amount_b.clone(),
Some(fee_a.clone()),
Some(fee_b.clone()),
price.clone(),
false,
)
.unwrap();
let raw_a = calculate_position_value(
amount_a.clone(),
amount_b.clone(),
None,
None,
price.clone(),
true,
)
.unwrap();
let raw_b = calculate_position_value(amount_a, amount_b, None, None, price, false).unwrap();
assert!(total_a.raw_amount() >= raw_a.raw_amount());
assert!(total_b.raw_amount() >= raw_b.raw_amount());
}
}