Skip to main content

pyra_margin/drift/
capacity.rs

1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13/// Per-position info emitted during capacity calculation.
14///
15/// Used downstream by liquidating spend jobs to know position sizes and weights.
16#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18    pub asset_id: AssetId,
19    /// Unsigned token balance in base units.
20    pub balance: u64,
21    pub position_type: SpotBalanceType,
22    pub price_usdc_base_units: u64,
23    /// Initial asset weight from the spot market (basis points).
24    pub weight_bps: u32,
25}
26
27/// Result of spending capacity calculation.
28#[derive(Debug, Clone)]
29pub struct CapacityResult {
30    /// Max spendable via liquidating spend (unweighted collateral - slippage - liabilities), in cents.
31    /// Excludes unliquidatable assets from collateral.
32    pub total_spendable_cents: u64,
33    /// Available credit line (weighted collateral - weighted liabilities), in cents.
34    pub available_credit_cents: u64,
35    /// USDC balance (market index 0) in cents.
36    pub usdc_balance_cents: u64,
37    /// Total weighted collateral in USDC base units.
38    pub weighted_collateral_usdc_base_units: u64,
39    /// Total weighted liabilities in USDC base units.
40    pub weighted_liabilities_usdc_base_units: u64,
41    /// Per-position breakdown for downstream use.
42    pub position_infos: Vec<PositionInfo>,
43}
44
45/// Calculate spending capacity from Drift spot positions.
46///
47/// This is the core calculation used for card transaction authorization.
48/// It computes:
49/// - **total_spendable**: max amount for liquidating spends (collateral minus slippage minus liabilities,
50///   excluding unliquidatable assets)
51/// - **available_credit**: credit line from weighted margin (weighted collateral minus weighted liabilities)
52/// - **usdc_balance**: direct USDC holdings
53///
54/// Positions in `unliquidatable_asset_ids` are excluded from unweighted collateral
55/// (affecting `total_spendable`) but still included in weighted calculations
56/// (affecting `available_credit`).
57///
58/// All maps are keyed by **asset_id** (not Drift market index).
59/// Positions whose asset_id is missing from `spot_market_map` or `price_map` are skipped.
60/// Positions for tokens not in `pyra_tokens` are also skipped.
61pub fn calculate_capacity(
62    spot_positions: &[SpotPosition],
63    spot_market_map: &HashMap<AssetId, SpotMarket>,
64    price_map: &HashMap<AssetId, u64>,
65    unliquidatable_asset_ids: &[AssetId],
66    max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68    let mut total_collateral_usdc_base_units: u64 = 0;
69    let mut total_liabilities_usdc_base_units: u64 = 0;
70
71    let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72    let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74    let mut usdc_balance_base_units: u64 = 0;
75
76    let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78    for position in spot_positions {
79        // Convert Drift market_index to Pyra asset_id for map lookups
80        let Some(token) =
81            pyra_tokens::Token::find_by_drift_market_index(position.market_index)
82        else {
83            continue;
84        };
85        let asset_id = token.asset_id;
86
87        let Some(spot_market) = spot_market_map.get(&asset_id) else {
88            continue;
89        };
90        let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
91            continue;
92        };
93
94        // Step 1: Calculate token balance and USDC value
95        let token_balance_base_units = get_token_balance(position, spot_market)?;
96
97        let is_asset = token_balance_base_units >= 0;
98        let twap5min = spot_market
99            .historical_oracle_data
100            .last_oracle_price_twap5min;
101        let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
102
103        let value_usdc_base_units = calculate_value_usdc_base_units(
104            token_balance_base_units,
105            strict_price,
106            spot_market.decimals,
107        )?;
108
109        // Accumulate unweighted totals (excluding unliquidatable collateral)
110        let is_unliquidatable_collateral =
111            unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
112        if !is_unliquidatable_collateral {
113            update_running_totals(
114                &mut total_collateral_usdc_base_units,
115                &mut total_liabilities_usdc_base_units,
116                value_usdc_base_units,
117            )?;
118        }
119
120        // Step 2: Apply IMF-adjusted weights
121        let token_amount_unsigned = token_balance_base_units.unsigned_abs();
122        let weight_bps = if is_asset {
123            calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
124                as i128
125        } else {
126            calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
127        };
128        let weighted_value_usdc_base_units = value_usdc_base_units
129            .checked_mul(weight_bps)
130            .ok_or(MathError::Overflow)?
131            .checked_div(MARGIN_PRECISION)
132            .ok_or(MathError::Overflow)?;
133
134        update_running_totals(
135            &mut total_weighted_collateral_usdc_base_units,
136            &mut total_weighted_liabilities_usdc_base_units,
137            weighted_value_usdc_base_units,
138        )?;
139
140        // Step 3: Track USDC balance (asset_id 0)
141        if asset_id == pyra_tokens::AssetId::USDC && usdc_balance_base_units == 0 && token_balance_base_units > 0 {
142            usdc_balance_base_units =
143                u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
144        }
145
146        // Step 4: Store position info
147        let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
148            .map_err(|_| MathError::Overflow)?;
149        position_infos.push(PositionInfo {
150            asset_id,
151            balance: token_balance_unsigned,
152            position_type: position.balance_type.clone(),
153            price_usdc_base_units,
154            weight_bps: spot_market.initial_asset_weight,
155        });
156    }
157
158    // Step 5: Available credit = weighted collateral - weighted liabilities
159    let available_credit_base_units = total_weighted_collateral_usdc_base_units
160        .saturating_sub(total_weighted_liabilities_usdc_base_units);
161    let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
162
163    // Step 6: Total spendable = collateral - slippage - liabilities (for liquidating spends)
164    let max_slippage_usdc_base_units = total_collateral_usdc_base_units
165        .checked_mul(max_slippage_bps)
166        .ok_or(MathError::Overflow)?
167        .checked_div(10_000)
168        .ok_or(MathError::Overflow)?;
169    let total_spendable_base_units = total_collateral_usdc_base_units
170        .saturating_sub(max_slippage_usdc_base_units)
171        .saturating_sub(total_liabilities_usdc_base_units);
172    let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
173
174    let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
175
176    Ok(CapacityResult {
177        total_spendable_cents,
178        available_credit_cents,
179        usdc_balance_cents,
180        weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
181        weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
182        position_infos,
183    })
184}
185
186/// Accumulate a signed value into positive/negative running totals.
187fn update_running_totals(
188    total_positive: &mut u64,
189    total_negative: &mut u64,
190    value: i128,
191) -> MathResult<()> {
192    let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
193
194    if value >= 0 {
195        *total_positive = total_positive
196            .checked_add(value_unsigned)
197            .ok_or(MathError::Overflow)?;
198    } else {
199        *total_negative = total_negative
200            .checked_add(value_unsigned)
201            .ok_or(MathError::Overflow)?;
202    }
203
204    Ok(())
205}
206
207#[cfg(test)]
208#[allow(
209    clippy::allow_attributes,
210    clippy::allow_attributes_without_reason,
211    clippy::unwrap_used,
212    clippy::expect_used,
213    clippy::panic,
214    clippy::arithmetic_side_effects,
215    reason = "test code"
216)]
217mod tests {
218    use super::*;
219    use pyra_types::{HistoricalOracleData, InsuranceFund};
220
221    fn make_spot_market_with_twap(
222        market_index: u16,
223        decimals: u32,
224        initial_asset_weight: u32,
225        initial_liability_weight: u32,
226        twap5min: i64,
227    ) -> SpotMarket {
228        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
229        SpotMarket {
230            pubkey: vec![],
231            market_index,
232            initial_asset_weight,
233            initial_liability_weight,
234            imf_factor: 0,
235            scale_initial_asset_weight_start: 0,
236            decimals,
237            cumulative_deposit_interest: precision_decrease,
238            cumulative_borrow_interest: precision_decrease,
239            deposit_balance: 0,
240            borrow_balance: 0,
241            optimal_utilization: 0,
242            optimal_borrow_rate: 0,
243            max_borrow_rate: 0,
244            min_borrow_rate: 0,
245            insurance_fund: InsuranceFund::default(),
246            historical_oracle_data: HistoricalOracleData {
247                last_oracle_price_twap5min: twap5min,
248            },
249            oracle: None,
250        }
251    }
252
253    /// Convenience: creates a spot market where TWAP matches the oracle price.
254    fn make_spot_market(
255        market_index: u16,
256        decimals: u32,
257        initial_asset_weight: u32,
258        initial_liability_weight: u32,
259        oracle_price: u64,
260    ) -> SpotMarket {
261        make_spot_market_with_twap(
262            market_index,
263            decimals,
264            initial_asset_weight,
265            initial_liability_weight,
266            oracle_price as i64,
267        )
268    }
269
270    /// Create a SpotPosition with the given Drift market_index.
271    fn make_position(
272        drift_market_index: u16,
273        scaled_balance: u64,
274        is_deposit: bool,
275    ) -> SpotPosition {
276        SpotPosition {
277            market_index: drift_market_index,
278            scaled_balance,
279            balance_type: if is_deposit {
280                SpotBalanceType::Deposit
281            } else {
282                SpotBalanceType::Borrow
283            },
284            ..Default::default()
285        }
286    }
287
288    // --- calculate_capacity ---
289
290    #[test]
291    fn empty_positions() {
292        let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
293        assert_eq!(result.total_spendable_cents, 0);
294        assert_eq!(result.available_credit_cents, 0);
295        assert_eq!(result.usdc_balance_cents, 0);
296        assert_eq!(result.weighted_collateral_usdc_base_units, 0);
297        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
298        assert!(result.position_infos.is_empty());
299    }
300
301    #[test]
302    fn single_usdc_deposit() {
303        // USDC: drift market_index=0, asset_id=0
304        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
305        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
306
307        let mut markets = HashMap::new();
308        markets.insert(AssetId::USDC, usdc); // keyed by asset_id
309        let mut prices = HashMap::new();
310        prices.insert(AssetId::USDC, 1_000_000u64); // keyed by asset_id
311
312        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
313
314        assert_eq!(result.usdc_balance_cents, 10_000); // $100
315        assert_eq!(result.total_spendable_cents, 10_000);
316        assert_eq!(result.available_credit_cents, 10_000);
317        assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
318        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
319        assert_eq!(result.position_infos.len(), 1);
320        assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
321    }
322
323    #[test]
324    fn deposit_and_borrow() {
325        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
326        let positions = vec![
327            make_position(0, 100_000_000, true), // 100 USDC deposit
328            make_position(0, 50_000_000, false), // 50 USDC borrow
329        ];
330
331        let mut markets = HashMap::new();
332        markets.insert(AssetId::USDC, usdc);
333        let mut prices = HashMap::new();
334        prices.insert(AssetId::USDC, 1_000_000u64);
335
336        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
337
338        assert_eq!(result.usdc_balance_cents, 10_000); // 100 USDC deposit
339        assert_eq!(result.total_spendable_cents, 5_000); // 100 - 50 = 50 USDC
340        assert_eq!(result.available_credit_cents, 5_000);
341    }
342
343    #[test]
344    fn unliquidatable_excluded_from_spendable() {
345        // wETH: drift_market_index=4, asset_id=4
346        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
347        let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
348
349        let positions = vec![
350            make_position(0, 10_000_000, true),    // 10 USDC (drift idx 0)
351            make_position(4, 1_000_000_000, true), // 1 wETH (drift idx 4)
352        ];
353
354        let mut markets = HashMap::new();
355        markets.insert(AssetId::USDC, usdc); // asset_id 0 = USDC
356        markets.insert(AssetId::WETH, weth); // asset_id 4 = wETH
357        let mut prices = HashMap::new();
358        prices.insert(AssetId::USDC, 1_000_000u64);
359        prices.insert(AssetId::WETH, 100_000_000u64); // $100
360
361        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
362        let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
363
364        // total_spendable: only USDC (10M base = 1000 cents), wETH excluded
365        assert_eq!(result.total_spendable_cents, 1_000);
366        // available_credit: includes wETH weighted (100M * 80% = 80M) + USDC (10M * 100% = 10M) = 90M = 9000 cents
367        assert_eq!(result.available_credit_cents, 9_000);
368        assert_eq!(result.position_infos.len(), 2);
369    }
370
371    #[test]
372    fn slippage_reduces_spendable() {
373        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
374        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
375
376        let mut markets = HashMap::new();
377        markets.insert(AssetId::USDC, usdc);
378        let mut prices = HashMap::new();
379        prices.insert(AssetId::USDC, 1_000_000u64);
380
381        // 10% slippage = 1000 bps
382        let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
383
384        // 100 USDC collateral, 10% slippage = 10 USDC, spendable = 90 USDC = 9000 cents
385        assert_eq!(result.total_spendable_cents, 9_000);
386        // available_credit not affected by slippage
387        assert_eq!(result.available_credit_cents, 10_000);
388    }
389
390    #[test]
391    fn missing_market_skipped() {
392        // drift_market_index=5 maps to asset_id=5 (USDT) — but no market data provided
393        let positions = vec![make_position(5, 1_000_000, true)];
394
395        let result =
396            calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
397
398        assert_eq!(result.total_spendable_cents, 0);
399        assert!(result.position_infos.is_empty());
400    }
401
402    #[test]
403    fn missing_price_skipped() {
404        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
405        let positions = vec![make_position(0, 1_000_000, true)];
406
407        let mut markets = HashMap::new();
408        markets.insert(AssetId::USDC, usdc); // asset_id 0
409
410        let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
411
412        assert_eq!(result.total_spendable_cents, 0);
413        assert!(result.position_infos.is_empty());
414    }
415
416    #[test]
417    fn multi_position_with_unliquidatable_and_slippage() {
418        // Use real tokens:
419        // USDC:  drift_idx=0,  asset_id=0, 6 decimals
420        // wETH:  drift_idx=4,  asset_id=4, 8 decimals -> unliquidatable
421        // wSOL:  drift_idx=1,  asset_id=1, 9 decimals
422        // USDT:  drift_idx=5,  asset_id=5, 6 decimals -> borrow
423        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
424        let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); // unliquidatable
425        let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
426        let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
427
428        let positions = vec![
429            make_position(0, 50_000_000, true),    // 50 USDC (drift idx 0)
430            make_position(4, 1_000_000_000, true), // 1 wETH @ $200 (drift idx 4, unliquidatable)
431            make_position(1, 500_000_000, true),   // 0.5 wSOL @ $100 (drift idx 1)
432            make_position(5, 20_000_000, false),   // 20 USDT borrow (drift idx 5)
433        ];
434
435        // Maps keyed by asset_id
436        let mut markets = HashMap::new();
437        markets.insert(AssetId::USDC, usdc);  // asset_id 0
438        markets.insert(AssetId::WETH, weth);  // asset_id 4
439        markets.insert(AssetId::WSOL, wsol);  // asset_id 1
440        markets.insert(AssetId::USDT, usdt);  // asset_id 5
441        let mut prices = HashMap::new();
442        prices.insert(AssetId::USDC, 1_000_000u64);
443        prices.insert(AssetId::WETH, 200_000_000u64);
444        prices.insert(AssetId::WSOL, 100_000_000u64);
445        prices.insert(AssetId::USDT, 1_000_000u64);
446
447        let unliquidatable = vec![AssetId::WETH]; // asset_id 4 = wETH
448        let result =
449            calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
450
451        // Unweighted collateral (excluding wETH): 50M (USDC) + 50M (wSOL: 0.5 * $100) = 100M
452        // Unweighted liabilities: 20M (USDT borrow)
453        // Slippage: 100M * 500/10000 = 5M
454        // total_spendable = 100M - 5M - 20M = 75M base = 7500 cents
455        assert_eq!(result.total_spendable_cents, 7_500);
456
457        // Weighted collateral: 50M*100% + 200M*80% + 50M*80% = 50M + 160M + 40M = 250M
458        // Weighted liabilities: 20M*100% = 20M
459        // available_credit = 250M - 20M = 230M base = 23000 cents
460        assert_eq!(result.available_credit_cents, 23_000);
461        assert_eq!(result.usdc_balance_cents, 5_000);
462        assert_eq!(result.position_infos.len(), 4);
463    }
464
465    // --- update_running_totals ---
466
467    #[test]
468    fn running_totals_positive() {
469        let mut pos = 0u64;
470        let mut neg = 0u64;
471        update_running_totals(&mut pos, &mut neg, 100).unwrap();
472        assert_eq!(pos, 100);
473        assert_eq!(neg, 0);
474    }
475
476    #[test]
477    fn running_totals_negative() {
478        let mut pos = 0u64;
479        let mut neg = 0u64;
480        update_running_totals(&mut pos, &mut neg, -50).unwrap();
481        assert_eq!(pos, 0);
482        assert_eq!(neg, 50);
483    }
484
485    #[test]
486    fn running_totals_accumulate() {
487        let mut pos = 10u64;
488        let mut neg = 5u64;
489        update_running_totals(&mut pos, &mut neg, 20).unwrap();
490        update_running_totals(&mut pos, &mut neg, -15).unwrap();
491        assert_eq!(pos, 30);
492        assert_eq!(neg, 20);
493    }
494}
495
496#[cfg(test)]
497#[allow(
498    clippy::allow_attributes,
499    clippy::allow_attributes_without_reason,
500    clippy::unwrap_used,
501    clippy::expect_used,
502    clippy::panic,
503    clippy::arithmetic_side_effects,
504    reason = "test code"
505)]
506mod proptests {
507    use super::*;
508    use proptest::prelude::*;
509
510    proptest! {
511        #[test]
512        fn spendable_le_collateral_minus_liabilities(
513            collateral_base in 0u64..=1_000_000_000_000u64,
514            liabilities_base in 0u64..=500_000_000_000u64,
515        ) {
516            // Spendable should never exceed collateral - liabilities (without slippage)
517            let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
518            let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
519            let max_possible = collateral_cents.saturating_sub(liabilities_cents);
520
521            // Since we're testing the formula directly, we can verify the invariant
522            let spendable_base = collateral_base.saturating_sub(liabilities_base);
523            let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
524            prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
525        }
526    }
527}