Skip to main content

pyra_margin/drift/
capacity.rs

1use std::collections::HashMap;
2
3use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
4
5use super::balance::{calculate_value_usdc_base_units, get_token_balance};
6use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
7use crate::common::usdc_base_units_to_cents;
8use crate::error::{MathError, MathResult};
9
10const MARGIN_PRECISION: i128 = 10_000;
11
12/// Per-position info emitted during capacity calculation.
13///
14/// Used downstream by liquidating spend jobs to know position sizes and weights.
15#[derive(Debug, Clone, PartialEq)]
16pub struct PositionInfo {
17    pub market_index: u16,
18    /// Unsigned token balance in base units.
19    pub balance: u64,
20    pub position_type: SpotBalanceType,
21    pub price_usdc_base_units: u64,
22    /// Initial asset weight from the spot market (basis points).
23    pub weight_bps: u32,
24}
25
26/// Result of spending capacity calculation.
27#[derive(Debug, Clone)]
28pub struct CapacityResult {
29    /// Max spendable via liquidating spend (unweighted collateral - slippage - liabilities), in cents.
30    /// Excludes unliquidatable assets from collateral.
31    pub total_spendable_cents: u64,
32    /// Available credit line (weighted collateral - weighted liabilities), in cents.
33    pub available_credit_cents: u64,
34    /// USDC balance (market index 0) in cents.
35    pub usdc_balance_cents: u64,
36    /// Total weighted collateral in USDC base units.
37    pub weighted_collateral_usdc_base_units: u64,
38    /// Total weighted liabilities in USDC base units.
39    pub weighted_liabilities_usdc_base_units: u64,
40    /// Per-position breakdown for downstream use.
41    pub position_infos: Vec<PositionInfo>,
42}
43
44/// Calculate spending capacity from Drift spot positions.
45///
46/// This is the core calculation used for card transaction authorization.
47/// It computes:
48/// - **total_spendable**: max amount for liquidating spends (collateral minus slippage minus liabilities,
49///   excluding unliquidatable assets)
50/// - **available_credit**: credit line from weighted margin (weighted collateral minus weighted liabilities)
51/// - **usdc_balance**: direct USDC holdings
52///
53/// Positions in `unliquidatable_market_indices` are excluded from unweighted collateral
54/// (affecting `total_spendable`) but still included in weighted calculations
55/// (affecting `available_credit`).
56///
57/// Positions whose market index is missing from `spot_market_map` or `price_map` are skipped.
58pub fn calculate_capacity(
59    spot_positions: &[SpotPosition],
60    spot_market_map: &HashMap<u16, SpotMarket>,
61    price_map: &HashMap<u16, u64>,
62    unliquidatable_market_indices: &[u16],
63    max_slippage_bps: u64,
64) -> MathResult<CapacityResult> {
65    let mut total_collateral_usdc_base_units: u64 = 0;
66    let mut total_liabilities_usdc_base_units: u64 = 0;
67
68    let mut total_weighted_collateral_usdc_base_units: u64 = 0;
69    let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
70
71    let mut usdc_balance_base_units: u64 = 0;
72
73    let mut position_infos: Vec<PositionInfo> = Vec::new();
74
75    for position in spot_positions {
76        let market_index = position.market_index;
77
78        let Some(spot_market) = spot_market_map.get(&market_index) else {
79            continue;
80        };
81        let Some(price_usdc_base_units) = price_map.get(&market_index).copied() else {
82            continue;
83        };
84
85        // Step 1: Calculate token balance and USDC value
86        let token_balance_base_units = get_token_balance(position, spot_market)?;
87
88        let is_asset = token_balance_base_units >= 0;
89        let twap5min = spot_market
90            .historical_oracle_data
91            .last_oracle_price_twap5min;
92        let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
93
94        let value_usdc_base_units = calculate_value_usdc_base_units(
95            token_balance_base_units,
96            strict_price,
97            spot_market.decimals,
98        )?;
99
100        // Accumulate unweighted totals (excluding unliquidatable collateral)
101        let is_unliquidatable_collateral =
102            unliquidatable_market_indices.contains(&market_index) && value_usdc_base_units > 0;
103        if !is_unliquidatable_collateral {
104            update_running_totals(
105                &mut total_collateral_usdc_base_units,
106                &mut total_liabilities_usdc_base_units,
107                value_usdc_base_units,
108            )?;
109        }
110
111        // Step 2: Apply IMF-adjusted weights
112        let token_amount_unsigned = token_balance_base_units.unsigned_abs();
113        let weight_bps = if is_asset {
114            calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
115                as i128
116        } else {
117            calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
118        };
119        let weighted_value_usdc_base_units = value_usdc_base_units
120            .checked_mul(weight_bps)
121            .ok_or(MathError::Overflow)?
122            .checked_div(MARGIN_PRECISION)
123            .ok_or(MathError::Overflow)?;
124
125        update_running_totals(
126            &mut total_weighted_collateral_usdc_base_units,
127            &mut total_weighted_liabilities_usdc_base_units,
128            weighted_value_usdc_base_units,
129        )?;
130
131        // Step 3: Track USDC balance (market index 0)
132        if market_index == 0 && usdc_balance_base_units == 0 && token_balance_base_units > 0 {
133            usdc_balance_base_units =
134                u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
135        }
136
137        // Step 4: Store position info
138        let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
139            .map_err(|_| MathError::Overflow)?;
140        position_infos.push(PositionInfo {
141            market_index,
142            balance: token_balance_unsigned,
143            position_type: position.balance_type.clone(),
144            price_usdc_base_units,
145            weight_bps: spot_market.initial_asset_weight,
146        });
147    }
148
149    // Step 5: Available credit = weighted collateral - weighted liabilities
150    let available_credit_base_units = total_weighted_collateral_usdc_base_units
151        .saturating_sub(total_weighted_liabilities_usdc_base_units);
152    let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
153
154    // Step 6: Total spendable = collateral - slippage - liabilities (for liquidating spends)
155    let max_slippage_usdc_base_units = total_collateral_usdc_base_units
156        .checked_mul(max_slippage_bps)
157        .ok_or(MathError::Overflow)?
158        .checked_div(10_000)
159        .ok_or(MathError::Overflow)?;
160    let total_spendable_base_units = total_collateral_usdc_base_units
161        .saturating_sub(max_slippage_usdc_base_units)
162        .saturating_sub(total_liabilities_usdc_base_units);
163    let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
164
165    let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
166
167    Ok(CapacityResult {
168        total_spendable_cents,
169        available_credit_cents,
170        usdc_balance_cents,
171        weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
172        weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
173        position_infos,
174    })
175}
176
177/// Accumulate a signed value into positive/negative running totals.
178fn update_running_totals(
179    total_positive: &mut u64,
180    total_negative: &mut u64,
181    value: i128,
182) -> MathResult<()> {
183    let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
184
185    if value >= 0 {
186        *total_positive = total_positive
187            .checked_add(value_unsigned)
188            .ok_or(MathError::Overflow)?;
189    } else {
190        *total_negative = total_negative
191            .checked_add(value_unsigned)
192            .ok_or(MathError::Overflow)?;
193    }
194
195    Ok(())
196}
197
198#[cfg(test)]
199#[allow(
200    clippy::unwrap_used,
201    clippy::expect_used,
202    clippy::panic,
203    clippy::arithmetic_side_effects
204)]
205mod tests {
206    use super::*;
207    use pyra_types::{HistoricalOracleData, InsuranceFund};
208
209    fn make_spot_market_with_twap(
210        market_index: u16,
211        decimals: u32,
212        initial_asset_weight: u32,
213        initial_liability_weight: u32,
214        twap5min: i64,
215    ) -> SpotMarket {
216        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
217        SpotMarket {
218            pubkey: vec![],
219            market_index,
220            initial_asset_weight,
221            initial_liability_weight,
222            imf_factor: 0,
223            scale_initial_asset_weight_start: 0,
224            decimals,
225            cumulative_deposit_interest: precision_decrease,
226            cumulative_borrow_interest: precision_decrease,
227            deposit_balance: 0,
228            borrow_balance: 0,
229            optimal_utilization: 0,
230            optimal_borrow_rate: 0,
231            max_borrow_rate: 0,
232            min_borrow_rate: 0,
233            insurance_fund: InsuranceFund::default(),
234            historical_oracle_data: HistoricalOracleData {
235                last_oracle_price_twap5min: twap5min,
236            },
237            oracle: None,
238        }
239    }
240
241    /// Convenience: creates a spot market where TWAP matches the oracle price.
242    fn make_spot_market(
243        market_index: u16,
244        decimals: u32,
245        initial_asset_weight: u32,
246        initial_liability_weight: u32,
247        oracle_price: u64,
248    ) -> SpotMarket {
249        make_spot_market_with_twap(
250            market_index,
251            decimals,
252            initial_asset_weight,
253            initial_liability_weight,
254            oracle_price as i64,
255        )
256    }
257
258    fn make_position(market_index: u16, scaled_balance: u64, is_deposit: bool) -> SpotPosition {
259        SpotPosition {
260            market_index,
261            scaled_balance,
262            balance_type: if is_deposit {
263                SpotBalanceType::Deposit
264            } else {
265                SpotBalanceType::Borrow
266            },
267            ..Default::default()
268        }
269    }
270
271    // --- calculate_capacity ---
272
273    #[test]
274    fn empty_positions() {
275        let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
276        assert_eq!(result.total_spendable_cents, 0);
277        assert_eq!(result.available_credit_cents, 0);
278        assert_eq!(result.usdc_balance_cents, 0);
279        assert_eq!(result.weighted_collateral_usdc_base_units, 0);
280        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
281        assert!(result.position_infos.is_empty());
282    }
283
284    #[test]
285    fn single_usdc_deposit() {
286        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
287        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
288
289        let mut markets = HashMap::new();
290        markets.insert(0, usdc);
291        let mut prices = HashMap::new();
292        prices.insert(0, 1_000_000u64);
293
294        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
295
296        assert_eq!(result.usdc_balance_cents, 10_000); // $100
297        assert_eq!(result.total_spendable_cents, 10_000);
298        assert_eq!(result.available_credit_cents, 10_000);
299        assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
300        assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
301        assert_eq!(result.position_infos.len(), 1);
302        assert_eq!(result.position_infos[0].market_index, 0);
303    }
304
305    #[test]
306    fn deposit_and_borrow() {
307        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
308        let positions = vec![
309            make_position(0, 100_000_000, true), // 100 USDC deposit
310            make_position(0, 50_000_000, false), // 50 USDC borrow
311        ];
312
313        let mut markets = HashMap::new();
314        markets.insert(0, usdc);
315        let mut prices = HashMap::new();
316        prices.insert(0, 1_000_000u64);
317
318        let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
319
320        assert_eq!(result.usdc_balance_cents, 10_000); // 100 USDC deposit
321        assert_eq!(result.total_spendable_cents, 5_000); // 100 - 50 = 50 USDC
322        assert_eq!(result.available_credit_cents, 5_000);
323    }
324
325    #[test]
326    fn unliquidatable_excluded_from_spendable() {
327        // Market 4 is unliquidatable
328        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
329        let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
330
331        let positions = vec![
332            make_position(0, 10_000_000, true),    // 10 USDC
333            make_position(4, 1_000_000_000, true), // 1 wETH at $100
334        ];
335
336        let mut markets = HashMap::new();
337        markets.insert(0, usdc);
338        markets.insert(4, weth);
339        let mut prices = HashMap::new();
340        prices.insert(0, 1_000_000u64);
341        prices.insert(4, 100_000_000u64); // $100
342
343        let unliquidatable = vec![4u16];
344        let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
345
346        // total_spendable: only USDC (10M base = 1000 cents), wETH excluded
347        assert_eq!(result.total_spendable_cents, 1_000);
348        // available_credit: includes wETH weighted (100M * 80% = 80M) + USDC (10M * 100% = 10M) = 90M = 9000 cents
349        assert_eq!(result.available_credit_cents, 9_000);
350        assert_eq!(result.position_infos.len(), 2);
351    }
352
353    #[test]
354    fn slippage_reduces_spendable() {
355        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
356        let positions = vec![make_position(0, 100_000_000, true)]; // 100 USDC
357
358        let mut markets = HashMap::new();
359        markets.insert(0, usdc);
360        let mut prices = HashMap::new();
361        prices.insert(0, 1_000_000u64);
362
363        // 10% slippage = 1000 bps
364        let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
365
366        // 100 USDC collateral, 10% slippage = 10 USDC, spendable = 90 USDC = 9000 cents
367        assert_eq!(result.total_spendable_cents, 9_000);
368        // available_credit not affected by slippage
369        assert_eq!(result.available_credit_cents, 10_000);
370    }
371
372    #[test]
373    fn missing_market_skipped() {
374        let positions = vec![make_position(5, 1_000_000, true)];
375
376        let result =
377            calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
378
379        assert_eq!(result.total_spendable_cents, 0);
380        assert!(result.position_infos.is_empty());
381    }
382
383    #[test]
384    fn missing_price_skipped() {
385        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
386        let positions = vec![make_position(0, 1_000_000, true)];
387
388        let mut markets = HashMap::new();
389        markets.insert(0, usdc);
390
391        let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
392
393        assert_eq!(result.total_spendable_cents, 0);
394        assert!(result.position_infos.is_empty());
395    }
396
397    #[test]
398    fn multi_position_with_unliquidatable_and_slippage() {
399        let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
400        let m4 = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); // unliquidatable
401        let m5 = make_spot_market(5, 9, 8_000, 12_000, 100_000_000);
402        let m6 = make_spot_market(6, 6, 10_000, 10_000, 1_000_000);
403
404        let positions = vec![
405            make_position(0, 50_000_000, true),    // 50 USDC
406            make_position(4, 1_000_000_000, true), // 1 token @ $200 (unliquidatable)
407            make_position(5, 500_000_000, true),   // 0.5 token @ $100
408            make_position(6, 20_000_000, false),   // 20 USDC-like borrow
409        ];
410
411        let mut markets = HashMap::new();
412        markets.insert(0, usdc);
413        markets.insert(4, m4);
414        markets.insert(5, m5);
415        markets.insert(6, m6);
416        let mut prices = HashMap::new();
417        prices.insert(0, 1_000_000u64);
418        prices.insert(4, 200_000_000u64);
419        prices.insert(5, 100_000_000u64);
420        prices.insert(6, 1_000_000u64);
421
422        let unliquidatable = vec![4u16, 32u16];
423        let result =
424            calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
425
426        // Unweighted collateral (excluding market 4): 50M (USDC) + 50M (m5: 0.5 * $100) = 100M
427        // Unweighted liabilities: 20M (m6 borrow)
428        // Slippage: 100M * 500/10000 = 5M
429        // total_spendable = 100M - 5M - 20M = 75M base = 7500 cents
430        assert_eq!(result.total_spendable_cents, 7_500);
431
432        // Weighted collateral: 50M*100% + 200M*80% + 50M*80% = 50M + 160M + 40M = 250M
433        // Weighted liabilities: 20M*100% = 20M
434        // available_credit = 250M - 20M = 230M base = 23000 cents
435        assert_eq!(result.available_credit_cents, 23_000);
436        assert_eq!(result.usdc_balance_cents, 5_000);
437        assert_eq!(result.position_infos.len(), 4);
438    }
439
440    // --- update_running_totals ---
441
442    #[test]
443    fn running_totals_positive() {
444        let mut pos = 0u64;
445        let mut neg = 0u64;
446        update_running_totals(&mut pos, &mut neg, 100).unwrap();
447        assert_eq!(pos, 100);
448        assert_eq!(neg, 0);
449    }
450
451    #[test]
452    fn running_totals_negative() {
453        let mut pos = 0u64;
454        let mut neg = 0u64;
455        update_running_totals(&mut pos, &mut neg, -50).unwrap();
456        assert_eq!(pos, 0);
457        assert_eq!(neg, 50);
458    }
459
460    #[test]
461    fn running_totals_accumulate() {
462        let mut pos = 10u64;
463        let mut neg = 5u64;
464        update_running_totals(&mut pos, &mut neg, 20).unwrap();
465        update_running_totals(&mut pos, &mut neg, -15).unwrap();
466        assert_eq!(pos, 30);
467        assert_eq!(neg, 20);
468    }
469}
470
471#[cfg(test)]
472#[allow(
473    clippy::unwrap_used,
474    clippy::expect_used,
475    clippy::panic,
476    clippy::arithmetic_side_effects
477)]
478mod proptests {
479    use super::*;
480    use proptest::prelude::*;
481
482    proptest! {
483        #[test]
484        fn spendable_le_collateral_minus_liabilities(
485            collateral_base in 0u64..=1_000_000_000_000u64,
486            liabilities_base in 0u64..=500_000_000_000u64,
487        ) {
488            // Spendable should never exceed collateral - liabilities (without slippage)
489            let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
490            let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
491            let max_possible = collateral_cents.saturating_sub(liabilities_cents);
492
493            // Since we're testing the formula directly, we can verify the invariant
494            let spendable_base = collateral_base.saturating_sub(liabilities_base);
495            let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
496            prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
497        }
498    }
499}