1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18 pub asset_id: AssetId,
19 pub balance: u64,
21 pub position_type: SpotBalanceType,
22 pub price_usdc_base_units: u64,
23 pub weight_bps: u32,
25}
26
27#[derive(Debug, Clone)]
29pub struct CapacityResult {
30 pub total_spendable_cents: u64,
33 pub available_credit_cents: u64,
35 pub usdc_balance_cents: u64,
37 pub weighted_collateral_usdc_base_units: u64,
39 pub weighted_liabilities_usdc_base_units: u64,
41 pub position_infos: Vec<PositionInfo>,
43}
44
45pub fn calculate_capacity(
62 spot_positions: &[SpotPosition],
63 spot_market_map: &HashMap<AssetId, SpotMarket>,
64 price_map: &HashMap<AssetId, u64>,
65 unliquidatable_asset_ids: &[AssetId],
66 max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68 let mut total_collateral_usdc_base_units: u64 = 0;
69 let mut total_liabilities_usdc_base_units: u64 = 0;
70
71 let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72 let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74 let mut usdc_balance_base_units: u64 = 0;
75
76 let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78 for position in spot_positions {
79 let Some(token) =
81 pyra_tokens::Token::find_by_drift_market_index(position.market_index)
82 else {
83 continue;
84 };
85 let asset_id = token.asset_id;
86
87 let Some(spot_market) = spot_market_map.get(&asset_id) else {
88 continue;
89 };
90 let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
91 continue;
92 };
93
94 let token_balance_base_units = get_token_balance(position, spot_market)?;
96
97 let is_asset = token_balance_base_units >= 0;
98 let twap5min = spot_market
99 .historical_oracle_data
100 .last_oracle_price_twap5min;
101 let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
102
103 let value_usdc_base_units = calculate_value_usdc_base_units(
104 token_balance_base_units,
105 strict_price,
106 spot_market.decimals,
107 )?;
108
109 let is_unliquidatable_collateral =
111 unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
112 if !is_unliquidatable_collateral {
113 update_running_totals(
114 &mut total_collateral_usdc_base_units,
115 &mut total_liabilities_usdc_base_units,
116 value_usdc_base_units,
117 )?;
118 }
119
120 let token_amount_unsigned = token_balance_base_units.unsigned_abs();
122 let weight_bps = if is_asset {
123 calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
124 as i128
125 } else {
126 calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
127 };
128 let weighted_value_usdc_base_units = value_usdc_base_units
129 .checked_mul(weight_bps)
130 .ok_or(MathError::Overflow)?
131 .checked_div(MARGIN_PRECISION)
132 .ok_or(MathError::Overflow)?;
133
134 update_running_totals(
135 &mut total_weighted_collateral_usdc_base_units,
136 &mut total_weighted_liabilities_usdc_base_units,
137 weighted_value_usdc_base_units,
138 )?;
139
140 if asset_id == pyra_tokens::AssetId::USDC && usdc_balance_base_units == 0 && token_balance_base_units > 0 {
142 usdc_balance_base_units =
143 u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
144 }
145
146 let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
148 .map_err(|_| MathError::Overflow)?;
149 position_infos.push(PositionInfo {
150 asset_id,
151 balance: token_balance_unsigned,
152 position_type: position.balance_type.clone(),
153 price_usdc_base_units,
154 weight_bps: spot_market.initial_asset_weight,
155 });
156 }
157
158 let available_credit_base_units = total_weighted_collateral_usdc_base_units
160 .saturating_sub(total_weighted_liabilities_usdc_base_units);
161 let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
162
163 let max_slippage_usdc_base_units = total_collateral_usdc_base_units
165 .checked_mul(max_slippage_bps)
166 .ok_or(MathError::Overflow)?
167 .checked_div(10_000)
168 .ok_or(MathError::Overflow)?;
169 let total_spendable_base_units = total_collateral_usdc_base_units
170 .saturating_sub(max_slippage_usdc_base_units)
171 .saturating_sub(total_liabilities_usdc_base_units);
172 let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
173
174 let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
175
176 Ok(CapacityResult {
177 total_spendable_cents,
178 available_credit_cents,
179 usdc_balance_cents,
180 weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
181 weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
182 position_infos,
183 })
184}
185
186fn update_running_totals(
188 total_positive: &mut u64,
189 total_negative: &mut u64,
190 value: i128,
191) -> MathResult<()> {
192 let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
193
194 if value >= 0 {
195 *total_positive = total_positive
196 .checked_add(value_unsigned)
197 .ok_or(MathError::Overflow)?;
198 } else {
199 *total_negative = total_negative
200 .checked_add(value_unsigned)
201 .ok_or(MathError::Overflow)?;
202 }
203
204 Ok(())
205}
206
207#[cfg(test)]
208#[allow(
209 clippy::unwrap_used,
210 clippy::expect_used,
211 clippy::panic,
212 clippy::arithmetic_side_effects
213)]
214mod tests {
215 use super::*;
216 use pyra_types::{HistoricalOracleData, InsuranceFund};
217
218 fn make_spot_market_with_twap(
219 market_index: u16,
220 decimals: u32,
221 initial_asset_weight: u32,
222 initial_liability_weight: u32,
223 twap5min: i64,
224 ) -> SpotMarket {
225 let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
226 SpotMarket {
227 pubkey: vec![],
228 market_index,
229 initial_asset_weight,
230 initial_liability_weight,
231 imf_factor: 0,
232 scale_initial_asset_weight_start: 0,
233 decimals,
234 cumulative_deposit_interest: precision_decrease,
235 cumulative_borrow_interest: precision_decrease,
236 deposit_balance: 0,
237 borrow_balance: 0,
238 optimal_utilization: 0,
239 optimal_borrow_rate: 0,
240 max_borrow_rate: 0,
241 min_borrow_rate: 0,
242 insurance_fund: InsuranceFund::default(),
243 historical_oracle_data: HistoricalOracleData {
244 last_oracle_price_twap5min: twap5min,
245 },
246 oracle: None,
247 }
248 }
249
250 fn make_spot_market(
252 market_index: u16,
253 decimals: u32,
254 initial_asset_weight: u32,
255 initial_liability_weight: u32,
256 oracle_price: u64,
257 ) -> SpotMarket {
258 make_spot_market_with_twap(
259 market_index,
260 decimals,
261 initial_asset_weight,
262 initial_liability_weight,
263 oracle_price as i64,
264 )
265 }
266
267 fn make_position(
269 drift_market_index: u16,
270 scaled_balance: u64,
271 is_deposit: bool,
272 ) -> SpotPosition {
273 SpotPosition {
274 market_index: drift_market_index,
275 scaled_balance,
276 balance_type: if is_deposit {
277 SpotBalanceType::Deposit
278 } else {
279 SpotBalanceType::Borrow
280 },
281 ..Default::default()
282 }
283 }
284
285 #[test]
288 fn empty_positions() {
289 let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
290 assert_eq!(result.total_spendable_cents, 0);
291 assert_eq!(result.available_credit_cents, 0);
292 assert_eq!(result.usdc_balance_cents, 0);
293 assert_eq!(result.weighted_collateral_usdc_base_units, 0);
294 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
295 assert!(result.position_infos.is_empty());
296 }
297
298 #[test]
299 fn single_usdc_deposit() {
300 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
302 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
305 markets.insert(AssetId::USDC, usdc); let mut prices = HashMap::new();
307 prices.insert(AssetId::USDC, 1_000_000u64); let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
310
311 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 10_000);
313 assert_eq!(result.available_credit_cents, 10_000);
314 assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
315 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
316 assert_eq!(result.position_infos.len(), 1);
317 assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
318 }
319
320 #[test]
321 fn deposit_and_borrow() {
322 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
323 let positions = vec![
324 make_position(0, 100_000_000, true), make_position(0, 50_000_000, false), ];
327
328 let mut markets = HashMap::new();
329 markets.insert(AssetId::USDC, usdc);
330 let mut prices = HashMap::new();
331 prices.insert(AssetId::USDC, 1_000_000u64);
332
333 let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
334
335 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 5_000); assert_eq!(result.available_credit_cents, 5_000);
338 }
339
340 #[test]
341 fn unliquidatable_excluded_from_spendable() {
342 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
344 let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
345
346 let positions = vec![
347 make_position(0, 10_000_000, true), make_position(4, 1_000_000_000, true), ];
350
351 let mut markets = HashMap::new();
352 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); let mut prices = HashMap::new();
355 prices.insert(AssetId::USDC, 1_000_000u64);
356 prices.insert(AssetId::WETH, 100_000_000u64); let unliquidatable = vec![AssetId::WETH]; let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
360
361 assert_eq!(result.total_spendable_cents, 1_000);
363 assert_eq!(result.available_credit_cents, 9_000);
365 assert_eq!(result.position_infos.len(), 2);
366 }
367
368 #[test]
369 fn slippage_reduces_spendable() {
370 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
371 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
374 markets.insert(AssetId::USDC, usdc);
375 let mut prices = HashMap::new();
376 prices.insert(AssetId::USDC, 1_000_000u64);
377
378 let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
380
381 assert_eq!(result.total_spendable_cents, 9_000);
383 assert_eq!(result.available_credit_cents, 10_000);
385 }
386
387 #[test]
388 fn missing_market_skipped() {
389 let positions = vec![make_position(5, 1_000_000, true)];
391
392 let result =
393 calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
394
395 assert_eq!(result.total_spendable_cents, 0);
396 assert!(result.position_infos.is_empty());
397 }
398
399 #[test]
400 fn missing_price_skipped() {
401 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
402 let positions = vec![make_position(0, 1_000_000, true)];
403
404 let mut markets = HashMap::new();
405 markets.insert(AssetId::USDC, usdc); let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
408
409 assert_eq!(result.total_spendable_cents, 0);
410 assert!(result.position_infos.is_empty());
411 }
412
413 #[test]
414 fn multi_position_with_unliquidatable_and_slippage() {
415 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
421 let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
423 let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
424
425 let positions = vec![
426 make_position(0, 50_000_000, true), make_position(4, 1_000_000_000, true), make_position(1, 500_000_000, true), make_position(5, 20_000_000, false), ];
431
432 let mut markets = HashMap::new();
434 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); markets.insert(AssetId::WSOL, wsol); markets.insert(AssetId::USDT, usdt); let mut prices = HashMap::new();
439 prices.insert(AssetId::USDC, 1_000_000u64);
440 prices.insert(AssetId::WETH, 200_000_000u64);
441 prices.insert(AssetId::WSOL, 100_000_000u64);
442 prices.insert(AssetId::USDT, 1_000_000u64);
443
444 let unliquidatable = vec![AssetId::WETH]; let result =
446 calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
447
448 assert_eq!(result.total_spendable_cents, 7_500);
453
454 assert_eq!(result.available_credit_cents, 23_000);
458 assert_eq!(result.usdc_balance_cents, 5_000);
459 assert_eq!(result.position_infos.len(), 4);
460 }
461
462 #[test]
465 fn running_totals_positive() {
466 let mut pos = 0u64;
467 let mut neg = 0u64;
468 update_running_totals(&mut pos, &mut neg, 100).unwrap();
469 assert_eq!(pos, 100);
470 assert_eq!(neg, 0);
471 }
472
473 #[test]
474 fn running_totals_negative() {
475 let mut pos = 0u64;
476 let mut neg = 0u64;
477 update_running_totals(&mut pos, &mut neg, -50).unwrap();
478 assert_eq!(pos, 0);
479 assert_eq!(neg, 50);
480 }
481
482 #[test]
483 fn running_totals_accumulate() {
484 let mut pos = 10u64;
485 let mut neg = 5u64;
486 update_running_totals(&mut pos, &mut neg, 20).unwrap();
487 update_running_totals(&mut pos, &mut neg, -15).unwrap();
488 assert_eq!(pos, 30);
489 assert_eq!(neg, 20);
490 }
491}
492
493#[cfg(test)]
494#[allow(
495 clippy::unwrap_used,
496 clippy::expect_used,
497 clippy::panic,
498 clippy::arithmetic_side_effects
499)]
500mod proptests {
501 use super::*;
502 use proptest::prelude::*;
503
504 proptest! {
505 #[test]
506 fn spendable_le_collateral_minus_liabilities(
507 collateral_base in 0u64..=1_000_000_000_000u64,
508 liabilities_base in 0u64..=500_000_000_000u64,
509 ) {
510 let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
512 let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
513 let max_possible = collateral_cents.saturating_sub(liabilities_cents);
514
515 let spendable_base = collateral_base.saturating_sub(liabilities_base);
517 let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
518 prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
519 }
520 }
521}