Skip to main content

pyra_margin/drift/
capacity.rs

1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13/// Per-position info emitted during capacity calculation.
14///
15/// Used downstream by liquidating spend jobs to know position sizes and weights.
16#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18    pub asset_id: AssetId,
19    /// Unsigned token balance in base units.
20    pub balance: u64,
21    pub position_type: SpotBalanceType,
22    pub price_usdc_base_units: u64,
23    /// Initial asset weight from the spot market (basis points).
24    pub weight_bps: u32,
25}
26
27/// Result of spending capacity calculation.
28#[derive(Debug, Clone)]
29pub struct CapacityResult {
30    /// Max spendable via liquidating spend (unweighted collateral - slippage - liabilities), in cents.
31    /// Excludes unliquidatable assets from collateral.
32    pub total_spendable_cents: u64,
33    /// Available credit line (weighted collateral - weighted liabilities), in cents.
34    pub available_credit_cents: u64,
35    /// USDC balance (market index 0) in cents.
36    pub usdc_balance_cents: u64,
37    /// Total weighted collateral in USDC base units.
38    pub weighted_collateral_usdc_base_units: u64,
39    /// Total weighted liabilities in USDC base units.
40    pub weighted_liabilities_usdc_base_units: u64,
41    /// Per-position breakdown for downstream use.
42    pub position_infos: Vec<PositionInfo>,
43}
44
45/// Calculate spending capacity from Drift spot positions.
46///
47/// This is the core calculation used for card transaction authorization.
48/// It computes:
49/// - **total_spendable**: max amount for liquidating spends (collateral minus slippage minus liabilities,
50///   excluding unliquidatable assets)
51/// - **available_credit**: credit line from weighted margin (weighted collateral minus weighted liabilities)
52/// - **usdc_balance**: direct USDC holdings
53///
54/// Positions in `unliquidatable_asset_ids` are excluded from unweighted collateral
55/// (affecting `total_spendable`) but still included in weighted calculations
56/// (affecting `available_credit`).
57///
58/// All maps are keyed by **asset_id** (not Drift market index).
59/// Positions whose asset_id is missing from `spot_market_map` or `price_map` are skipped.
60/// Positions for tokens not in `pyra_tokens` are also skipped.
61pub fn calculate_capacity(
62    spot_positions: &[SpotPosition],
63    spot_market_map: &HashMap<AssetId, SpotMarket>,
64    price_map: &HashMap<AssetId, u64>,
65    unliquidatable_asset_ids: &[AssetId],
66    max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68    let mut total_collateral_usdc_base_units: u64 = 0;
69    let mut total_liabilities_usdc_base_units: u64 = 0;
70
71    let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72    let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74    let mut usdc_balance_base_units: u64 = 0;
75
76    let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78    for position in spot_positions {
79        // Convert Drift market_index to Pyra asset_id for map lookups
80        let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
81        else {
82            continue;
83        };
84        let asset_id = token.asset_id;
85
86        let Some(spot_market) = spot_market_map.get(&asset_id) else {
87            continue;
88        };
89        let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
90            continue;
91        };
92
93        // Step 1: Calculate token balance and USDC value
94        let token_balance_base_units = get_token_balance(position, spot_market)?;
95
96        let is_asset = token_balance_base_units >= 0;
97        let twap5min = spot_market
98            .historical_oracle_data
99            .last_oracle_price_twap5min;
100        let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
101
102        let value_usdc_base_units = calculate_value_usdc_base_units(
103            token_balance_base_units,
104            strict_price,
105            spot_market.decimals,
106        )?;
107
108        // Accumulate unweighted totals (excluding unliquidatable collateral)
109        let is_unliquidatable_collateral =
110            unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
111        if !is_unliquidatable_collateral {
112            update_running_totals(
113                &mut total_collateral_usdc_base_units,
114                &mut total_liabilities_usdc_base_units,
115                value_usdc_base_units,
116            )?;
117        }
118
119        // Step 2: Apply IMF-adjusted weights
120        let token_amount_unsigned = token_balance_base_units.unsigned_abs();
121        let weight_bps = if is_asset {
122            calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
123                as i128
124        } else {
125            calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
126        };
127        let weighted_value_usdc_base_units = value_usdc_base_units
128            .checked_mul(weight_bps)
129            .ok_or(MathError::Overflow)?
130            .checked_div(MARGIN_PRECISION)
131            .ok_or(MathError::Overflow)?;
132
133        update_running_totals(
134            &mut total_weighted_collateral_usdc_base_units,
135            &mut total_weighted_liabilities_usdc_base_units,
136            weighted_value_usdc_base_units,
137        )?;
138
139        // Step 3: Track USDC balance (asset_id 0)
140        if asset_id == pyra_tokens::AssetId::USDC
141            && usdc_balance_base_units == 0
142            && token_balance_base_units > 0
143        {
144            usdc_balance_base_units =
145                u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
146        }
147
148        // Step 4: Store position info
149        let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
150            .map_err(|_| MathError::Overflow)?;
151        position_infos.push(PositionInfo {
152            asset_id,
153            balance: token_balance_unsigned,
154            position_type: position.balance_type.clone(),
155            price_usdc_base_units,
156            weight_bps: spot_market.initial_asset_weight,
157        });
158    }
159
160    // Step 5: Available credit = weighted collateral - weighted liabilities
161    let available_credit_base_units = total_weighted_collateral_usdc_base_units
162        .saturating_sub(total_weighted_liabilities_usdc_base_units);
163    let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
164
165    // Step 6: Total spendable = collateral - slippage - liabilities (for liquidating spends)
166    let max_slippage_usdc_base_units = total_collateral_usdc_base_units
167        .checked_mul(max_slippage_bps)
168        .ok_or(MathError::Overflow)?
169        .checked_div(10_000)
170        .ok_or(MathError::Overflow)?;
171    let total_spendable_base_units = total_collateral_usdc_base_units
172        .saturating_sub(max_slippage_usdc_base_units)
173        .saturating_sub(total_liabilities_usdc_base_units);
174    let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
175
176    let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
177
178    Ok(CapacityResult {
179        total_spendable_cents,
180        available_credit_cents,
181        usdc_balance_cents,
182        weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
183        weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
184        position_infos,
185    })
186}
187
188/// Accumulate a signed value into positive/negative running totals.
189fn update_running_totals(
190    total_positive: &mut u64,
191    total_negative: &mut u64,
192    value: i128,
193) -> MathResult<()> {
194    let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
195
196    if value >= 0 {
197        *total_positive = total_positive
198            .checked_add(value_unsigned)
199            .ok_or(MathError::Overflow)?;
200    } else {
201        *total_negative = total_negative
202            .checked_add(value_unsigned)
203            .ok_or(MathError::Overflow)?;
204    }
205
206    Ok(())
207}
208
209#[cfg(test)]
210#[allow(
211    clippy::allow_attributes,
212    clippy::allow_attributes_without_reason,
213    clippy::unwrap_used,
214    clippy::expect_used,
215    clippy::panic,
216    clippy::arithmetic_side_effects,
217    reason = "test code"
218)]
219mod tests {
220    use super::*;
221    use pyra_types::{HistoricalOracleData, InsuranceFund};
222
223    fn make_spot_market_with_twap(
224        market_index: u16,
225        decimals: u32,
226        initial_asset_weight: u32,
227        initial_liability_weight: u32,
228        twap5min: i64,
229    ) -> SpotMarket {
230        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
231        SpotMarket {
232            pubkey: vec![],
233            market_index,
234            initial_asset_weight,
235            initial_liability_weight,
236            imf_factor: 0,
237            scale_initial_asset_weight_start: 0,
238            decimals,
239            cumulative_deposit_interest: precision_decrease,
240            cumulative_borrow_interest: precision_decrease,
241            deposit_balance: 0,
242            borrow_balance: 0,
243            optimal_utilization: 0,
244            optimal_borrow_rate: 0,
245            max_borrow_rate: 0,
246            min_borrow_rate: 0,
247            insurance_fund: InsuranceFund::default(),
248            historical_oracle_data: HistoricalOracleData {
249                last_oracle_price_twap5min: twap5min,
250            },
251            oracle: None,
252        }
253    }
254
255    /// Convenience: creates a spot market where TWAP matches the oracle price.
256    fn make_spot_market(
257        market_index: u16,
258        decimals: u32,
259        initial_asset_weight: u32,
260        initial_liability_weight: u32,
261        oracle_price: u64,
262    ) -> SpotMarket {
263        make_spot_market_with_twap(
264            market_index,
265            decimals,
266            initial_asset_weight,
267            initial_liability_weight,
268            oracle_price as i64,
269        )
270    }
271
272    /// Create a SpotPosition with the given Drift market_index.
273    fn make_position(
274        drift_market_index: u16,
275        scaled_balance: u64,
276        is_deposit: bool,
277    ) -> SpotPosition {
278        SpotPosition {
279            market_index: drift_market_index,
280            scaled_balance,
281            balance_type: if is_deposit {
282                SpotBalanceType::Deposit
283            } else {
284                SpotBalanceType::Borrow
285            },
286            ..Default::default()
287        }
288    }
289
290    // --- calculate_capacity ---
291
292    #[test]
293    fn empty_positions() {
294        let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
295        assert_eq!(result.total_spendable_cents, 0);
296        assert_eq!(result.available_credit_cents, 0);
297        assert_eq!(result.usdc_balance_cents, 0);
298        assert_eq!(result.weighted_collateral_usdc_base_units, 0);
299        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
300        assert!(result.position_infos.is_empty());
301    }
302
303    #[test]
304    fn single_usdc_deposit() {
305        // USDC: drift market_index=0, asset_id=0
306        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
307        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
308
309        let mut markets = HashMap::new();
310        markets.insert(AssetId::USDC, usdc); // keyed by asset_id
311        let mut prices = HashMap::new();
312        prices.insert(AssetId::USDC, 1_000_000u64); // keyed by asset_id
313
314        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
315
316        assert_eq!(result.usdc_balance_cents, 10_000); // $100
317        assert_eq!(result.total_spendable_cents, 10_000);
318        assert_eq!(result.available_credit_cents, 10_000);
319        assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
320        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
321        assert_eq!(result.position_infos.len(), 1);
322        assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
323    }
324
325    #[test]
326    fn deposit_and_borrow() {
327        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
328        let positions = vec![
329            make_position(0, 100_000_000, true), // 100 USDC deposit
330            make_position(0, 50_000_000, false), // 50 USDC borrow
331        ];
332
333        let mut markets = HashMap::new();
334        markets.insert(AssetId::USDC, usdc);
335        let mut prices = HashMap::new();
336        prices.insert(AssetId::USDC, 1_000_000u64);
337
338        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
339
340        assert_eq!(result.usdc_balance_cents, 10_000); // 100 USDC deposit
341        assert_eq!(result.total_spendable_cents, 5_000); // 100 - 50 = 50 USDC
342        assert_eq!(result.available_credit_cents, 5_000);
343    }
344
345    #[test]
346    fn unliquidatable_excluded_from_spendable() {
347        // wETH: drift_market_index=4, asset_id=4
348        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
349        let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
350
351        let positions = vec![
352            make_position(0, 10_000_000, true),    // 10 USDC (drift idx 0)
353            make_position(4, 1_000_000_000, true), // 1 wETH (drift idx 4)
354        ];
355
356        let mut markets = HashMap::new();
357        markets.insert(AssetId::USDC, usdc); // asset_id 0 = USDC
358        markets.insert(AssetId::WETH, weth); // asset_id 4 = wETH
359        let mut prices = HashMap::new();
360        prices.insert(AssetId::USDC, 1_000_000u64);
361        prices.insert(AssetId::WETH, 100_000_000u64); // $100
362
363        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
364        let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
365
366        // total_spendable: only USDC (10M base = 1000 cents), wETH excluded
367        assert_eq!(result.total_spendable_cents, 1_000);
368        // available_credit: includes wETH weighted (100M * 80% = 80M) + USDC (10M * 100% = 10M) = 90M = 9000 cents
369        assert_eq!(result.available_credit_cents, 9_000);
370        assert_eq!(result.position_infos.len(), 2);
371    }
372
373    #[test]
374    fn slippage_reduces_spendable() {
375        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
376        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
377
378        let mut markets = HashMap::new();
379        markets.insert(AssetId::USDC, usdc);
380        let mut prices = HashMap::new();
381        prices.insert(AssetId::USDC, 1_000_000u64);
382
383        // 10% slippage = 1000 bps
384        let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
385
386        // 100 USDC collateral, 10% slippage = 10 USDC, spendable = 90 USDC = 9000 cents
387        assert_eq!(result.total_spendable_cents, 9_000);
388        // available_credit not affected by slippage
389        assert_eq!(result.available_credit_cents, 10_000);
390    }
391
392    #[test]
393    fn missing_market_skipped() {
394        // drift_market_index=5 maps to asset_id=5 (USDT) — but no market data provided
395        let positions = vec![make_position(5, 1_000_000, true)];
396
397        let result =
398            calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
399
400        assert_eq!(result.total_spendable_cents, 0);
401        assert!(result.position_infos.is_empty());
402    }
403
404    #[test]
405    fn missing_price_skipped() {
406        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
407        let positions = vec![make_position(0, 1_000_000, true)];
408
409        let mut markets = HashMap::new();
410        markets.insert(AssetId::USDC, usdc); // asset_id 0
411
412        let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
413
414        assert_eq!(result.total_spendable_cents, 0);
415        assert!(result.position_infos.is_empty());
416    }
417
418    #[test]
419    fn multi_position_with_unliquidatable_and_slippage() {
420        // Use real tokens:
421        // USDC:  drift_idx=0,  asset_id=0, 6 decimals
422        // wETH:  drift_idx=4,  asset_id=4, 8 decimals -> unliquidatable
423        // wSOL:  drift_idx=1,  asset_id=1, 9 decimals
424        // USDT:  drift_idx=5,  asset_id=5, 6 decimals -> borrow
425        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
426        let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); // unliquidatable
427        let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
428        let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
429
430        let positions = vec![
431            make_position(0, 50_000_000, true),    // 50 USDC (drift idx 0)
432            make_position(4, 1_000_000_000, true), // 1 wETH @ $200 (drift idx 4, unliquidatable)
433            make_position(1, 500_000_000, true),   // 0.5 wSOL @ $100 (drift idx 1)
434            make_position(5, 20_000_000, false),   // 20 USDT borrow (drift idx 5)
435        ];
436
437        // Maps keyed by asset_id
438        let mut markets = HashMap::new();
439        markets.insert(AssetId::USDC, usdc); // asset_id 0
440        markets.insert(AssetId::WETH, weth); // asset_id 4
441        markets.insert(AssetId::WSOL, wsol); // asset_id 1
442        markets.insert(AssetId::USDT, usdt); // asset_id 5
443        let mut prices = HashMap::new();
444        prices.insert(AssetId::USDC, 1_000_000u64);
445        prices.insert(AssetId::WETH, 200_000_000u64);
446        prices.insert(AssetId::WSOL, 100_000_000u64);
447        prices.insert(AssetId::USDT, 1_000_000u64);
448
449        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
450        let result =
451            calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
452
453        // Unweighted collateral (excluding wETH): 50M (USDC) + 50M (wSOL: 0.5 * $100) = 100M
454        // Unweighted liabilities: 20M (USDT borrow)
455        // Slippage: 100M * 500/10000 = 5M
456        // total_spendable = 100M - 5M - 20M = 75M base = 7500 cents
457        assert_eq!(result.total_spendable_cents, 7_500);
458
459        // Weighted collateral: 50M*100% + 200M*80% + 50M*80% = 50M + 160M + 40M = 250M
460        // Weighted liabilities: 20M*100% = 20M
461        // available_credit = 250M - 20M = 230M base = 23000 cents
462        assert_eq!(result.available_credit_cents, 23_000);
463        assert_eq!(result.usdc_balance_cents, 5_000);
464        assert_eq!(result.position_infos.len(), 4);
465    }
466
467    // --- update_running_totals ---
468
469    #[test]
470    fn running_totals_positive() {
471        let mut pos = 0u64;
472        let mut neg = 0u64;
473        update_running_totals(&mut pos, &mut neg, 100).unwrap();
474        assert_eq!(pos, 100);
475        assert_eq!(neg, 0);
476    }
477
478    #[test]
479    fn running_totals_negative() {
480        let mut pos = 0u64;
481        let mut neg = 0u64;
482        update_running_totals(&mut pos, &mut neg, -50).unwrap();
483        assert_eq!(pos, 0);
484        assert_eq!(neg, 50);
485    }
486
487    #[test]
488    fn running_totals_accumulate() {
489        let mut pos = 10u64;
490        let mut neg = 5u64;
491        update_running_totals(&mut pos, &mut neg, 20).unwrap();
492        update_running_totals(&mut pos, &mut neg, -15).unwrap();
493        assert_eq!(pos, 30);
494        assert_eq!(neg, 20);
495    }
496}
497
498#[cfg(test)]
499#[allow(
500    clippy::allow_attributes,
501    clippy::allow_attributes_without_reason,
502    clippy::unwrap_used,
503    clippy::expect_used,
504    clippy::panic,
505    clippy::arithmetic_side_effects,
506    reason = "test code"
507)]
508mod proptests {
509    use super::*;
510    use proptest::prelude::*;
511
512    proptest! {
513        #[test]
514        fn spendable_le_collateral_minus_liabilities(
515            collateral_base in 0u64..=1_000_000_000_000u64,
516            liabilities_base in 0u64..=500_000_000_000u64,
517        ) {
518            // Spendable should never exceed collateral - liabilities (without slippage)
519            let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
520            let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
521            let max_possible = collateral_cents.saturating_sub(liabilities_cents);
522
523            // Since we're testing the formula directly, we can verify the invariant
524            let spendable_base = collateral_base.saturating_sub(liabilities_base);
525            let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
526            prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
527        }
528    }
529}