wp-solana-pool-traits 0.1.1

Traits and utilities for Solana liquidity pool operations: PoolViewer, PoolInfuser, PositionViewer
Documentation
use anyhow::Result;

use crate::types::{CurrencyAmount, Price};

/// Calculate the total value of a position in a target currency
///
/// This function takes token amounts, fees, and a price, and calculates the
/// total value in the target currency by converting the non-target currency
/// amounts using the price.
///
/// # Arguments
/// * `amount0` - Amount of token0 in the position
/// * `amount1` - Amount of token1 in the position
/// * `fee0` - Uncollected fees in token0 (optional)
/// * `fee1` - Uncollected fees in token1 (optional)
/// * `price` - Price from token0 to token1 (Price<Currency0, Currency1>)
/// * `target_currency0` - If true, calculate value in currency0; if false,
///   calculate in currency1
///
/// # Returns
/// Total value of the position in the target currency
pub fn calculate_position_value(
    amount0: CurrencyAmount,
    amount1: CurrencyAmount,
    fee0: Option<CurrencyAmount>,
    fee1: Option<CurrencyAmount>,
    price: Price,
    target_currency0: bool,
) -> Result<CurrencyAmount> {
    // Sum amounts and fees for each currency (fees default to zero if None)
    let total_amount0 = if let Some(fee) = fee0 {
        amount0.add(&fee).map_err(|e| anyhow::anyhow!("Failed to add amount0 and fee0: {:?}", e))?
    } else {
        amount0
    };
    let total_amount1 = if let Some(fee) = fee1 {
        amount1.add(&fee).map_err(|e| anyhow::anyhow!("Failed to add amount1 and fee1: {:?}", e))?
    } else {
        amount1
    };

    let total_value =
        calculate_value_in_target_currency(total_amount0, total_amount1, price, target_currency0)
            .map_err(|e| anyhow::anyhow!("Failed to calculate position value: {:?}", e))?;

    Ok(total_value)
}

/// Calculate the value of a position in a target currency
///
/// This function takes token amounts, and a price, and calculates the value
/// in the target currency by converting the non-target currency amounts using
/// the price.
///
/// # Arguments
/// * `amount0` - Amount of token0 in the position
/// * `amount1` - Amount of token1 in the position
/// * `price` - Price from token0 to token1 (Price<Currency0, Currency1>)
/// * `target_currency0` - If true, calculate value in currency0; if false,
///   calculate in currency1
///
/// # Returns
/// Total value of the position in the target currency
pub fn calculate_value_in_target_currency(
    amount0: CurrencyAmount,
    amount1: CurrencyAmount,
    price: Price,
    target_currency0: bool,
) -> Result<CurrencyAmount> {
    if target_currency0 {
        // Total value in currency0 = amount0 + (amount1) converted to currency0
        // price is Price<Currency0, Currency1>, so we need inverted price to convert
        // currency1 -> currency0
        let inverted_price = price.invert();
        let amount1_in_currency0 = inverted_price
            .quote(&amount1)
            .map_err(|e| anyhow::anyhow!("Failed to quote amount1 to currency0: {:?}", e))?;
        amount0
            .add(&amount1_in_currency0)
            .map_err(|e| anyhow::anyhow!("Failed to add amounts in currency0: {:?}", e))
    } else {
        // Total value in currency1 = amount1 + (amount0) converted to currency1
        // price is Price<Currency0, Currency1>, so we can use it directly to convert
        // currency0 -> currency1
        let amount0_in_currency1 = price
            .quote(&amount0)
            .map_err(|e| anyhow::anyhow!("Failed to quote amount0 to currency1: {:?}", e))?;
        amount1
            .add(&amount0_in_currency1)
            .map_err(|e| anyhow::anyhow!("Failed to add amounts in currency1: {:?}", e))
    }
}

#[cfg(test)]
mod tests {
    use std::str::FromStr;

    use rust_decimal::Decimal;
    use solana_sdk::pubkey::Pubkey;

    use super::*;
    use crate::types::Currency;

    fn sol_currency() -> Currency {
        Currency::sol(None) // 9 decimals
    }

    fn usdc_currency() -> Currency {
        Currency::token(
            Pubkey::new_unique(),
            6,
            Some("USDC".to_string()),
            Some("USD Coin".to_string()),
        )
    }

    // -----------------------------------------------------------------------
    // Tests for calculate_value_in_target_currency
    // -----------------------------------------------------------------------

    #[test]
    fn test_value_in_target_currency_target_quote() {
        // 1 SOL + 50 USDC at price 100 USDC/SOL => total in USDC = 100 + 50 = 150
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000); // 1 SOL
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 50_000_000); // 50 USDC

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();

        assert_eq!(result.currency(), &usdc);
        let value = result.to_decimal();
        let expected = Decimal::from(150);
        let tolerance = Decimal::from(2);
        assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
    }

    #[test]
    fn test_value_in_target_currency_target_base() {
        // 1 SOL + 100 USDC at price 100 USDC/SOL => total in SOL = 1 + 1 = 2
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000); // 1 SOL
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000); // 100 USDC

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_value_in_target_currency(amount0, amount1, price, true).unwrap();

        assert_eq!(result.currency(), &sol);
        let value = result.to_decimal();
        let expected = Decimal::from(2);
        let tolerance = Decimal::from_str("0.1").unwrap();
        assert!((value - expected).abs() < tolerance, "Expected ~2 SOL, got {}", value);
    }

    #[test]
    fn test_value_in_target_currency_zero_amounts() {
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 0);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 0);

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();
        assert_eq!(result.raw_amount(), 0);
    }

    #[test]
    fn test_value_in_target_currency_only_one_token() {
        // Only have token B (USDC), targeting USDC
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 0);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 200_000_000); // 200 USDC

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_value_in_target_currency(amount0, amount1, price, false).unwrap();
        let value = result.to_decimal();
        assert_eq!(value, Decimal::from(200));
    }

    // -----------------------------------------------------------------------
    // Tests for calculate_position_value
    // -----------------------------------------------------------------------

    #[test]
    fn test_position_value_with_fees() {
        // Position: 2 SOL + 100 USDC, fees: 0.5 SOL + 10 USDC
        // Price: 100 USDC/SOL, target USDC
        // Total in USDC = (2 + 0.5) * 100 + (100 + 10) = 250 + 110 = 360
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 2_000_000_000);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000);
        let fee0 = Some(CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000));
        let fee1 = Some(CurrencyAmount::from_raw_amount(usdc.clone(), 10_000_000));

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_position_value(amount0, amount1, fee0, fee1, price, false).unwrap();

        let value = result.to_decimal();
        let expected = Decimal::from(360);
        let tolerance = Decimal::from(5);
        assert!((value - expected).abs() < tolerance, "Expected ~360 USDC, got {}", value);
    }

    #[test]
    fn test_position_value_without_fees() {
        // Position: 1 SOL + 50 USDC, no fees
        // Price: 100 USDC/SOL, target USDC
        // Total = 1 * 100 + 50 = 150
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 50_000_000);

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_position_value(amount0, amount1, None, None, price, false).unwrap();

        let value = result.to_decimal();
        let expected = Decimal::from(150);
        let tolerance = Decimal::from(2);
        assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
    }

    #[test]
    fn test_position_value_target_base_currency() {
        // Position: 1 SOL + 200 USDC, no fees
        // Price: 100 USDC/SOL, target SOL
        // Total in SOL = 1 + 200/100 = 1 + 2 = 3
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 200_000_000);

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_position_value(amount0, amount1, None, None, price, true).unwrap();

        let value = result.to_decimal();
        let expected = Decimal::from(3);
        let tolerance = Decimal::from_str("0.1").unwrap();
        assert!((value - expected).abs() < tolerance, "Expected ~3 SOL, got {}", value);
    }

    #[test]
    fn test_position_value_only_fee0() {
        // Only fee for token A, no fee for token B
        let sol = sol_currency();
        let usdc = usdc_currency();

        let amount0 = CurrencyAmount::from_raw_amount(sol.clone(), 1_000_000_000);
        let amount1 = CurrencyAmount::from_raw_amount(usdc.clone(), 0);
        let fee0 = Some(CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000));

        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let result = calculate_position_value(amount0, amount1, fee0, None, price, false).unwrap();

        let value = result.to_decimal();
        // (1 + 0.5) * 100 = 150
        let expected = Decimal::from(150);
        let tolerance = Decimal::from(2);
        assert!((value - expected).abs() < tolerance, "Expected ~150 USDC, got {}", value);
    }

    #[test]
    fn values_default_recombination_identity() {
        let sol = sol_currency();
        let usdc = usdc_currency();
        let amount_a = CurrencyAmount::from_raw_amount(sol.clone(), 2_000_000_000);
        let amount_b = CurrencyAmount::from_raw_amount(usdc.clone(), 100_000_000);
        let fee_a = CurrencyAmount::from_raw_amount(sol.clone(), 500_000_000);
        let fee_b = CurrencyAmount::from_raw_amount(usdc.clone(), 10_000_000);
        let price =
            Price::from_decimal(sol.clone(), usdc.clone(), Decimal::from(100), Some(18)).unwrap();

        let total_a = calculate_position_value(
            amount_a.clone(),
            amount_b.clone(),
            Some(fee_a.clone()),
            Some(fee_b.clone()),
            price.clone(),
            true,
        )
        .unwrap();
        let total_b = calculate_position_value(
            amount_a.clone(),
            amount_b.clone(),
            Some(fee_a.clone()),
            Some(fee_b.clone()),
            price.clone(),
            false,
        )
        .unwrap();
        let raw_a = calculate_position_value(
            amount_a.clone(),
            amount_b.clone(),
            None,
            None,
            price.clone(),
            true,
        )
        .unwrap();
        let raw_b = calculate_position_value(amount_a, amount_b, None, None, price, false).unwrap();

        assert!(total_a.raw_amount() >= raw_a.raw_amount());
        assert!(total_b.raw_amount() >= raw_b.raw_amount());
    }
}