use std::collections::HashMap;
use pyra_tokens::AssetId;
use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
use super::balance::{calculate_value_usdc_base_units, get_token_balance};
use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
use crate::common::usdc_base_units_to_cents;
use crate::error::{MathError, MathResult};
const MARGIN_PRECISION: i128 = 10_000;
#[derive(Debug, Clone, PartialEq)]
pub struct PositionInfo {
pub asset_id: AssetId,
pub balance: u64,
pub position_type: SpotBalanceType,
pub price_usdc_base_units: u64,
pub weight_bps: u32,
}
#[derive(Debug, Clone)]
pub struct CapacityResult {
pub total_spendable_cents: u64,
pub available_credit_cents: u64,
pub usdc_balance_cents: u64,
pub weighted_collateral_usdc_base_units: u64,
pub weighted_liabilities_usdc_base_units: u64,
pub position_infos: Vec<PositionInfo>,
}
pub fn calculate_capacity(
spot_positions: &[SpotPosition],
spot_market_map: &HashMap<AssetId, SpotMarket>,
price_map: &HashMap<AssetId, u64>,
unliquidatable_asset_ids: &[AssetId],
max_slippage_bps: u64,
) -> MathResult<CapacityResult> {
let mut total_collateral_usdc_base_units: u64 = 0;
let mut total_liabilities_usdc_base_units: u64 = 0;
let mut total_weighted_collateral_usdc_base_units: u64 = 0;
let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
let mut usdc_balance_base_units: u64 = 0;
let mut position_infos: Vec<PositionInfo> = Vec::new();
for position in spot_positions {
let Some(token) =
pyra_tokens::Token::find_by_drift_market_index(position.market_index)
else {
continue;
};
let asset_id = token.asset_id;
let Some(spot_market) = spot_market_map.get(&asset_id) else {
continue;
};
let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
continue;
};
let token_balance_base_units = get_token_balance(position, spot_market)?;
let is_asset = token_balance_base_units >= 0;
let twap5min = spot_market
.historical_oracle_data
.last_oracle_price_twap5min;
let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
let value_usdc_base_units = calculate_value_usdc_base_units(
token_balance_base_units,
strict_price,
spot_market.decimals,
)?;
let is_unliquidatable_collateral =
unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
if !is_unliquidatable_collateral {
update_running_totals(
&mut total_collateral_usdc_base_units,
&mut total_liabilities_usdc_base_units,
value_usdc_base_units,
)?;
}
let token_amount_unsigned = token_balance_base_units.unsigned_abs();
let weight_bps = if is_asset {
calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
as i128
} else {
calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
};
let weighted_value_usdc_base_units = value_usdc_base_units
.checked_mul(weight_bps)
.ok_or(MathError::Overflow)?
.checked_div(MARGIN_PRECISION)
.ok_or(MathError::Overflow)?;
update_running_totals(
&mut total_weighted_collateral_usdc_base_units,
&mut total_weighted_liabilities_usdc_base_units,
weighted_value_usdc_base_units,
)?;
if asset_id == pyra_tokens::AssetId::USDC && usdc_balance_base_units == 0 && token_balance_base_units > 0 {
usdc_balance_base_units =
u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
}
let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
.map_err(|_| MathError::Overflow)?;
position_infos.push(PositionInfo {
asset_id,
balance: token_balance_unsigned,
position_type: position.balance_type.clone(),
price_usdc_base_units,
weight_bps: spot_market.initial_asset_weight,
});
}
let available_credit_base_units = total_weighted_collateral_usdc_base_units
.saturating_sub(total_weighted_liabilities_usdc_base_units);
let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
let max_slippage_usdc_base_units = total_collateral_usdc_base_units
.checked_mul(max_slippage_bps)
.ok_or(MathError::Overflow)?
.checked_div(10_000)
.ok_or(MathError::Overflow)?;
let total_spendable_base_units = total_collateral_usdc_base_units
.saturating_sub(max_slippage_usdc_base_units)
.saturating_sub(total_liabilities_usdc_base_units);
let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
Ok(CapacityResult {
total_spendable_cents,
available_credit_cents,
usdc_balance_cents,
weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
position_infos,
})
}
fn update_running_totals(
total_positive: &mut u64,
total_negative: &mut u64,
value: i128,
) -> MathResult<()> {
let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
if value >= 0 {
*total_positive = total_positive
.checked_add(value_unsigned)
.ok_or(MathError::Overflow)?;
} else {
*total_negative = total_negative
.checked_add(value_unsigned)
.ok_or(MathError::Overflow)?;
}
Ok(())
}
#[cfg(test)]
#[allow(
clippy::allow_attributes,
clippy::allow_attributes_without_reason,
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::arithmetic_side_effects,
reason = "test code"
)]
mod tests {
use super::*;
use pyra_types::{HistoricalOracleData, InsuranceFund};
fn make_spot_market_with_twap(
market_index: u16,
decimals: u32,
initial_asset_weight: u32,
initial_liability_weight: u32,
twap5min: i64,
) -> SpotMarket {
let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
SpotMarket {
pubkey: vec![],
market_index,
initial_asset_weight,
initial_liability_weight,
imf_factor: 0,
scale_initial_asset_weight_start: 0,
decimals,
cumulative_deposit_interest: precision_decrease,
cumulative_borrow_interest: precision_decrease,
deposit_balance: 0,
borrow_balance: 0,
optimal_utilization: 0,
optimal_borrow_rate: 0,
max_borrow_rate: 0,
min_borrow_rate: 0,
insurance_fund: InsuranceFund::default(),
historical_oracle_data: HistoricalOracleData {
last_oracle_price_twap5min: twap5min,
},
oracle: None,
}
}
fn make_spot_market(
market_index: u16,
decimals: u32,
initial_asset_weight: u32,
initial_liability_weight: u32,
oracle_price: u64,
) -> SpotMarket {
make_spot_market_with_twap(
market_index,
decimals,
initial_asset_weight,
initial_liability_weight,
oracle_price as i64,
)
}
fn make_position(
drift_market_index: u16,
scaled_balance: u64,
is_deposit: bool,
) -> SpotPosition {
SpotPosition {
market_index: drift_market_index,
scaled_balance,
balance_type: if is_deposit {
SpotBalanceType::Deposit
} else {
SpotBalanceType::Borrow
},
..Default::default()
}
}
#[test]
fn empty_positions() {
let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
assert_eq!(result.total_spendable_cents, 0);
assert_eq!(result.available_credit_cents, 0);
assert_eq!(result.usdc_balance_cents, 0);
assert_eq!(result.weighted_collateral_usdc_base_units, 0);
assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
assert!(result.position_infos.is_empty());
}
#[test]
fn single_usdc_deposit() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let positions = vec![make_position(0, 100_000_000, true)];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc); let mut prices = HashMap::new();
prices.insert(AssetId::USDC, 1_000_000u64);
let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 10_000);
assert_eq!(result.available_credit_cents, 10_000);
assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
assert_eq!(result.position_infos.len(), 1);
assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
}
#[test]
fn deposit_and_borrow() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let positions = vec![
make_position(0, 100_000_000, true), make_position(0, 50_000_000, false), ];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc);
let mut prices = HashMap::new();
prices.insert(AssetId::USDC, 1_000_000u64);
let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 5_000); assert_eq!(result.available_credit_cents, 5_000);
}
#[test]
fn unliquidatable_excluded_from_spendable() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
let positions = vec![
make_position(0, 10_000_000, true), make_position(4, 1_000_000_000, true), ];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); let mut prices = HashMap::new();
prices.insert(AssetId::USDC, 1_000_000u64);
prices.insert(AssetId::WETH, 100_000_000u64);
let unliquidatable = vec![AssetId::WETH]; let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
assert_eq!(result.total_spendable_cents, 1_000);
assert_eq!(result.available_credit_cents, 9_000);
assert_eq!(result.position_infos.len(), 2);
}
#[test]
fn slippage_reduces_spendable() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let positions = vec![make_position(0, 100_000_000, true)];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc);
let mut prices = HashMap::new();
prices.insert(AssetId::USDC, 1_000_000u64);
let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
assert_eq!(result.total_spendable_cents, 9_000);
assert_eq!(result.available_credit_cents, 10_000);
}
#[test]
fn missing_market_skipped() {
let positions = vec![make_position(5, 1_000_000, true)];
let result =
calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
assert_eq!(result.total_spendable_cents, 0);
assert!(result.position_infos.is_empty());
}
#[test]
fn missing_price_skipped() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let positions = vec![make_position(0, 1_000_000, true)];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc);
let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
assert_eq!(result.total_spendable_cents, 0);
assert!(result.position_infos.is_empty());
}
#[test]
fn multi_position_with_unliquidatable_and_slippage() {
let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
let positions = vec![
make_position(0, 50_000_000, true), make_position(4, 1_000_000_000, true), make_position(1, 500_000_000, true), make_position(5, 20_000_000, false), ];
let mut markets = HashMap::new();
markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); markets.insert(AssetId::WSOL, wsol); markets.insert(AssetId::USDT, usdt); let mut prices = HashMap::new();
prices.insert(AssetId::USDC, 1_000_000u64);
prices.insert(AssetId::WETH, 200_000_000u64);
prices.insert(AssetId::WSOL, 100_000_000u64);
prices.insert(AssetId::USDT, 1_000_000u64);
let unliquidatable = vec![AssetId::WETH]; let result =
calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
assert_eq!(result.total_spendable_cents, 7_500);
assert_eq!(result.available_credit_cents, 23_000);
assert_eq!(result.usdc_balance_cents, 5_000);
assert_eq!(result.position_infos.len(), 4);
}
#[test]
fn running_totals_positive() {
let mut pos = 0u64;
let mut neg = 0u64;
update_running_totals(&mut pos, &mut neg, 100).unwrap();
assert_eq!(pos, 100);
assert_eq!(neg, 0);
}
#[test]
fn running_totals_negative() {
let mut pos = 0u64;
let mut neg = 0u64;
update_running_totals(&mut pos, &mut neg, -50).unwrap();
assert_eq!(pos, 0);
assert_eq!(neg, 50);
}
#[test]
fn running_totals_accumulate() {
let mut pos = 10u64;
let mut neg = 5u64;
update_running_totals(&mut pos, &mut neg, 20).unwrap();
update_running_totals(&mut pos, &mut neg, -15).unwrap();
assert_eq!(pos, 30);
assert_eq!(neg, 20);
}
}
#[cfg(test)]
#[allow(
clippy::allow_attributes,
clippy::allow_attributes_without_reason,
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::arithmetic_side_effects,
reason = "test code"
)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn spendable_le_collateral_minus_liabilities(
collateral_base in 0u64..=1_000_000_000_000u64,
liabilities_base in 0u64..=500_000_000_000u64,
) {
let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
let max_possible = collateral_cents.saturating_sub(liabilities_cents);
let spendable_base = collateral_base.saturating_sub(liabilities_base);
let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
}
}
}