1use std::cmp;
2
3use pyra_types::SpotMarket;
4
5use crate::error::{MathError, MathResult};
6
7pub const SPOT_WEIGHT_PRECISION: u128 = 10_000;
9pub const SPOT_IMF_PRECISION: u128 = 1_000_000;
11pub const AMM_RESERVE_PRECISION: u128 = 1_000_000_000;
13
14fn isqrt(n: u128) -> u128 {
16 if n < 2 {
17 return n;
18 }
19 let mut x = 1u128 << ((128u32.saturating_sub(n.leading_zeros()).saturating_add(1)) / 2);
20 let mut y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
21 while y < x {
22 x = y;
23 y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
24 }
25 x
26}
27
28pub fn to_amm_precision(balance: u128, token_decimals: u32) -> MathResult<u128> {
30 let size_precision = 10u128
31 .checked_pow(token_decimals)
32 .ok_or(MathError::Overflow)?;
33
34 if size_precision > AMM_RESERVE_PRECISION {
35 let scale = size_precision
36 .checked_div(AMM_RESERVE_PRECISION)
37 .ok_or(MathError::Overflow)?;
38 balance.checked_div(scale).ok_or(MathError::Overflow)
39 } else {
40 balance
41 .checked_mul(AMM_RESERVE_PRECISION)
42 .ok_or(MathError::Overflow)?
43 .checked_div(size_precision)
44 .ok_or(MathError::Overflow)
45 }
46}
47
48pub fn calculate_scaled_initial_asset_weight(
52 spot_market: &SpotMarket,
53 oracle_price: u64,
54) -> MathResult<u128> {
55 let initial_asset_weight = spot_market.initial_asset_weight as u128;
56
57 if spot_market.scale_initial_asset_weight_start == 0 {
58 return Ok(initial_asset_weight);
59 }
60
61 let precision_decrease = 10u128
62 .checked_pow(19u32.saturating_sub(spot_market.decimals))
63 .ok_or(MathError::Overflow)?;
64
65 let deposit_tokens = (spot_market.deposit_balance)
66 .checked_mul(spot_market.cumulative_deposit_interest)
67 .ok_or(MathError::Overflow)?
68 .checked_div(precision_decrease)
69 .ok_or(MathError::Overflow)?;
70
71 let token_precision = 10u128
72 .checked_pow(spot_market.decimals)
73 .ok_or(MathError::Overflow)?;
74
75 let deposits_value = deposit_tokens
76 .checked_mul(oracle_price as u128)
77 .ok_or(MathError::Overflow)?
78 .checked_div(token_precision)
79 .ok_or(MathError::Overflow)?;
80
81 let threshold = spot_market.scale_initial_asset_weight_start as u128;
82 if deposits_value < threshold {
83 return Ok(initial_asset_weight);
84 }
85
86 initial_asset_weight
87 .checked_mul(threshold)
88 .ok_or(MathError::Overflow)?
89 .checked_div(deposits_value)
90 .ok_or(MathError::Overflow)
91}
92
93pub fn calculate_size_discount_asset_weight(
97 size_in_amm: u128,
98 imf_factor: u32,
99 asset_weight: u128,
100) -> MathResult<u128> {
101 if imf_factor == 0 {
102 return Ok(asset_weight);
103 }
104
105 let size_times_10 = size_in_amm
106 .checked_mul(10)
107 .ok_or(MathError::Overflow)?
108 .checked_add(1)
109 .ok_or(MathError::Overflow)?;
110 let size_sqrt = isqrt(size_times_10);
111
112 let imf_numerator: u128 = SPOT_IMF_PRECISION
113 .checked_add(
114 SPOT_IMF_PRECISION
115 .checked_div(10)
116 .ok_or(MathError::Overflow)?,
117 )
118 .ok_or(MathError::Overflow)?;
119
120 let numerator = imf_numerator
121 .checked_mul(SPOT_WEIGHT_PRECISION)
122 .ok_or(MathError::Overflow)?;
123
124 let inner = size_sqrt
125 .checked_mul(imf_factor as u128)
126 .ok_or(MathError::Overflow)?
127 .checked_div(100_000)
128 .ok_or(MathError::Overflow)?;
129 let denominator = SPOT_IMF_PRECISION
130 .checked_add(inner)
131 .ok_or(MathError::Overflow)?;
132
133 let size_discount_weight = numerator
134 .checked_div(denominator)
135 .ok_or(MathError::Overflow)?;
136
137 Ok(cmp::min(asset_weight, size_discount_weight))
138}
139
140pub fn calculate_size_premium_liability_weight(
144 size_in_amm: u128,
145 imf_factor: u32,
146 liability_weight: u128,
147) -> MathResult<u128> {
148 if imf_factor == 0 {
149 return Ok(liability_weight);
150 }
151
152 let size_times_10 = size_in_amm
153 .checked_mul(10)
154 .ok_or(MathError::Overflow)?
155 .checked_add(1)
156 .ok_or(MathError::Overflow)?;
157 let size_sqrt = isqrt(size_times_10);
158
159 let lw_fifth = liability_weight.checked_div(5).ok_or(MathError::Overflow)?;
160 let liability_weight_numerator = liability_weight
161 .checked_sub(lw_fifth)
162 .ok_or(MathError::Overflow)?;
163
164 let denom = 100_000u128
165 .checked_mul(SPOT_IMF_PRECISION)
166 .ok_or(MathError::Overflow)?
167 .checked_div(SPOT_WEIGHT_PRECISION)
168 .ok_or(MathError::Overflow)?;
169
170 let premium_term = size_sqrt
171 .checked_mul(imf_factor as u128)
172 .ok_or(MathError::Overflow)?
173 .checked_div(denom)
174 .ok_or(MathError::Overflow)?;
175
176 let size_premium_weight = liability_weight_numerator
177 .checked_add(premium_term)
178 .ok_or(MathError::Overflow)?;
179
180 Ok(cmp::max(liability_weight, size_premium_weight))
181}
182
183pub fn calculate_asset_weight(
186 token_amount: u128,
187 oracle_price: u64,
188 spot_market: &SpotMarket,
189) -> MathResult<u128> {
190 let scaled_weight = calculate_scaled_initial_asset_weight(spot_market, oracle_price)?;
191 let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
192 calculate_size_discount_asset_weight(size_in_amm, spot_market.imf_factor, scaled_weight)
193}
194
195pub fn calculate_liability_weight(
198 token_amount: u128,
199 spot_market: &SpotMarket,
200) -> MathResult<u128> {
201 let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
202 calculate_size_premium_liability_weight(
203 size_in_amm,
204 spot_market.imf_factor,
205 spot_market.initial_liability_weight as u128,
206 )
207}
208
209pub fn get_strict_price(price_usdc_base_units: u64, twap5min: i64, is_asset: bool) -> u64 {
217 let twap = if twap5min > 0 {
218 twap5min as u64
219 } else {
220 price_usdc_base_units
221 };
222 if is_asset {
223 cmp::min(price_usdc_base_units, twap)
224 } else {
225 cmp::max(price_usdc_base_units, twap)
226 }
227}
228
229#[cfg(test)]
230#[allow(
231 clippy::allow_attributes,
232 clippy::allow_attributes_without_reason,
233 clippy::unwrap_used,
234 clippy::expect_used,
235 clippy::panic,
236 clippy::arithmetic_side_effects,
237 reason = "test code"
238)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn isqrt_basic_values() {
244 assert_eq!(isqrt(0), 0);
245 assert_eq!(isqrt(1), 1);
246 assert_eq!(isqrt(4), 2);
247 assert_eq!(isqrt(9), 3);
248 assert_eq!(isqrt(10), 3);
249 assert_eq!(isqrt(100), 10);
250 assert_eq!(isqrt(10_000_000_000), 100_000);
251 }
252
253 #[test]
254 fn size_discount_asset_weight_no_imf() {
255 let result = calculate_size_discount_asset_weight(1_000_000_000, 0, 8_000).unwrap();
256 assert_eq!(result, 8_000);
257 }
258
259 #[test]
260 fn size_discount_asset_weight_with_imf() {
261 let result = calculate_size_discount_asset_weight(1_000_000_000, 1000, 8_000).unwrap();
262 assert_eq!(result, 8_000);
263
264 let result =
265 calculate_size_discount_asset_weight(1_000_000_000_000_000, 1000, 8_000).unwrap();
266 assert!(result < 8_000, "Large position should have reduced weight");
267 }
268
269 #[test]
270 fn size_premium_liability_weight_no_imf() {
271 let result = calculate_size_premium_liability_weight(1_000_000_000, 0, 12_000).unwrap();
272 assert_eq!(result, 12_000);
273 }
274
275 #[test]
276 fn size_premium_liability_weight_with_imf() {
277 let result = calculate_size_premium_liability_weight(1_000_000_000, 1000, 12_000).unwrap();
278 assert_eq!(result, 12_000);
279
280 let result =
281 calculate_size_premium_liability_weight(1_000_000_000_000_000, 1000, 12_000).unwrap();
282 assert!(
283 result > 12_000,
284 "Large position should have increased weight"
285 );
286 }
287
288 #[test]
289 fn strict_price_asset_uses_min() {
290 assert_eq!(get_strict_price(1_000_000, 900_000, true), 900_000);
291 assert_eq!(get_strict_price(1_000_000, 1_100_000, true), 1_000_000);
292 }
293
294 #[test]
295 fn strict_price_liability_uses_max() {
296 assert_eq!(get_strict_price(1_000_000, 900_000, false), 1_000_000);
297 assert_eq!(get_strict_price(1_000_000, 1_100_000, false), 1_100_000);
298 }
299
300 #[test]
301 fn strict_price_nonpositive_twap_falls_back() {
302 assert_eq!(get_strict_price(1_000_000, 0, true), 1_000_000);
303 assert_eq!(get_strict_price(1_000_000, -500, true), 1_000_000);
304 assert_eq!(get_strict_price(1_000_000, 0, false), 1_000_000);
305 }
306
307 fn make_weight_market(
308 initial_asset_weight: u32,
309 scale_start: u64,
310 decimals: u32,
311 deposit_interest: u128,
312 deposit_balance: u128,
313 ) -> SpotMarket {
314 SpotMarket {
315 pubkey: vec![],
316 market_index: 0,
317 initial_asset_weight,
318 initial_liability_weight: 0,
319 imf_factor: 0,
320 scale_initial_asset_weight_start: scale_start,
321 decimals,
322 cumulative_deposit_interest: deposit_interest,
323 cumulative_borrow_interest: 0,
324 deposit_balance,
325 borrow_balance: 0,
326 optimal_utilization: 0,
327 optimal_borrow_rate: 0,
328 max_borrow_rate: 0,
329 min_borrow_rate: 0,
330 insurance_fund: Default::default(),
331 historical_oracle_data: Default::default(),
332 oracle: None,
333 }
334 }
335
336 #[test]
337 fn scaled_initial_asset_weight_no_scaling() {
338 let market = make_weight_market(8_000, 0, 0, 0, 0);
339 let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
340 assert_eq!(result, 8_000);
341 }
342
343 #[test]
344 fn scaled_initial_asset_weight_below_threshold() {
345 let decimals = 6u32;
346 let precision_decrease = 10u128.pow(19 - decimals);
347 let market = make_weight_market(
348 8_000,
349 1_000_000_000_000,
350 decimals,
351 precision_decrease,
352 500_000_000_000,
353 );
354 let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
355 assert_eq!(result, 8_000);
356 }
357
358 #[test]
359 fn scaled_initial_asset_weight_above_threshold() {
360 let decimals = 6u32;
361 let precision_decrease = 10u128.pow(19 - decimals);
362 let market = make_weight_market(
363 8_000,
364 500_000_000_000,
365 decimals,
366 precision_decrease,
367 1_000_000_000_000,
368 );
369 let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
370 assert_eq!(result, 4_000);
371 }
372
373 #[test]
374 fn to_amm_precision_decimals_6() {
375 let result = to_amm_precision(1_000_000, 6).unwrap();
376 assert_eq!(result, 1_000_000_000);
377 }
378
379 #[test]
380 fn to_amm_precision_decimals_9() {
381 let result = to_amm_precision(1_000_000_000, 9).unwrap();
382 assert_eq!(result, 1_000_000_000);
383 }
384
385 #[test]
386 fn to_amm_precision_decimals_18() {
387 let result = to_amm_precision(1_000_000_000_000_000_000, 18).unwrap();
388 assert_eq!(result, 1_000_000_000);
389 }
390}
391
392#[cfg(test)]
393#[allow(
394 clippy::allow_attributes,
395 clippy::allow_attributes_without_reason,
396 clippy::unwrap_used,
397 clippy::expect_used,
398 clippy::panic,
399 clippy::arithmetic_side_effects,
400 reason = "test code"
401)]
402mod proptests {
403 use super::*;
404 use proptest::prelude::*;
405
406 proptest! {
407 #[test]
408 fn isqrt_correct(n in 0u128..=1_000_000_000_000_000_000u128) {
409 let root = isqrt(n);
410 prop_assert!(root.checked_mul(root).unwrap() <= n);
412 let next = root + 1;
414 prop_assert!(next.checked_mul(next).unwrap() > n);
415 }
416
417 #[test]
418 fn size_discount_weight_le_base(
419 size in 0u128..=1_000_000_000_000_000_000u128,
420 imf in 0u32..=100_000u32,
421 base_weight in 1u128..=20_000u128,
422 ) {
423 let result = calculate_size_discount_asset_weight(size, imf, base_weight).unwrap();
424 prop_assert!(result <= base_weight, "discount weight {} > base {}", result, base_weight);
425 }
426
427 #[test]
428 fn size_premium_weight_ge_base(
429 size in 0u128..=1_000_000_000_000_000_000u128,
430 imf in 0u32..=100_000u32,
431 base_weight in 5u128..=20_000u128,
432 ) {
433 let result = calculate_size_premium_liability_weight(size, imf, base_weight).unwrap();
434 prop_assert!(result >= base_weight, "premium weight {} < base {}", result, base_weight);
435 }
436
437 #[test]
438 fn strict_price_asset_le_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
439 let result = get_strict_price(price, twap, true);
440 prop_assert!(result <= price);
441 prop_assert!(result <= twap as u64);
442 }
443
444 #[test]
445 fn strict_price_liability_ge_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
446 let result = get_strict_price(price, twap, false);
447 prop_assert!(result >= price && result >= twap as u64);
448 }
449 }
450}