1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18 pub asset_id: AssetId,
19 pub balance: u64,
21 pub position_type: SpotBalanceType,
22 pub price_usdc_base_units: u64,
23 pub weight_bps: u32,
25}
26
27#[derive(Debug, Clone)]
29pub struct CapacityResult {
30 pub total_spendable_cents: u64,
33 pub available_credit_cents: u64,
35 pub usdc_balance_cents: u64,
37 pub weighted_collateral_usdc_base_units: u64,
39 pub weighted_liabilities_usdc_base_units: u64,
41 pub position_infos: Vec<PositionInfo>,
43}
44
45pub fn calculate_capacity(
62 spot_positions: &[SpotPosition],
63 spot_market_map: &HashMap<AssetId, SpotMarket>,
64 price_map: &HashMap<AssetId, u64>,
65 unliquidatable_asset_ids: &[AssetId],
66 max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68 let mut total_collateral_usdc_base_units: u64 = 0;
69 let mut total_liabilities_usdc_base_units: u64 = 0;
70
71 let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72 let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74 let mut usdc_balance_base_units: u64 = 0;
75
76 let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78 for position in spot_positions {
79 let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
81 else {
82 continue;
83 };
84 let asset_id = token.asset_id;
85
86 let Some(spot_market) = spot_market_map.get(&asset_id) else {
87 continue;
88 };
89 let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
90 continue;
91 };
92
93 let token_balance_base_units = get_token_balance(position, spot_market)?;
95
96 let is_asset = token_balance_base_units >= 0;
97 let twap5min = spot_market
98 .historical_oracle_data
99 .last_oracle_price_twap5min;
100 let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
101
102 let value_usdc_base_units = calculate_value_usdc_base_units(
103 token_balance_base_units,
104 strict_price,
105 spot_market.decimals,
106 )?;
107
108 let is_unliquidatable_collateral =
110 unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
111 if !is_unliquidatable_collateral {
112 update_running_totals(
113 &mut total_collateral_usdc_base_units,
114 &mut total_liabilities_usdc_base_units,
115 value_usdc_base_units,
116 )?;
117 }
118
119 let token_amount_unsigned = token_balance_base_units.unsigned_abs();
121 let weight_bps = if is_asset {
122 calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
123 as i128
124 } else {
125 calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
126 };
127 let weighted_value_usdc_base_units = value_usdc_base_units
128 .checked_mul(weight_bps)
129 .ok_or(MathError::Overflow)?
130 .checked_div(MARGIN_PRECISION)
131 .ok_or(MathError::Overflow)?;
132
133 update_running_totals(
134 &mut total_weighted_collateral_usdc_base_units,
135 &mut total_weighted_liabilities_usdc_base_units,
136 weighted_value_usdc_base_units,
137 )?;
138
139 if asset_id == pyra_tokens::AssetId::USDC
141 && usdc_balance_base_units == 0
142 && token_balance_base_units > 0
143 {
144 usdc_balance_base_units =
145 u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
146 }
147
148 let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
150 .map_err(|_| MathError::Overflow)?;
151 position_infos.push(PositionInfo {
152 asset_id,
153 balance: token_balance_unsigned,
154 position_type: position.balance_type.clone(),
155 price_usdc_base_units,
156 weight_bps: spot_market.initial_asset_weight,
157 });
158 }
159
160 let available_credit_base_units = total_weighted_collateral_usdc_base_units
162 .saturating_sub(total_weighted_liabilities_usdc_base_units);
163 let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
164
165 let max_slippage_usdc_base_units = total_collateral_usdc_base_units
167 .checked_mul(max_slippage_bps)
168 .ok_or(MathError::Overflow)?
169 .checked_div(10_000)
170 .ok_or(MathError::Overflow)?;
171 let total_spendable_base_units = total_collateral_usdc_base_units
172 .saturating_sub(max_slippage_usdc_base_units)
173 .saturating_sub(total_liabilities_usdc_base_units);
174 let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
175
176 let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
177
178 Ok(CapacityResult {
179 total_spendable_cents,
180 available_credit_cents,
181 usdc_balance_cents,
182 weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
183 weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
184 position_infos,
185 })
186}
187
188fn update_running_totals(
190 total_positive: &mut u64,
191 total_negative: &mut u64,
192 value: i128,
193) -> MathResult<()> {
194 let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
195
196 if value >= 0 {
197 *total_positive = total_positive
198 .checked_add(value_unsigned)
199 .ok_or(MathError::Overflow)?;
200 } else {
201 *total_negative = total_negative
202 .checked_add(value_unsigned)
203 .ok_or(MathError::Overflow)?;
204 }
205
206 Ok(())
207}
208
209#[cfg(test)]
210#[allow(
211 clippy::allow_attributes,
212 clippy::allow_attributes_without_reason,
213 clippy::unwrap_used,
214 clippy::expect_used,
215 clippy::panic,
216 clippy::arithmetic_side_effects
217)]
218mod tests {
219 use super::*;
220 use pyra_types::{HistoricalOracleData, InsuranceFund};
221
222 fn make_spot_market_with_twap(
223 market_index: u16,
224 decimals: u32,
225 initial_asset_weight: u32,
226 initial_liability_weight: u32,
227 twap5min: i64,
228 ) -> SpotMarket {
229 let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
230 SpotMarket {
231 pubkey: vec![],
232 market_index,
233 initial_asset_weight,
234 initial_liability_weight,
235 imf_factor: 0,
236 scale_initial_asset_weight_start: 0,
237 decimals,
238 cumulative_deposit_interest: precision_decrease,
239 cumulative_borrow_interest: precision_decrease,
240 deposit_balance: 0,
241 borrow_balance: 0,
242 optimal_utilization: 0,
243 optimal_borrow_rate: 0,
244 max_borrow_rate: 0,
245 min_borrow_rate: 0,
246 insurance_fund: InsuranceFund::default(),
247 historical_oracle_data: HistoricalOracleData {
248 last_oracle_price_twap5min: twap5min,
249 },
250 oracle: None,
251 }
252 }
253
254 fn make_spot_market(
256 market_index: u16,
257 decimals: u32,
258 initial_asset_weight: u32,
259 initial_liability_weight: u32,
260 oracle_price: u64,
261 ) -> SpotMarket {
262 make_spot_market_with_twap(
263 market_index,
264 decimals,
265 initial_asset_weight,
266 initial_liability_weight,
267 oracle_price as i64,
268 )
269 }
270
271 fn make_position(
273 drift_market_index: u16,
274 scaled_balance: u64,
275 is_deposit: bool,
276 ) -> SpotPosition {
277 SpotPosition {
278 market_index: drift_market_index,
279 scaled_balance,
280 balance_type: if is_deposit {
281 SpotBalanceType::Deposit
282 } else {
283 SpotBalanceType::Borrow
284 },
285 ..Default::default()
286 }
287 }
288
289 #[test]
292 fn empty_positions() {
293 let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
294 assert_eq!(result.total_spendable_cents, 0);
295 assert_eq!(result.available_credit_cents, 0);
296 assert_eq!(result.usdc_balance_cents, 0);
297 assert_eq!(result.weighted_collateral_usdc_base_units, 0);
298 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
299 assert!(result.position_infos.is_empty());
300 }
301
302 #[test]
303 fn single_usdc_deposit() {
304 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
306 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
309 markets.insert(AssetId::USDC, usdc); let mut prices = HashMap::new();
311 prices.insert(AssetId::USDC, 1_000_000u64); let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
314
315 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 10_000);
317 assert_eq!(result.available_credit_cents, 10_000);
318 assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
319 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
320 assert_eq!(result.position_infos.len(), 1);
321 assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
322 }
323
324 #[test]
325 fn deposit_and_borrow() {
326 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
327 let positions = vec![
328 make_position(0, 100_000_000, true), make_position(0, 50_000_000, false), ];
331
332 let mut markets = HashMap::new();
333 markets.insert(AssetId::USDC, usdc);
334 let mut prices = HashMap::new();
335 prices.insert(AssetId::USDC, 1_000_000u64);
336
337 let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
338
339 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 5_000); assert_eq!(result.available_credit_cents, 5_000);
342 }
343
344 #[test]
345 fn unliquidatable_excluded_from_spendable() {
346 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
348 let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
349
350 let positions = vec![
351 make_position(0, 10_000_000, true), make_position(4, 1_000_000_000, true), ];
354
355 let mut markets = HashMap::new();
356 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); let mut prices = HashMap::new();
359 prices.insert(AssetId::USDC, 1_000_000u64);
360 prices.insert(AssetId::WETH, 100_000_000u64); let unliquidatable = vec![AssetId::WETH]; let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
364
365 assert_eq!(result.total_spendable_cents, 1_000);
367 assert_eq!(result.available_credit_cents, 9_000);
369 assert_eq!(result.position_infos.len(), 2);
370 }
371
372 #[test]
373 fn slippage_reduces_spendable() {
374 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
375 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
378 markets.insert(AssetId::USDC, usdc);
379 let mut prices = HashMap::new();
380 prices.insert(AssetId::USDC, 1_000_000u64);
381
382 let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
384
385 assert_eq!(result.total_spendable_cents, 9_000);
387 assert_eq!(result.available_credit_cents, 10_000);
389 }
390
391 #[test]
392 fn missing_market_skipped() {
393 let positions = vec![make_position(5, 1_000_000, true)];
395
396 let result =
397 calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
398
399 assert_eq!(result.total_spendable_cents, 0);
400 assert!(result.position_infos.is_empty());
401 }
402
403 #[test]
404 fn missing_price_skipped() {
405 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
406 let positions = vec![make_position(0, 1_000_000, true)];
407
408 let mut markets = HashMap::new();
409 markets.insert(AssetId::USDC, usdc); let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
412
413 assert_eq!(result.total_spendable_cents, 0);
414 assert!(result.position_infos.is_empty());
415 }
416
417 #[test]
418 fn multi_position_with_unliquidatable_and_slippage() {
419 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
425 let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
427 let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
428
429 let positions = vec![
430 make_position(0, 50_000_000, true), make_position(4, 1_000_000_000, true), make_position(1, 500_000_000, true), make_position(5, 20_000_000, false), ];
435
436 let mut markets = HashMap::new();
438 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); markets.insert(AssetId::WSOL, wsol); markets.insert(AssetId::USDT, usdt); let mut prices = HashMap::new();
443 prices.insert(AssetId::USDC, 1_000_000u64);
444 prices.insert(AssetId::WETH, 200_000_000u64);
445 prices.insert(AssetId::WSOL, 100_000_000u64);
446 prices.insert(AssetId::USDT, 1_000_000u64);
447
448 let unliquidatable = vec![AssetId::WETH]; let result =
450 calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
451
452 assert_eq!(result.total_spendable_cents, 7_500);
457
458 assert_eq!(result.available_credit_cents, 23_000);
462 assert_eq!(result.usdc_balance_cents, 5_000);
463 assert_eq!(result.position_infos.len(), 4);
464 }
465
466 #[test]
469 fn running_totals_positive() {
470 let mut pos = 0u64;
471 let mut neg = 0u64;
472 update_running_totals(&mut pos, &mut neg, 100).unwrap();
473 assert_eq!(pos, 100);
474 assert_eq!(neg, 0);
475 }
476
477 #[test]
478 fn running_totals_negative() {
479 let mut pos = 0u64;
480 let mut neg = 0u64;
481 update_running_totals(&mut pos, &mut neg, -50).unwrap();
482 assert_eq!(pos, 0);
483 assert_eq!(neg, 50);
484 }
485
486 #[test]
487 fn running_totals_accumulate() {
488 let mut pos = 10u64;
489 let mut neg = 5u64;
490 update_running_totals(&mut pos, &mut neg, 20).unwrap();
491 update_running_totals(&mut pos, &mut neg, -15).unwrap();
492 assert_eq!(pos, 30);
493 assert_eq!(neg, 20);
494 }
495}
496
497#[cfg(test)]
498#[allow(
499 clippy::allow_attributes,
500 clippy::allow_attributes_without_reason,
501 clippy::unwrap_used,
502 clippy::expect_used,
503 clippy::panic,
504 clippy::arithmetic_side_effects
505)]
506mod proptests {
507 use super::*;
508 use proptest::prelude::*;
509
510 proptest! {
511 #[test]
512 fn spendable_le_collateral_minus_liabilities(
513 collateral_base in 0u64..=1_000_000_000_000u64,
514 liabilities_base in 0u64..=500_000_000_000u64,
515 ) {
516 let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
518 let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
519 let max_possible = collateral_cents.saturating_sub(liabilities_cents);
520
521 let spendable_base = collateral_base.saturating_sub(liabilities_base);
523 let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
524 prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
525 }
526 }
527}