pyra-margin 0.4.2

Margin weight, balance, and price calculations for Drift spot positions
Documentation
use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};

use crate::error::{MathError, MathResult};
use crate::math::CheckedDivCeil;

/// Calculate token balance from raw Drift position fields.
///
/// This is the low-level variant that accepts individual fields rather than
/// typed structs, useful when the caller has a different SpotPosition/SpotMarket
/// representation (e.g. Carbon decoder types).
///
/// Returns signed balance in base units (positive = deposit, negative = borrow).
pub fn compute_token_balance(
    scaled_balance: u64,
    is_deposit: bool,
    cumulative_deposit_interest: u128,
    cumulative_borrow_interest: u128,
    decimals: u32,
) -> MathResult<i128> {
    let precision_decrease = 10u128
        .checked_pow(19u32.saturating_sub(decimals))
        .ok_or(MathError::Overflow)?;

    // u64 always fits in u128.
    let balance = scaled_balance as u128;

    let token_balance = if is_deposit {
        let raw_balance = balance
            .checked_mul(cumulative_deposit_interest)
            .ok_or(MathError::Overflow)?;
        // Safe: result of u128/u128 division is ≤ the dividend, which is bounded
        // by scaled_balance (u64::MAX) × cumulative_interest — well within i128::MAX.
        raw_balance
            .checked_div(precision_decrease)
            .ok_or(MathError::Overflow)? as i128
    } else {
        let raw_balance = balance
            .checked_mul(cumulative_borrow_interest)
            .ok_or(MathError::Overflow)?;
        let balance_unsigned = raw_balance
            .checked_div_ceil(precision_decrease)
            .ok_or(MathError::Overflow)?;
        // Safe: same bound as deposit path — division result fits in i128.
        (balance_unsigned as i128).saturating_neg()
    };

    Ok(token_balance)
}

/// Calculate token balance from a Drift spot position and market data.
/// Returns signed balance in base units (positive = deposit, negative = borrow).
///
/// This is a convenience wrapper around [`compute_token_balance`] for callers
/// that already have `pyra_types` structs.
pub fn get_token_balance(
    spot_position: &SpotPosition,
    spot_market: &SpotMarket,
) -> MathResult<i128> {
    let is_deposit = matches!(spot_position.balance_type, SpotBalanceType::Deposit);
    compute_token_balance(
        spot_position.scaled_balance,
        is_deposit,
        spot_market.cumulative_deposit_interest,
        spot_market.cumulative_borrow_interest,
        spot_market.decimals,
    )
}

/// Calculate value of token balance in USDC base units.
/// Takes a signed token balance and converts it to USDC value using the oracle price.
pub fn calculate_value_usdc_base_units(
    token_balance_base_units: i128,
    price_usdc_base_units: u64,
    token_decimals: u32,
) -> MathResult<i128> {
    let precision_decrease = 10u128
        .checked_pow(token_decimals)
        .ok_or(MathError::Overflow)?;

    // u64 always fits in i128.
    let value_usdc_base_units = token_balance_base_units
        .checked_mul(price_usdc_base_units as i128)
        .ok_or(MathError::Overflow)?
        // u128 from checked_pow(decimals ≤ 19) always fits in i128.
        .checked_div(precision_decrease as i128)
        .ok_or(MathError::Overflow)?;

    Ok(value_usdc_base_units)
}

#[cfg(test)]
#[allow(
    clippy::allow_attributes,
    clippy::allow_attributes_without_reason,
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::panic,
    clippy::arithmetic_side_effects,
    reason = "test code"
)]
mod tests {
    use super::*;

    fn make_market(decimals: u32, deposit_interest: u128, borrow_interest: u128) -> SpotMarket {
        SpotMarket {
            pubkey: vec![],
            market_index: 0,
            initial_asset_weight: 0,
            initial_liability_weight: 0,
            imf_factor: 0,
            scale_initial_asset_weight_start: 0,
            decimals,
            cumulative_deposit_interest: deposit_interest,
            cumulative_borrow_interest: borrow_interest,
            deposit_balance: 0,
            borrow_balance: 0,
            optimal_utilization: 0,
            optimal_borrow_rate: 0,
            max_borrow_rate: 0,
            min_borrow_rate: 0,
            insurance_fund: Default::default(),
            historical_oracle_data: Default::default(),
            oracle: None,
        }
    }

    #[test]
    fn deposit_balance_basic() {
        let precision_decrease = 10u128.pow(19 - 6); // decimals=6
        let market = make_market(6, precision_decrease, precision_decrease);
        let position = SpotPosition {
            scaled_balance: 1_000_000,
            balance_type: SpotBalanceType::Deposit,
            ..Default::default()
        };
        let balance = get_token_balance(&position, &market).unwrap();
        assert_eq!(balance, 1_000_000); // 1 USDC
    }

    #[test]
    fn borrow_balance_is_negative() {
        let precision_decrease = 10u128.pow(19 - 6);
        let market = make_market(6, precision_decrease, precision_decrease);
        let position = SpotPosition {
            scaled_balance: 500_000,
            balance_type: SpotBalanceType::Borrow,
            ..Default::default()
        };
        let balance = get_token_balance(&position, &market).unwrap();
        assert_eq!(balance, -500_000);
    }

    #[test]
    fn deposit_with_interest() {
        let precision_decrease = 10u128.pow(19 - 6);
        // 10% interest: cumulative_deposit_interest = precision_decrease * 1.1
        let interest = precision_decrease
            .checked_mul(11)
            .unwrap()
            .checked_div(10)
            .unwrap();
        let market = make_market(6, interest, precision_decrease);
        let position = SpotPosition {
            scaled_balance: 1_000_000,
            balance_type: SpotBalanceType::Deposit,
            ..Default::default()
        };
        let balance = get_token_balance(&position, &market).unwrap();
        assert_eq!(balance, 1_100_000); // 1.1 USDC
    }

    #[test]
    fn value_usdc_basic() {
        // 1 SOL at $100 (price in USDC base units = 100_000_000)
        let value = calculate_value_usdc_base_units(1_000_000_000, 100_000_000, 9).unwrap();
        assert_eq!(value, 100_000_000); // $100 in USDC base units
    }

    #[test]
    fn value_usdc_negative_balance() {
        let value = calculate_value_usdc_base_units(-1_000_000_000, 100_000_000, 9).unwrap();
        assert_eq!(value, -100_000_000);
    }

    #[test]
    fn value_usdc_usdc_token() {
        // 1 USDC at $1 (price = 1_000_000)
        let value = calculate_value_usdc_base_units(1_000_000, 1_000_000, 6).unwrap();
        assert_eq!(value, 1_000_000);
    }
}

#[cfg(test)]
#[allow(
    clippy::allow_attributes,
    clippy::allow_attributes_without_reason,
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::panic,
    clippy::arithmetic_side_effects,
    reason = "test code"
)]
mod proptests {
    use super::*;
    use proptest::prelude::*;

    fn arb_market(decimals: u32) -> SpotMarket {
        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
        SpotMarket {
            pubkey: vec![],
            market_index: 0,
            initial_asset_weight: 10_000,
            initial_liability_weight: 10_000,
            imf_factor: 0,
            scale_initial_asset_weight_start: 0,
            decimals,
            cumulative_deposit_interest: precision_decrease,
            cumulative_borrow_interest: precision_decrease,
            deposit_balance: 0,
            borrow_balance: 0,
            optimal_utilization: 0,
            optimal_borrow_rate: 0,
            max_borrow_rate: 0,
            min_borrow_rate: 0,
            insurance_fund: Default::default(),
            historical_oracle_data: Default::default(),
            oracle: None,
        }
    }

    proptest! {
        #[test]
        fn deposit_balance_always_non_negative(
            scaled_balance in 0u64..=1_000_000_000_000u64,
            decimals in 6u32..=9u32,
        ) {
            let market = arb_market(decimals);
            let position = SpotPosition {
                scaled_balance,
                balance_type: SpotBalanceType::Deposit,
                ..Default::default()
            };
            let balance = get_token_balance(&position, &market).unwrap();
            prop_assert!(balance >= 0, "deposit balance {} should be >= 0", balance);
        }

        #[test]
        fn borrow_balance_always_non_positive(
            scaled_balance in 0u64..=1_000_000_000_000u64,
            decimals in 6u32..=9u32,
        ) {
            let market = arb_market(decimals);
            let position = SpotPosition {
                scaled_balance,
                balance_type: SpotBalanceType::Borrow,
                ..Default::default()
            };
            let balance = get_token_balance(&position, &market).unwrap();
            prop_assert!(balance <= 0, "borrow balance {} should be <= 0", balance);
        }

        #[test]
        fn value_preserves_sign(
            balance in -1_000_000_000_000i128..=1_000_000_000_000i128,
            price in 1u64..=1_000_000_000u64,
            decimals in 6u32..=9u32,
        ) {
            let value = calculate_value_usdc_base_units(balance, price, decimals).unwrap();
            if balance > 0 {
                prop_assert!(value >= 0, "positive balance should give non-negative value");
            } else if balance < 0 {
                prop_assert!(value <= 0, "negative balance should give non-positive value");
            } else {
                prop_assert_eq!(value, 0);
            }
        }
    }
}