1use std::collections::HashMap;
2
3use pyra_tokens::AssetId;
4use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
5
6use super::balance::{calculate_value_usdc_base_units, get_token_balance};
7use super::weights::{calculate_asset_weight, calculate_liability_weight, get_strict_price};
8use crate::common::usdc_base_units_to_cents;
9use crate::error::{MathError, MathResult};
10
11const MARGIN_PRECISION: i128 = 10_000;
12
13#[derive(Debug, Clone, PartialEq)]
17pub struct PositionInfo {
18 pub asset_id: AssetId,
19 pub balance: u64,
21 pub position_type: SpotBalanceType,
22 pub price_usdc_base_units: u64,
23 pub weight_bps: u32,
25}
26
27#[derive(Debug, Clone)]
29pub struct CapacityResult {
30 pub total_spendable_cents: u64,
33 pub available_credit_cents: u64,
35 pub usdc_balance_cents: u64,
37 pub weighted_collateral_usdc_base_units: u64,
39 pub weighted_liabilities_usdc_base_units: u64,
41 pub position_infos: Vec<PositionInfo>,
43}
44
45pub fn calculate_capacity(
62 spot_positions: &[SpotPosition],
63 spot_market_map: &HashMap<AssetId, SpotMarket>,
64 price_map: &HashMap<AssetId, u64>,
65 unliquidatable_asset_ids: &[AssetId],
66 max_slippage_bps: u64,
67) -> MathResult<CapacityResult> {
68 let mut total_collateral_usdc_base_units: u64 = 0;
69 let mut total_liabilities_usdc_base_units: u64 = 0;
70
71 let mut total_weighted_collateral_usdc_base_units: u64 = 0;
72 let mut total_weighted_liabilities_usdc_base_units: u64 = 0;
73
74 let mut usdc_balance_base_units: u64 = 0;
75
76 let mut position_infos: Vec<PositionInfo> = Vec::new();
77
78 for position in spot_positions {
79 let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
81 else {
82 continue;
83 };
84 let asset_id = token.asset_id;
85
86 let Some(spot_market) = spot_market_map.get(&asset_id) else {
87 continue;
88 };
89 let Some(price_usdc_base_units) = price_map.get(&asset_id).copied() else {
90 continue;
91 };
92
93 let token_balance_base_units = get_token_balance(position, spot_market)?;
95
96 let is_asset = token_balance_base_units >= 0;
97 let twap5min = spot_market
98 .historical_oracle_data
99 .last_oracle_price_twap5min;
100 let strict_price = get_strict_price(price_usdc_base_units, twap5min, is_asset);
101
102 let value_usdc_base_units = calculate_value_usdc_base_units(
103 token_balance_base_units,
104 strict_price,
105 spot_market.decimals,
106 )?;
107
108 let is_unliquidatable_collateral =
110 unliquidatable_asset_ids.contains(&asset_id) && value_usdc_base_units > 0;
111 if !is_unliquidatable_collateral {
112 update_running_totals(
113 &mut total_collateral_usdc_base_units,
114 &mut total_liabilities_usdc_base_units,
115 value_usdc_base_units,
116 )?;
117 }
118
119 let token_amount_unsigned = token_balance_base_units.unsigned_abs();
121 let weight_bps = if is_asset {
122 calculate_asset_weight(token_amount_unsigned, price_usdc_base_units, spot_market)?
123 as i128
124 } else {
125 calculate_liability_weight(token_amount_unsigned, spot_market)? as i128
126 };
127 let weighted_value_usdc_base_units = value_usdc_base_units
128 .checked_mul(weight_bps)
129 .ok_or(MathError::Overflow)?
130 .checked_div(MARGIN_PRECISION)
131 .ok_or(MathError::Overflow)?;
132
133 update_running_totals(
134 &mut total_weighted_collateral_usdc_base_units,
135 &mut total_weighted_liabilities_usdc_base_units,
136 weighted_value_usdc_base_units,
137 )?;
138
139 if asset_id == pyra_tokens::AssetId::USDC
141 && usdc_balance_base_units == 0
142 && token_balance_base_units > 0
143 {
144 usdc_balance_base_units =
145 u64::try_from(token_balance_base_units).map_err(|_| MathError::Overflow)?;
146 }
147
148 let token_balance_unsigned = u64::try_from(token_balance_base_units.unsigned_abs())
150 .map_err(|_| MathError::Overflow)?;
151 position_infos.push(PositionInfo {
152 asset_id,
153 balance: token_balance_unsigned,
154 position_type: position.balance_type.clone(),
155 price_usdc_base_units,
156 weight_bps: spot_market.initial_asset_weight,
157 });
158 }
159
160 let available_credit_base_units = total_weighted_collateral_usdc_base_units
162 .saturating_sub(total_weighted_liabilities_usdc_base_units);
163 let available_credit_cents = usdc_base_units_to_cents(available_credit_base_units)?;
164
165 let max_slippage_usdc_base_units = total_collateral_usdc_base_units
167 .checked_mul(max_slippage_bps)
168 .ok_or(MathError::Overflow)?
169 .checked_div(10_000)
170 .ok_or(MathError::Overflow)?;
171 let total_spendable_base_units = total_collateral_usdc_base_units
172 .saturating_sub(max_slippage_usdc_base_units)
173 .saturating_sub(total_liabilities_usdc_base_units);
174 let total_spendable_cents = usdc_base_units_to_cents(total_spendable_base_units)?;
175
176 let usdc_balance_cents = usdc_base_units_to_cents(usdc_balance_base_units)?;
177
178 Ok(CapacityResult {
179 total_spendable_cents,
180 available_credit_cents,
181 usdc_balance_cents,
182 weighted_collateral_usdc_base_units: total_weighted_collateral_usdc_base_units,
183 weighted_liabilities_usdc_base_units: total_weighted_liabilities_usdc_base_units,
184 position_infos,
185 })
186}
187
188fn update_running_totals(
190 total_positive: &mut u64,
191 total_negative: &mut u64,
192 value: i128,
193) -> MathResult<()> {
194 let value_unsigned = u64::try_from(value.unsigned_abs()).map_err(|_| MathError::Overflow)?;
195
196 if value >= 0 {
197 *total_positive = total_positive
198 .checked_add(value_unsigned)
199 .ok_or(MathError::Overflow)?;
200 } else {
201 *total_negative = total_negative
202 .checked_add(value_unsigned)
203 .ok_or(MathError::Overflow)?;
204 }
205
206 Ok(())
207}
208
209#[cfg(test)]
210#[allow(
211 clippy::allow_attributes,
212 clippy::allow_attributes_without_reason,
213 clippy::unwrap_used,
214 clippy::expect_used,
215 clippy::panic,
216 clippy::arithmetic_side_effects,
217 reason = "test code"
218)]
219mod tests {
220 use super::*;
221 use pyra_types::{HistoricalOracleData, InsuranceFund};
222
223 fn make_spot_market_with_twap(
224 market_index: u16,
225 decimals: u32,
226 initial_asset_weight: u32,
227 initial_liability_weight: u32,
228 twap5min: i64,
229 ) -> SpotMarket {
230 let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
231 SpotMarket {
232 pubkey: vec![],
233 market_index,
234 initial_asset_weight,
235 initial_liability_weight,
236 imf_factor: 0,
237 scale_initial_asset_weight_start: 0,
238 decimals,
239 cumulative_deposit_interest: precision_decrease,
240 cumulative_borrow_interest: precision_decrease,
241 deposit_balance: 0,
242 borrow_balance: 0,
243 optimal_utilization: 0,
244 optimal_borrow_rate: 0,
245 max_borrow_rate: 0,
246 min_borrow_rate: 0,
247 insurance_fund: InsuranceFund::default(),
248 historical_oracle_data: HistoricalOracleData {
249 last_oracle_price_twap5min: twap5min,
250 },
251 oracle: None,
252 }
253 }
254
255 fn make_spot_market(
257 market_index: u16,
258 decimals: u32,
259 initial_asset_weight: u32,
260 initial_liability_weight: u32,
261 oracle_price: u64,
262 ) -> SpotMarket {
263 make_spot_market_with_twap(
264 market_index,
265 decimals,
266 initial_asset_weight,
267 initial_liability_weight,
268 oracle_price as i64,
269 )
270 }
271
272 fn make_position(
274 drift_market_index: u16,
275 scaled_balance: u64,
276 is_deposit: bool,
277 ) -> SpotPosition {
278 SpotPosition {
279 market_index: drift_market_index,
280 scaled_balance,
281 balance_type: if is_deposit {
282 SpotBalanceType::Deposit
283 } else {
284 SpotBalanceType::Borrow
285 },
286 ..Default::default()
287 }
288 }
289
290 #[test]
293 fn empty_positions() {
294 let result = calculate_capacity(&[], &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
295 assert_eq!(result.total_spendable_cents, 0);
296 assert_eq!(result.available_credit_cents, 0);
297 assert_eq!(result.usdc_balance_cents, 0);
298 assert_eq!(result.weighted_collateral_usdc_base_units, 0);
299 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
300 assert!(result.position_infos.is_empty());
301 }
302
303 #[test]
304 fn single_usdc_deposit() {
305 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
307 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
310 markets.insert(AssetId::USDC, usdc); let mut prices = HashMap::new();
312 prices.insert(AssetId::USDC, 1_000_000u64); let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
315
316 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 10_000);
318 assert_eq!(result.available_credit_cents, 10_000);
319 assert_eq!(result.weighted_collateral_usdc_base_units, 100_000_000);
320 assert_eq!(result.weighted_liabilities_usdc_base_units, 0);
321 assert_eq!(result.position_infos.len(), 1);
322 assert_eq!(result.position_infos[0].asset_id, AssetId::USDC);
323 }
324
325 #[test]
326 fn deposit_and_borrow() {
327 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
328 let positions = vec![
329 make_position(0, 100_000_000, true), make_position(0, 50_000_000, false), ];
332
333 let mut markets = HashMap::new();
334 markets.insert(AssetId::USDC, usdc);
335 let mut prices = HashMap::new();
336 prices.insert(AssetId::USDC, 1_000_000u64);
337
338 let result = calculate_capacity(&positions, &markets, &prices, &[], 0).unwrap();
339
340 assert_eq!(result.usdc_balance_cents, 10_000); assert_eq!(result.total_spendable_cents, 5_000); assert_eq!(result.available_credit_cents, 5_000);
343 }
344
345 #[test]
346 fn unliquidatable_excluded_from_spendable() {
347 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
349 let weth = make_spot_market(4, 9, 8_000, 12_000, 100_000_000);
350
351 let positions = vec![
352 make_position(0, 10_000_000, true), make_position(4, 1_000_000_000, true), ];
355
356 let mut markets = HashMap::new();
357 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); let mut prices = HashMap::new();
360 prices.insert(AssetId::USDC, 1_000_000u64);
361 prices.insert(AssetId::WETH, 100_000_000u64); let unliquidatable = vec![AssetId::WETH]; let result = calculate_capacity(&positions, &markets, &prices, &unliquidatable, 0).unwrap();
365
366 assert_eq!(result.total_spendable_cents, 1_000);
368 assert_eq!(result.available_credit_cents, 9_000);
370 assert_eq!(result.position_infos.len(), 2);
371 }
372
373 #[test]
374 fn slippage_reduces_spendable() {
375 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
376 let positions = vec![make_position(0, 100_000_000, true)]; let mut markets = HashMap::new();
379 markets.insert(AssetId::USDC, usdc);
380 let mut prices = HashMap::new();
381 prices.insert(AssetId::USDC, 1_000_000u64);
382
383 let result = calculate_capacity(&positions, &markets, &prices, &[], 1_000).unwrap();
385
386 assert_eq!(result.total_spendable_cents, 9_000);
388 assert_eq!(result.available_credit_cents, 10_000);
390 }
391
392 #[test]
393 fn missing_market_skipped() {
394 let positions = vec![make_position(5, 1_000_000, true)];
396
397 let result =
398 calculate_capacity(&positions, &HashMap::new(), &HashMap::new(), &[], 0).unwrap();
399
400 assert_eq!(result.total_spendable_cents, 0);
401 assert!(result.position_infos.is_empty());
402 }
403
404 #[test]
405 fn missing_price_skipped() {
406 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
407 let positions = vec![make_position(0, 1_000_000, true)];
408
409 let mut markets = HashMap::new();
410 markets.insert(AssetId::USDC, usdc); let result = calculate_capacity(&positions, &markets, &HashMap::new(), &[], 0).unwrap();
413
414 assert_eq!(result.total_spendable_cents, 0);
415 assert!(result.position_infos.is_empty());
416 }
417
418 #[test]
419 fn multi_position_with_unliquidatable_and_slippage() {
420 let usdc = make_spot_market(0, 6, 10_000, 10_000, 1_000_000);
426 let weth = make_spot_market(4, 9, 8_000, 12_000, 200_000_000); let wsol = make_spot_market(1, 9, 8_000, 12_000, 100_000_000);
428 let usdt = make_spot_market(5, 6, 10_000, 10_000, 1_000_000);
429
430 let positions = vec![
431 make_position(0, 50_000_000, true), make_position(4, 1_000_000_000, true), make_position(1, 500_000_000, true), make_position(5, 20_000_000, false), ];
436
437 let mut markets = HashMap::new();
439 markets.insert(AssetId::USDC, usdc); markets.insert(AssetId::WETH, weth); markets.insert(AssetId::WSOL, wsol); markets.insert(AssetId::USDT, usdt); let mut prices = HashMap::new();
444 prices.insert(AssetId::USDC, 1_000_000u64);
445 prices.insert(AssetId::WETH, 200_000_000u64);
446 prices.insert(AssetId::WSOL, 100_000_000u64);
447 prices.insert(AssetId::USDT, 1_000_000u64);
448
449 let unliquidatable = vec![AssetId::WETH]; let result =
451 calculate_capacity(&positions, &markets, &prices, &unliquidatable, 500).unwrap();
452
453 assert_eq!(result.total_spendable_cents, 7_500);
458
459 assert_eq!(result.available_credit_cents, 23_000);
463 assert_eq!(result.usdc_balance_cents, 5_000);
464 assert_eq!(result.position_infos.len(), 4);
465 }
466
467 #[test]
470 fn running_totals_positive() {
471 let mut pos = 0u64;
472 let mut neg = 0u64;
473 update_running_totals(&mut pos, &mut neg, 100).unwrap();
474 assert_eq!(pos, 100);
475 assert_eq!(neg, 0);
476 }
477
478 #[test]
479 fn running_totals_negative() {
480 let mut pos = 0u64;
481 let mut neg = 0u64;
482 update_running_totals(&mut pos, &mut neg, -50).unwrap();
483 assert_eq!(pos, 0);
484 assert_eq!(neg, 50);
485 }
486
487 #[test]
488 fn running_totals_accumulate() {
489 let mut pos = 10u64;
490 let mut neg = 5u64;
491 update_running_totals(&mut pos, &mut neg, 20).unwrap();
492 update_running_totals(&mut pos, &mut neg, -15).unwrap();
493 assert_eq!(pos, 30);
494 assert_eq!(neg, 20);
495 }
496}
497
498#[cfg(test)]
499#[allow(
500 clippy::allow_attributes,
501 clippy::allow_attributes_without_reason,
502 clippy::unwrap_used,
503 clippy::expect_used,
504 clippy::panic,
505 clippy::arithmetic_side_effects,
506 reason = "test code"
507)]
508mod proptests {
509 use super::*;
510 use proptest::prelude::*;
511
512 proptest! {
513 #[test]
514 fn spendable_le_collateral_minus_liabilities(
515 collateral_base in 0u64..=1_000_000_000_000u64,
516 liabilities_base in 0u64..=500_000_000_000u64,
517 ) {
518 let collateral_cents = usdc_base_units_to_cents(collateral_base).unwrap();
520 let liabilities_cents = usdc_base_units_to_cents(liabilities_base).unwrap();
521 let max_possible = collateral_cents.saturating_sub(liabilities_cents);
522
523 let spendable_base = collateral_base.saturating_sub(liabilities_base);
525 let spendable_cents = usdc_base_units_to_cents(spendable_base).unwrap();
526 prop_assert!(spendable_cents <= max_possible + 1, "rounding violation");
527 }
528 }
529}