1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18 pub asset_id: AssetId,
19 pub balance: u64,
21 pub position_type: SpotBalanceType,
22 pub price_usdc_base_units: u64,
23 pub weight_bps: u32,
25}
26
27#[derive(Debug, Clone)]
29pub struct CapacityResult {
30 pub total_spendable_cents: u64,
33 pub available_credit_cents: u64,
35 pub usdc_balance_cents: u64,
37 pub weighted_collateral_usdc_base_units: u64,
39 pub weighted_liabilities_usdc_base_units: u64,
41 pub position_infos: Vec<PositionInfo>,
43}
44
45pub fn calculate_capacity(
62 spot_positions: &[SpotPosition],
63 spot_market_map: &HashMap<AssetId, SpotMarket>,
64 price_map: &HashMap<AssetId, u64>,
65 unliquidatable_asset_ids: &[AssetId],
66 max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68 let mut total_collateral_usdc_base_units: u64 = 0;
69 let mut total_liabilities_usdc_base_units: u64 = 0;
70
71 let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72 let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74 let mut usdc_balance_base_units: u64 = 0;
75
76 let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78 for position in spot_positions {
79 let Some(token) =
81 pyra_tokens::Token::find_by_drift_market_index(position.market_index)
82 else {
83 continue;
84 };
85 let asset_id = token.asset_id;
86
87 let Some(spot_market) = spot_market_map.get(&asset_id) else {
88 continue;
89 };
90 let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
91 continue;
92 };
93
94 let token_balance_base_units = get_token_balance(position, spot_market)?;
96
97 let is_asset = token_balance_base_units >= 0;
98 let twap5min = spot_market
99 .historical_oracle_data
100 .last_oracle_price_twap5min;
101 let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
102
103 let value_usdc_base_units = calculate_value_usdc_base_units(
104 token_balance_base_units,
105 strict_price,
106 spot_market.decimals,
107 )?;
108
109 let is_unliquidatable_collateral =
111 unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
112 if !is_unliquidatable_collateral {
113 update_running_totals(
114 &mut total_collateral_usdc_base_units,
115 &mut total_liabilities_usdc_base_units,
116 value_usdc_base_units,
117 )?;
118 }
119
120 let token_amount_unsigned = token_balance_base_units.unsigned_abs();
122 let weight_bps = if is_asset {
123 calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
124 as i128
125 } else {
126 calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
127 };
128 let weighted_value_usdc_base_units = value_usdc_base_units
129 .checked_mul(weight_bps)
130 .ok_or(MathError::Overflow)?
131 .checked_div(MARGIN_PRECISION)
132 .ok_or(MathError::Overflow)?;
133
134 update_running_totals(
135 &mut total_weighted_collateral_usdc_base_units,
136 &mut total_weighted_liabilities_usdc_base_units,
137 weighted_value_usdc_base_units,
138 )?;
139
140 if asset_id == pyra_tokens::AssetId::USDC && usdc_balance_base_units == 0 && token_balance_base_units > 0 {
142 usdc_balance_base_units =
143 u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
144 }
145
146 let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
148 .map_err(|_| MathError::Overflow)?;
149 position_infos.push(PositionInfo {
150 asset_id,
151 balance: token_balance_unsigned,
152 position_type: position.balance_type.clone(),
153 price_usdc_base_units,
154 weight_bps: spot_market.initial_asset_weight,
155 });
156 }
157
158 let available_credit_base_units = total_weighted_collateral_usdc_base_units
160 .saturating_sub(total_weighted_liabilities_usdc_base_units);
161 let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
162
163 let max_slippage_usdc_base_units = total_collateral_usdc_base_units
165 .checked_mul(max_slippage_bps)
166 .ok_or(MathError::Overflow)?
167 .checked_div(10_000)
168 .ok_or(MathError::Overflow)?;
169 let total_spendable_base_units = total_collateral_usdc_base_units
170 .saturating_sub(max_slippage_usdc_base_units)
171 .saturating_sub(total_liabilities_usdc_base_units);
172 let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
173
174 let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
175
176 Ok(CapacityResult {
177 total_spendable_cents,
178 available_credit_cents,
179 usdc_balance_cents,
180 weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
181 weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
182 position_infos,
183 })
184}
185
186fn update_running_totals(
188 total_positive: &mut u64,
189 total_negative: &mut u64,
190 value: i128,
191) -> MathResult<()> {
192 let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
193
194 if value >= 0 {
195 *total_positive = total_positive
196 .checked_add(value_unsigned)
197 .ok_or(MathError::Overflow)?;
198 } else {
199 *total_negative = total_negative
200 .checked_add(value_unsigned)
201 .ok_or(MathError::Overflow)?;
202 }
203
204 Ok(())
205}
206
207#[cfg(test)]
208#[allow(
209 clippy::allow_attributes,
210 clippy::allow_attributes_without_reason,
211 clippy::unwrap_used,
212 clippy::expect_used,
213 clippy::panic,
214 clippy::arithmetic_side_effects,
215 reason = "test code"
216)]
217mod tests {
218 use super::*;
219 use pyra_types::{HistoricalOracleData, InsuranceFund};
220
221 fn make_spot_market_with_twap(
222 market_index: u16,
223 decimals: u32,
224 initial_asset_weight: u32,
225 initial_liability_weight: u32,
226 twap5min: i64,
227 ) -> SpotMarket {
228 let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
229 SpotMarket {
230 pubkey: vec![],
231 market_index,
232 initial_asset_weight,
233 initial_liability_weight,
234 imf_factor: 0,
235 scale_initial_asset_weight_start: 0,
236 decimals,
237 cumulative_deposit_interest: precision_decrease,
238 cumulative_borrow_interest: precision_decrease,
239 deposit_balance: 0,
240 borrow_balance: 0,
241 optimal_utilization: 0,
242 optimal_borrow_rate: 0,
243 max_borrow_rate: 0,
244 min_borrow_rate: 0,
245 insurance_fund: InsuranceFund::default(),
246 historical_oracle_data: HistoricalOracleData {
247 last_oracle_price_twap5min: twap5min,
248 },
249 oracle: None,
250 }
251 }
252
253 fn make_spot_market(
255 market_index: u16,
256 decimals: u32,
257 initial_asset_weight: u32,
258 initial_liability_weight: u32,
259 oracle_price: u64,
260 ) -> SpotMarket {
261 make_spot_market_with_twap(
262 market_index,
263 decimals,
264 initial_asset_weight,
265 initial_liability_weight,
266 oracle_price as i64,
267 )
268 }
269
270 fn make_position(
272 drift_market_index: u16,
273 scaled_balance: u64,
274 is_deposit: bool,
275 ) -> SpotPosition {
276 SpotPosition {
277 market_index: drift_market_index,
278 scaled_balance,
279 balance_type: if is_deposit {
280 SpotBalanceType::Deposit
281 } else {
282 SpotBalanceType::Borrow
283 },
284 ..Default::default()
285 }
286 }
287
288 #[test]
291 fn empty_positions() {
292 let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
293 assert_eq!(result.total_spendable_cents, 0);
294 assert_eq!(result.available_credit_cents, 0);
295 assert_eq!(result.usdc_balance_cents, 0);
296 assert_eq!(result.weighted_collateral_usdc_base_units, 0);
297 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
298 assert!(result.position_infos.is_empty());
299 }
300
301 #[test]
302 fn single_usdc_deposit() {
303 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
305 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
308 markets.insert(AssetId::USDC, usdc); let mut prices = HashMap::new();
310 prices.insert(AssetId::USDC, 1_000_000u64); let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
313
314 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 10_000);
316 assert_eq!(result.available_credit_cents, 10_000);
317 assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
318 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
319 assert_eq!(result.position_infos.len(), 1);
320 assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
321 }
322
323 #[test]
324 fn deposit_and_borrow() {
325 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
326 let positions = vec![
327 make_position(0, 100_000_000, true), make_position(0, 50_000_000, false), ];
330
331 let mut markets = HashMap::new();
332 markets.insert(AssetId::USDC, usdc);
333 let mut prices = HashMap::new();
334 prices.insert(AssetId::USDC, 1_000_000u64);
335
336 let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
337
338 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 5_000); assert_eq!(result.available_credit_cents, 5_000);
341 }
342
343 #[test]
344 fn unliquidatable_excluded_from_spendable() {
345 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
347 let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
348
349 let positions = vec![
350 make_position(0, 10_000_000, true), make_position(4, 1_000_000_000, true), ];
353
354 let mut markets = HashMap::new();
355 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); let mut prices = HashMap::new();
358 prices.insert(AssetId::USDC, 1_000_000u64);
359 prices.insert(AssetId::WETH, 100_000_000u64); let unliquidatable = vec![AssetId::WETH]; let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
363
364 assert_eq!(result.total_spendable_cents, 1_000);
366 assert_eq!(result.available_credit_cents, 9_000);
368 assert_eq!(result.position_infos.len(), 2);
369 }
370
371 #[test]
372 fn slippage_reduces_spendable() {
373 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
374 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
377 markets.insert(AssetId::USDC, usdc);
378 let mut prices = HashMap::new();
379 prices.insert(AssetId::USDC, 1_000_000u64);
380
381 let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
383
384 assert_eq!(result.total_spendable_cents, 9_000);
386 assert_eq!(result.available_credit_cents, 10_000);
388 }
389
390 #[test]
391 fn missing_market_skipped() {
392 let positions = vec![make_position(5, 1_000_000, true)];
394
395 let result =
396 calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
397
398 assert_eq!(result.total_spendable_cents, 0);
399 assert!(result.position_infos.is_empty());
400 }
401
402 #[test]
403 fn missing_price_skipped() {
404 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
405 let positions = vec![make_position(0, 1_000_000, true)];
406
407 let mut markets = HashMap::new();
408 markets.insert(AssetId::USDC, usdc); let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
411
412 assert_eq!(result.total_spendable_cents, 0);
413 assert!(result.position_infos.is_empty());
414 }
415
416 #[test]
417 fn multi_position_with_unliquidatable_and_slippage() {
418 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
424 let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
426 let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
427
428 let positions = vec![
429 make_position(0, 50_000_000, true), make_position(4, 1_000_000_000, true), make_position(1, 500_000_000, true), make_position(5, 20_000_000, false), ];
434
435 let mut markets = HashMap::new();
437 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); markets.insert(AssetId::WSOL, wsol); markets.insert(AssetId::USDT, usdt); let mut prices = HashMap::new();
442 prices.insert(AssetId::USDC, 1_000_000u64);
443 prices.insert(AssetId::WETH, 200_000_000u64);
444 prices.insert(AssetId::WSOL, 100_000_000u64);
445 prices.insert(AssetId::USDT, 1_000_000u64);
446
447 let unliquidatable = vec![AssetId::WETH]; let result =
449 calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
450
451 assert_eq!(result.total_spendable_cents, 7_500);
456
457 assert_eq!(result.available_credit_cents, 23_000);
461 assert_eq!(result.usdc_balance_cents, 5_000);
462 assert_eq!(result.position_infos.len(), 4);
463 }
464
465 #[test]
468 fn running_totals_positive() {
469 let mut pos = 0u64;
470 let mut neg = 0u64;
471 update_running_totals(&mut pos, &mut neg, 100).unwrap();
472 assert_eq!(pos, 100);
473 assert_eq!(neg, 0);
474 }
475
476 #[test]
477 fn running_totals_negative() {
478 let mut pos = 0u64;
479 let mut neg = 0u64;
480 update_running_totals(&mut pos, &mut neg, -50).unwrap();
481 assert_eq!(pos, 0);
482 assert_eq!(neg, 50);
483 }
484
485 #[test]
486 fn running_totals_accumulate() {
487 let mut pos = 10u64;
488 let mut neg = 5u64;
489 update_running_totals(&mut pos, &mut neg, 20).unwrap();
490 update_running_totals(&mut pos, &mut neg, -15).unwrap();
491 assert_eq!(pos, 30);
492 assert_eq!(neg, 20);
493 }
494}
495
496#[cfg(test)]
497#[allow(
498 clippy::allow_attributes,
499 clippy::allow_attributes_without_reason,
500 clippy::unwrap_used,
501 clippy::expect_used,
502 clippy::panic,
503 clippy::arithmetic_side_effects,
504 reason = "test code"
505)]
506mod proptests {
507 use super::*;
508 use proptest::prelude::*;
509
510 proptest! {
511 #[test]
512 fn spendable_le_collateral_minus_liabilities(
513 collateral_base in 0u64..=1_000_000_000_000u64,
514 liabilities_base in 0u64..=500_000_000_000u64,
515 ) {
516 let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
518 let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
519 let max_possible = collateral_cents.saturating_sub(liabilities_cents);
520
521 let spendable_base = collateral_base.saturating_sub(liabilities_base);
523 let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
524 prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
525 }
526 }
527}