Skip to main content

pyra_margin/drift/
capacity.rs

1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13/// Per-position info emitted during capacity calculation.
14///
15/// Used downstream by liquidating spend jobs to know position sizes and weights.
16#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18    pub asset_id: AssetId,
19    /// Unsigned token balance in base units.
20    pub balance: u64,
21    pub position_type: SpotBalanceType,
22    pub price_usdc_base_units: u64,
23    /// Initial asset weight from the spot market (basis points).
24    pub weight_bps: u32,
25}
26
27/// Result of spending capacity calculation.
28#[derive(Debug, Clone)]
29pub struct CapacityResult {
30    /// Max spendable via liquidating spend (unweighted collateral - slippage - liabilities), in cents.
31    /// Excludes unliquidatable assets from collateral.
32    pub total_spendable_cents: u64,
33    /// Available credit line (weighted collateral - weighted liabilities), in cents.
34    pub available_credit_cents: u64,
35    /// USDC balance (market index 0) in cents.
36    pub usdc_balance_cents: u64,
37    /// Total weighted collateral in USDC base units.
38    pub weighted_collateral_usdc_base_units: u64,
39    /// Total weighted liabilities in USDC base units.
40    pub weighted_liabilities_usdc_base_units: u64,
41    /// Per-position breakdown for downstream use.
42    pub position_infos: Vec<PositionInfo>,
43}
44
45/// Calculate spending capacity from Drift spot positions.
46///
47/// This is the core calculation used for card transaction authorization.
48/// It computes:
49/// - **total_spendable**: max amount for liquidating spends (collateral minus slippage minus liabilities,
50///   excluding unliquidatable assets)
51/// - **available_credit**: credit line from weighted margin (weighted collateral minus weighted liabilities)
52/// - **usdc_balance**: direct USDC holdings
53///
54/// Positions in `unliquidatable_asset_ids` are excluded from unweighted collateral
55/// (affecting `total_spendable`) but still included in weighted calculations
56/// (affecting `available_credit`).
57///
58/// All maps are keyed by **asset_id** (not Drift market index).
59/// Positions whose asset_id is missing from `spot_market_map` or `price_map` are skipped.
60/// Positions for tokens not in `pyra_tokens` are also skipped.
61pub fn calculate_capacity(
62    spot_positions: &[SpotPosition],
63    spot_market_map: &HashMap<AssetId, SpotMarket>,
64    price_map: &HashMap<AssetId, u64>,
65    unliquidatable_asset_ids: &[AssetId],
66    max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68    let mut total_collateral_usdc_base_units: u64 = 0;
69    let mut total_liabilities_usdc_base_units: u64 = 0;
70
71    let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72    let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74    let mut usdc_balance_base_units: u64 = 0;
75
76    let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78    for position in spot_positions {
79        // Convert Drift market_index to Pyra asset_id for map lookups
80        let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
81        else {
82            continue;
83        };
84        let asset_id = token.asset_id;
85
86        let Some(spot_market) = spot_market_map.get(&asset_id) else {
87            continue;
88        };
89        let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
90            continue;
91        };
92
93        // Step 1: Calculate token balance and USDC value
94        let token_balance_base_units = get_token_balance(position, spot_market)?;
95
96        let is_asset = token_balance_base_units >= 0;
97        let twap5min = spot_market
98            .historical_oracle_data
99            .last_oracle_price_twap5min;
100        let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
101
102        let value_usdc_base_units = calculate_value_usdc_base_units(
103            token_balance_base_units,
104            strict_price,
105            spot_market.decimals,
106        )?;
107
108        // Accumulate unweighted totals (excluding unliquidatable collateral)
109        let is_unliquidatable_collateral =
110            unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
111        if !is_unliquidatable_collateral {
112            update_running_totals(
113                &mut total_collateral_usdc_base_units,
114                &mut total_liabilities_usdc_base_units,
115                value_usdc_base_units,
116            )?;
117        }
118
119        // Step 2: Apply IMF-adjusted weights
120        let token_amount_unsigned = token_balance_base_units.unsigned_abs();
121        let weight_bps = if is_asset {
122            calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
123                as i128
124        } else {
125            calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
126        };
127        let weighted_value_usdc_base_units = value_usdc_base_units
128            .checked_mul(weight_bps)
129            .ok_or(MathError::Overflow)?
130            .checked_div(MARGIN_PRECISION)
131            .ok_or(MathError::Overflow)?;
132
133        update_running_totals(
134            &mut total_weighted_collateral_usdc_base_units,
135            &mut total_weighted_liabilities_usdc_base_units,
136            weighted_value_usdc_base_units,
137        )?;
138
139        // Step 3: Track USDC balance (asset_id 0)
140        if asset_id == pyra_tokens::AssetId::USDC
141            && usdc_balance_base_units == 0
142            && token_balance_base_units > 0
143        {
144            usdc_balance_base_units =
145                u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
146        }
147
148        // Step 4: Store position info
149        let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
150            .map_err(|_| MathError::Overflow)?;
151        position_infos.push(PositionInfo {
152            asset_id,
153            balance: token_balance_unsigned,
154            position_type: position.balance_type.clone(),
155            price_usdc_base_units,
156            weight_bps: spot_market.initial_asset_weight,
157        });
158    }
159
160    // Step 5: Available credit = weighted collateral - weighted liabilities
161    let available_credit_base_units = total_weighted_collateral_usdc_base_units
162        .saturating_sub(total_weighted_liabilities_usdc_base_units);
163    let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
164
165    // Step 6: Total spendable = collateral - slippage - liabilities (for liquidating spends)
166    let max_slippage_usdc_base_units = total_collateral_usdc_base_units
167        .checked_mul(max_slippage_bps)
168        .ok_or(MathError::Overflow)?
169        .checked_div(10_000)
170        .ok_or(MathError::Overflow)?;
171    let total_spendable_base_units = total_collateral_usdc_base_units
172        .saturating_sub(max_slippage_usdc_base_units)
173        .saturating_sub(total_liabilities_usdc_base_units);
174    let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
175
176    let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
177
178    Ok(CapacityResult {
179        total_spendable_cents,
180        available_credit_cents,
181        usdc_balance_cents,
182        weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
183        weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
184        position_infos,
185    })
186}
187
188/// Accumulate a signed value into positive/negative running totals.
189fn update_running_totals(
190    total_positive: &mut u64,
191    total_negative: &mut u64,
192    value: i128,
193) -> MathResult<()> {
194    let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
195
196    if value >= 0 {
197        *total_positive = total_positive
198            .checked_add(value_unsigned)
199            .ok_or(MathError::Overflow)?;
200    } else {
201        *total_negative = total_negative
202            .checked_add(value_unsigned)
203            .ok_or(MathError::Overflow)?;
204    }
205
206    Ok(())
207}
208
209#[cfg(test)]
210#[allow(
211    clippy::allow_attributes,
212    clippy::allow_attributes_without_reason,
213    clippy::unwrap_used,
214    clippy::expect_used,
215    clippy::panic,
216    clippy::arithmetic_side_effects
217)]
218mod tests {
219    use super::*;
220    use pyra_types::{HistoricalOracleData, InsuranceFund};
221
222    fn make_spot_market_with_twap(
223        market_index: u16,
224        decimals: u32,
225        initial_asset_weight: u32,
226        initial_liability_weight: u32,
227        twap5min: i64,
228    ) -> SpotMarket {
229        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
230        SpotMarket {
231            pubkey: vec![],
232            market_index,
233            initial_asset_weight,
234            initial_liability_weight,
235            imf_factor: 0,
236            scale_initial_asset_weight_start: 0,
237            decimals,
238            cumulative_deposit_interest: precision_decrease,
239            cumulative_borrow_interest: precision_decrease,
240            deposit_balance: 0,
241            borrow_balance: 0,
242            optimal_utilization: 0,
243            optimal_borrow_rate: 0,
244            max_borrow_rate: 0,
245            min_borrow_rate: 0,
246            insurance_fund: InsuranceFund::default(),
247            historical_oracle_data: HistoricalOracleData {
248                last_oracle_price_twap5min: twap5min,
249            },
250            oracle: None,
251        }
252    }
253
254    /// Convenience: creates a spot market where TWAP matches the oracle price.
255    fn make_spot_market(
256        market_index: u16,
257        decimals: u32,
258        initial_asset_weight: u32,
259        initial_liability_weight: u32,
260        oracle_price: u64,
261    ) -> SpotMarket {
262        make_spot_market_with_twap(
263            market_index,
264            decimals,
265            initial_asset_weight,
266            initial_liability_weight,
267            oracle_price as i64,
268        )
269    }
270
271    /// Create a SpotPosition with the given Drift market_index.
272    fn make_position(
273        drift_market_index: u16,
274        scaled_balance: u64,
275        is_deposit: bool,
276    ) -> SpotPosition {
277        SpotPosition {
278            market_index: drift_market_index,
279            scaled_balance,
280            balance_type: if is_deposit {
281                SpotBalanceType::Deposit
282            } else {
283                SpotBalanceType::Borrow
284            },
285            ..Default::default()
286        }
287    }
288
289    // --- calculate_capacity ---
290
291    #[test]
292    fn empty_positions() {
293        let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
294        assert_eq!(result.total_spendable_cents, 0);
295        assert_eq!(result.available_credit_cents, 0);
296        assert_eq!(result.usdc_balance_cents, 0);
297        assert_eq!(result.weighted_collateral_usdc_base_units, 0);
298        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
299        assert!(result.position_infos.is_empty());
300    }
301
302    #[test]
303    fn single_usdc_deposit() {
304        // USDC: drift market_index=0, asset_id=0
305        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
306        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
307
308        let mut markets = HashMap::new();
309        markets.insert(AssetId::USDC, usdc); // keyed by asset_id
310        let mut prices = HashMap::new();
311        prices.insert(AssetId::USDC, 1_000_000u64); // keyed by asset_id
312
313        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
314
315        assert_eq!(result.usdc_balance_cents, 10_000); // $100
316        assert_eq!(result.total_spendable_cents, 10_000);
317        assert_eq!(result.available_credit_cents, 10_000);
318        assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
319        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
320        assert_eq!(result.position_infos.len(), 1);
321        assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
322    }
323
324    #[test]
325    fn deposit_and_borrow() {
326        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
327        let positions = vec![
328            make_position(0, 100_000_000, true), // 100 USDC deposit
329            make_position(0, 50_000_000, false), // 50 USDC borrow
330        ];
331
332        let mut markets = HashMap::new();
333        markets.insert(AssetId::USDC, usdc);
334        let mut prices = HashMap::new();
335        prices.insert(AssetId::USDC, 1_000_000u64);
336
337        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
338
339        assert_eq!(result.usdc_balance_cents, 10_000); // 100 USDC deposit
340        assert_eq!(result.total_spendable_cents, 5_000); // 100 - 50 = 50 USDC
341        assert_eq!(result.available_credit_cents, 5_000);
342    }
343
344    #[test]
345    fn unliquidatable_excluded_from_spendable() {
346        // wETH: drift_market_index=4, asset_id=4
347        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
348        let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
349
350        let positions = vec![
351            make_position(0, 10_000_000, true),    // 10 USDC (drift idx 0)
352            make_position(4, 1_000_000_000, true), // 1 wETH (drift idx 4)
353        ];
354
355        let mut markets = HashMap::new();
356        markets.insert(AssetId::USDC, usdc); // asset_id 0 = USDC
357        markets.insert(AssetId::WETH, weth); // asset_id 4 = wETH
358        let mut prices = HashMap::new();
359        prices.insert(AssetId::USDC, 1_000_000u64);
360        prices.insert(AssetId::WETH, 100_000_000u64); // $100
361
362        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
363        let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
364
365        // total_spendable: only USDC (10M base = 1000 cents), wETH excluded
366        assert_eq!(result.total_spendable_cents, 1_000);
367        // available_credit: includes wETH weighted (100M * 80% = 80M) + USDC (10M * 100% = 10M) = 90M = 9000 cents
368        assert_eq!(result.available_credit_cents, 9_000);
369        assert_eq!(result.position_infos.len(), 2);
370    }
371
372    #[test]
373    fn slippage_reduces_spendable() {
374        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
375        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
376
377        let mut markets = HashMap::new();
378        markets.insert(AssetId::USDC, usdc);
379        let mut prices = HashMap::new();
380        prices.insert(AssetId::USDC, 1_000_000u64);
381
382        // 10% slippage = 1000 bps
383        let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
384
385        // 100 USDC collateral, 10% slippage = 10 USDC, spendable = 90 USDC = 9000 cents
386        assert_eq!(result.total_spendable_cents, 9_000);
387        // available_credit not affected by slippage
388        assert_eq!(result.available_credit_cents, 10_000);
389    }
390
391    #[test]
392    fn missing_market_skipped() {
393        // drift_market_index=5 maps to asset_id=5 (USDT) — but no market data provided
394        let positions = vec![make_position(5, 1_000_000, true)];
395
396        let result =
397            calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
398
399        assert_eq!(result.total_spendable_cents, 0);
400        assert!(result.position_infos.is_empty());
401    }
402
403    #[test]
404    fn missing_price_skipped() {
405        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
406        let positions = vec![make_position(0, 1_000_000, true)];
407
408        let mut markets = HashMap::new();
409        markets.insert(AssetId::USDC, usdc); // asset_id 0
410
411        let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
412
413        assert_eq!(result.total_spendable_cents, 0);
414        assert!(result.position_infos.is_empty());
415    }
416
417    #[test]
418    fn multi_position_with_unliquidatable_and_slippage() {
419        // Use real tokens:
420        // USDC:  drift_idx=0,  asset_id=0, 6 decimals
421        // wETH:  drift_idx=4,  asset_id=4, 8 decimals -> unliquidatable
422        // wSOL:  drift_idx=1,  asset_id=1, 9 decimals
423        // USDT:  drift_idx=5,  asset_id=5, 6 decimals -> borrow
424        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
425        let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); // unliquidatable
426        let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
427        let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
428
429        let positions = vec![
430            make_position(0, 50_000_000, true),    // 50 USDC (drift idx 0)
431            make_position(4, 1_000_000_000, true), // 1 wETH @ $200 (drift idx 4, unliquidatable)
432            make_position(1, 500_000_000, true),   // 0.5 wSOL @ $100 (drift idx 1)
433            make_position(5, 20_000_000, false),   // 20 USDT borrow (drift idx 5)
434        ];
435
436        // Maps keyed by asset_id
437        let mut markets = HashMap::new();
438        markets.insert(AssetId::USDC, usdc); // asset_id 0
439        markets.insert(AssetId::WETH, weth); // asset_id 4
440        markets.insert(AssetId::WSOL, wsol); // asset_id 1
441        markets.insert(AssetId::USDT, usdt); // asset_id 5
442        let mut prices = HashMap::new();
443        prices.insert(AssetId::USDC, 1_000_000u64);
444        prices.insert(AssetId::WETH, 200_000_000u64);
445        prices.insert(AssetId::WSOL, 100_000_000u64);
446        prices.insert(AssetId::USDT, 1_000_000u64);
447
448        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
449        let result =
450            calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
451
452        // Unweighted collateral (excluding wETH): 50M (USDC) + 50M (wSOL: 0.5 * $100) = 100M
453        // Unweighted liabilities: 20M (USDT borrow)
454        // Slippage: 100M * 500/10000 = 5M
455        // total_spendable = 100M - 5M - 20M = 75M base = 7500 cents
456        assert_eq!(result.total_spendable_cents, 7_500);
457
458        // Weighted collateral: 50M*100% + 200M*80% + 50M*80% = 50M + 160M + 40M = 250M
459        // Weighted liabilities: 20M*100% = 20M
460        // available_credit = 250M - 20M = 230M base = 23000 cents
461        assert_eq!(result.available_credit_cents, 23_000);
462        assert_eq!(result.usdc_balance_cents, 5_000);
463        assert_eq!(result.position_infos.len(), 4);
464    }
465
466    // --- update_running_totals ---
467
468    #[test]
469    fn running_totals_positive() {
470        let mut pos = 0u64;
471        let mut neg = 0u64;
472        update_running_totals(&mut pos, &mut neg, 100).unwrap();
473        assert_eq!(pos, 100);
474        assert_eq!(neg, 0);
475    }
476
477    #[test]
478    fn running_totals_negative() {
479        let mut pos = 0u64;
480        let mut neg = 0u64;
481        update_running_totals(&mut pos, &mut neg, -50).unwrap();
482        assert_eq!(pos, 0);
483        assert_eq!(neg, 50);
484    }
485
486    #[test]
487    fn running_totals_accumulate() {
488        let mut pos = 10u64;
489        let mut neg = 5u64;
490        update_running_totals(&mut pos, &mut neg, 20).unwrap();
491        update_running_totals(&mut pos, &mut neg, -15).unwrap();
492        assert_eq!(pos, 30);
493        assert_eq!(neg, 20);
494    }
495}
496
497#[cfg(test)]
498#[allow(
499    clippy::allow_attributes,
500    clippy::allow_attributes_without_reason,
501    clippy::unwrap_used,
502    clippy::expect_used,
503    clippy::panic,
504    clippy::arithmetic_side_effects
505)]
506mod proptests {
507    use super::*;
508    use proptest::prelude::*;
509
510    proptest! {
511        #[test]
512        fn spendable_le_collateral_minus_liabilities(
513            collateral_base in 0u64..=1_000_000_000_000u64,
514            liabilities_base in 0u64..=500_000_000_000u64,
515        ) {
516            // Spendable should never exceed collateral - liabilities (without slippage)
517            let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
518            let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
519            let max_possible = collateral_cents.saturating_sub(liabilities_cents);
520
521            // Since we're testing the formula directly, we can verify the invariant
522            let spendable_base = collateral_base.saturating_sub(liabilities_base);
523            let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
524            prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
525        }
526    }
527}