Skip to main content

pyra_margin/drift/
weights.rs

1use std::cmp;
2
3use pyra_types::SpotMarket;
4
5use crate::error::{MathError, MathResult};
6
7/// Drift spot market weight precision: 10_000 = 100%.
8pub const SPOT_WEIGHT_PRECISION: u128 = 10_000;
9/// Drift IMF factor precision.
10pub const SPOT_IMF_PRECISION: u128 = 1_000_000;
11/// Drift AMM reserve precision (used for size normalisation).
12pub const AMM_RESERVE_PRECISION: u128 = 1_000_000_000;
13
14/// Integer square root (floor) via Newton's method.
15fn isqrt(n: u128) -> u128 {
16    if n < 2 {
17        return n;
18    }
19    let mut x = 1u128 << ((128u32.saturating_sub(n.leading_zeros()).saturating_add(1)) / 2);
20    let mut y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
21    while y < x {
22        x = y;
23        y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
24    }
25    x
26}
27
28/// Convert a token balance from native decimals to AMM_RESERVE_PRECISION (1e9).
29pub fn to_amm_precision(balance: u128, token_decimals: u32) -> MathResult<u128> {
30    let size_precision = 10u128
31        .checked_pow(token_decimals)
32        .ok_or(MathError::Overflow)?;
33
34    if size_precision > AMM_RESERVE_PRECISION {
35        let scale = size_precision
36            .checked_div(AMM_RESERVE_PRECISION)
37            .ok_or(MathError::Overflow)?;
38        balance.checked_div(scale).ok_or(MathError::Overflow)
39    } else {
40        balance
41            .checked_mul(AMM_RESERVE_PRECISION)
42            .ok_or(MathError::Overflow)?
43            .checked_div(size_precision)
44            .ok_or(MathError::Overflow)
45    }
46}
47
48/// Scales the initial asset weight down when total market deposits exceed a threshold.
49///
50/// Reference: Drift SDK `calculateScaledInitialAssetWeight` in `math/spotBalance.ts`.
51pub fn calculate_scaled_initial_asset_weight(
52    spot_market: &SpotMarket,
53    oracle_price: u64,
54) -> MathResult<u128> {
55    let initial_asset_weight = spot_market.initial_asset_weight as u128;
56
57    if spot_market.scale_initial_asset_weight_start == 0 {
58        return Ok(initial_asset_weight);
59    }
60
61    let precision_decrease = 10u128
62        .checked_pow(19u32.saturating_sub(spot_market.decimals))
63        .ok_or(MathError::Overflow)?;
64
65    let deposit_tokens = (spot_market.deposit_balance)
66        .checked_mul(spot_market.cumulative_deposit_interest)
67        .ok_or(MathError::Overflow)?
68        .checked_div(precision_decrease)
69        .ok_or(MathError::Overflow)?;
70
71    let token_precision = 10u128
72        .checked_pow(spot_market.decimals)
73        .ok_or(MathError::Overflow)?;
74
75    let deposits_value = deposit_tokens
76        .checked_mul(oracle_price as u128)
77        .ok_or(MathError::Overflow)?
78        .checked_div(token_precision)
79        .ok_or(MathError::Overflow)?;
80
81    let threshold = spot_market.scale_initial_asset_weight_start as u128;
82    if deposits_value < threshold {
83        return Ok(initial_asset_weight);
84    }
85
86    initial_asset_weight
87        .checked_mul(threshold)
88        .ok_or(MathError::Overflow)?
89        .checked_div(deposits_value)
90        .ok_or(MathError::Overflow)
91}
92
93/// Applies IMF size discount to asset weight — larger deposits get less collateral credit.
94///
95/// Reference: Drift SDK `calculateSizeDiscountAssetWeight` in `math/margin.ts`.
96pub fn calculate_size_discount_asset_weight(
97    size_in_amm: u128,
98    imf_factor: u32,
99    asset_weight: u128,
100) -> MathResult<u128> {
101    if imf_factor == 0 {
102        return Ok(asset_weight);
103    }
104
105    let size_times_10 = size_in_amm
106        .checked_mul(10)
107        .ok_or(MathError::Overflow)?
108        .checked_add(1)
109        .ok_or(MathError::Overflow)?;
110    let size_sqrt = isqrt(size_times_10);
111
112    let imf_numerator: u128 = SPOT_IMF_PRECISION
113        .checked_add(
114            SPOT_IMF_PRECISION
115                .checked_div(10)
116                .ok_or(MathError::Overflow)?,
117        )
118        .ok_or(MathError::Overflow)?;
119
120    let numerator = imf_numerator
121        .checked_mul(SPOT_WEIGHT_PRECISION)
122        .ok_or(MathError::Overflow)?;
123
124    let inner = size_sqrt
125        .checked_mul(imf_factor as u128)
126        .ok_or(MathError::Overflow)?
127        .checked_div(100_000)
128        .ok_or(MathError::Overflow)?;
129    let denominator = SPOT_IMF_PRECISION
130        .checked_add(inner)
131        .ok_or(MathError::Overflow)?;
132
133    let size_discount_weight = numerator
134        .checked_div(denominator)
135        .ok_or(MathError::Overflow)?;
136
137    Ok(cmp::min(asset_weight, size_discount_weight))
138}
139
140/// Applies IMF size premium to liability weight — larger borrows need more margin.
141///
142/// Reference: Drift SDK `calculateSizePremiumLiabilityWeight` in `math/margin.ts`.
143pub fn calculate_size_premium_liability_weight(
144    size_in_amm: u128,
145    imf_factor: u32,
146    liability_weight: u128,
147) -> MathResult<u128> {
148    if imf_factor == 0 {
149        return Ok(liability_weight);
150    }
151
152    let size_times_10 = size_in_amm
153        .checked_mul(10)
154        .ok_or(MathError::Overflow)?
155        .checked_add(1)
156        .ok_or(MathError::Overflow)?;
157    let size_sqrt = isqrt(size_times_10);
158
159    let lw_fifth = liability_weight.checked_div(5).ok_or(MathError::Overflow)?;
160    let liability_weight_numerator = liability_weight
161        .checked_sub(lw_fifth)
162        .ok_or(MathError::Overflow)?;
163
164    let denom = 100_000u128
165        .checked_mul(SPOT_IMF_PRECISION)
166        .ok_or(MathError::Overflow)?
167        .checked_div(SPOT_WEIGHT_PRECISION)
168        .ok_or(MathError::Overflow)?;
169
170    let premium_term = size_sqrt
171        .checked_mul(imf_factor as u128)
172        .ok_or(MathError::Overflow)?
173        .checked_div(denom)
174        .ok_or(MathError::Overflow)?;
175
176    let size_premium_weight = liability_weight_numerator
177        .checked_add(premium_term)
178        .ok_or(MathError::Overflow)?;
179
180    Ok(cmp::max(liability_weight, size_premium_weight))
181}
182
183/// Calculate the effective initial asset weight for a position, applying both
184/// scale-down (when market deposits are large) and IMF size discount.
185pub fn calculate_asset_weight(
186    token_amount: u128,
187    oracle_price: u64,
188    spot_market: &SpotMarket,
189) -> MathResult<u128> {
190    let scaled_weight = calculate_scaled_initial_asset_weight(spot_market, oracle_price)?;
191    let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
192    calculate_size_discount_asset_weight(size_in_amm, spot_market.imf_factor, scaled_weight)
193}
194
195/// Calculate the effective initial liability weight for a position, applying
196/// IMF size premium.
197pub fn calculate_liability_weight(
198    token_amount: u128,
199    spot_market: &SpotMarket,
200) -> MathResult<u128> {
201    let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
202    calculate_size_premium_liability_weight(
203        size_in_amm,
204        spot_market.imf_factor,
205        spot_market.initial_liability_weight as u128,
206    )
207}
208
209/// Get a conservative oracle price for margin calculations.
210///
211/// For assets: use `min(oracle, twap5min)` — lower price = less collateral.
212/// For liabilities: use `max(oracle, twap5min)` — higher price = larger debt.
213///
214/// The TWAP comes from `historical_oracle_data.last_oracle_price_twap5min` on the SpotMarket
215/// and is in PRICE_PRECISION (1e6), same scale as `price_usdc_base_units`.
216pub fn get_strict_price(price_usdc_base_units: u64, twap5min: i64, is_asset: bool) -> u64 {
217    let twap = if twap5min > 0 {
218        twap5min as u64
219    } else {
220        price_usdc_base_units
221    };
222    if is_asset {
223        cmp::min(price_usdc_base_units, twap)
224    } else {
225        cmp::max(price_usdc_base_units, twap)
226    }
227}
228
229#[cfg(test)]
230#[allow(
231    clippy::allow_attributes,
232    clippy::allow_attributes_without_reason,
233    clippy::unwrap_used,
234    clippy::expect_used,
235    clippy::panic,
236    clippy::arithmetic_side_effects,
237    reason = "test code"
238)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn isqrt_basic_values() {
244        assert_eq!(isqrt(0), 0);
245        assert_eq!(isqrt(1), 1);
246        assert_eq!(isqrt(4), 2);
247        assert_eq!(isqrt(9), 3);
248        assert_eq!(isqrt(10), 3);
249        assert_eq!(isqrt(100), 10);
250        assert_eq!(isqrt(10_000_000_000), 100_000);
251    }
252
253    #[test]
254    fn size_discount_asset_weight_no_imf() {
255        let result = calculate_size_discount_asset_weight(1_000_000_000, 0, 8_000).unwrap();
256        assert_eq!(result, 8_000);
257    }
258
259    #[test]
260    fn size_discount_asset_weight_with_imf() {
261        let result = calculate_size_discount_asset_weight(1_000_000_000, 1000, 8_000).unwrap();
262        assert_eq!(result, 8_000);
263
264        let result =
265            calculate_size_discount_asset_weight(1_000_000_000_000_000, 1000, 8_000).unwrap();
266        assert!(result < 8_000, "Large position should have reduced weight");
267    }
268
269    #[test]
270    fn size_premium_liability_weight_no_imf() {
271        let result = calculate_size_premium_liability_weight(1_000_000_000, 0, 12_000).unwrap();
272        assert_eq!(result, 12_000);
273    }
274
275    #[test]
276    fn size_premium_liability_weight_with_imf() {
277        let result = calculate_size_premium_liability_weight(1_000_000_000, 1000, 12_000).unwrap();
278        assert_eq!(result, 12_000);
279
280        let result =
281            calculate_size_premium_liability_weight(1_000_000_000_000_000, 1000, 12_000).unwrap();
282        assert!(
283            result > 12_000,
284            "Large position should have increased weight"
285        );
286    }
287
288    #[test]
289    fn strict_price_asset_uses_min() {
290        assert_eq!(get_strict_price(1_000_000, 900_000, true), 900_000);
291        assert_eq!(get_strict_price(1_000_000, 1_100_000, true), 1_000_000);
292    }
293
294    #[test]
295    fn strict_price_liability_uses_max() {
296        assert_eq!(get_strict_price(1_000_000, 900_000, false), 1_000_000);
297        assert_eq!(get_strict_price(1_000_000, 1_100_000, false), 1_100_000);
298    }
299
300    #[test]
301    fn strict_price_nonpositive_twap_falls_back() {
302        assert_eq!(get_strict_price(1_000_000, 0, true), 1_000_000);
303        assert_eq!(get_strict_price(1_000_000, -500, true), 1_000_000);
304        assert_eq!(get_strict_price(1_000_000, 0, false), 1_000_000);
305    }
306
307    fn make_weight_market(
308        initial_asset_weight: u32,
309        scale_start: u64,
310        decimals: u32,
311        deposit_interest: u128,
312        deposit_balance: u128,
313    ) -> SpotMarket {
314        SpotMarket {
315            pubkey: vec![],
316            market_index: 0,
317            initial_asset_weight,
318            initial_liability_weight: 0,
319            imf_factor: 0,
320            scale_initial_asset_weight_start: scale_start,
321            decimals,
322            cumulative_deposit_interest: deposit_interest,
323            cumulative_borrow_interest: 0,
324            deposit_balance,
325            borrow_balance: 0,
326            optimal_utilization: 0,
327            optimal_borrow_rate: 0,
328            max_borrow_rate: 0,
329            min_borrow_rate: 0,
330            insurance_fund: Default::default(),
331            historical_oracle_data: Default::default(),
332            oracle: None,
333        }
334    }
335
336    #[test]
337    fn scaled_initial_asset_weight_no_scaling() {
338        let market = make_weight_market(8_000, 0, 0, 0, 0);
339        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
340        assert_eq!(result, 8_000);
341    }
342
343    #[test]
344    fn scaled_initial_asset_weight_below_threshold() {
345        let decimals = 6u32;
346        let precision_decrease = 10u128.pow(19 - decimals);
347        let market = make_weight_market(
348            8_000,
349            1_000_000_000_000,
350            decimals,
351            precision_decrease,
352            500_000_000_000,
353        );
354        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
355        assert_eq!(result, 8_000);
356    }
357
358    #[test]
359    fn scaled_initial_asset_weight_above_threshold() {
360        let decimals = 6u32;
361        let precision_decrease = 10u128.pow(19 - decimals);
362        let market = make_weight_market(
363            8_000,
364            500_000_000_000,
365            decimals,
366            precision_decrease,
367            1_000_000_000_000,
368        );
369        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
370        assert_eq!(result, 4_000);
371    }
372
373    #[test]
374    fn to_amm_precision_decimals_6() {
375        let result = to_amm_precision(1_000_000, 6).unwrap();
376        assert_eq!(result, 1_000_000_000);
377    }
378
379    #[test]
380    fn to_amm_precision_decimals_9() {
381        let result = to_amm_precision(1_000_000_000, 9).unwrap();
382        assert_eq!(result, 1_000_000_000);
383    }
384
385    #[test]
386    fn to_amm_precision_decimals_18() {
387        let result = to_amm_precision(1_000_000_000_000_000_000, 18).unwrap();
388        assert_eq!(result, 1_000_000_000);
389    }
390}
391
392#[cfg(test)]
393#[allow(
394    clippy::allow_attributes,
395    clippy::allow_attributes_without_reason,
396    clippy::unwrap_used,
397    clippy::expect_used,
398    clippy::panic,
399    clippy::arithmetic_side_effects,
400    reason = "test code"
401)]
402mod proptests {
403    use super::*;
404    use proptest::prelude::*;
405
406    proptest! {
407        #[test]
408        fn isqrt_correct(n in 0u128..=1_000_000_000_000_000_000u128) {
409            let root = isqrt(n);
410            // root^2 <= n
411            prop_assert!(root.checked_mul(root).unwrap() <= n);
412            // (root+1)^2 > n
413            let next = root + 1;
414            prop_assert!(next.checked_mul(next).unwrap() > n);
415        }
416
417        #[test]
418        fn size_discount_weight_le_base(
419            size in 0u128..=1_000_000_000_000_000_000u128,
420            imf in 0u32..=100_000u32,
421            base_weight in 1u128..=20_000u128,
422        ) {
423            let result = calculate_size_discount_asset_weight(size, imf, base_weight).unwrap();
424            prop_assert!(result <= base_weight, "discount weight {} > base {}", result, base_weight);
425        }
426
427        #[test]
428        fn size_premium_weight_ge_base(
429            size in 0u128..=1_000_000_000_000_000_000u128,
430            imf in 0u32..=100_000u32,
431            base_weight in 5u128..=20_000u128,
432        ) {
433            let result = calculate_size_premium_liability_weight(size, imf, base_weight).unwrap();
434            prop_assert!(result >= base_weight, "premium weight {} < base {}", result, base_weight);
435        }
436
437        #[test]
438        fn strict_price_asset_le_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
439            let result = get_strict_price(price, twap, true);
440            prop_assert!(result <= price);
441            prop_assert!(result <= twap as u64);
442        }
443
444        #[test]
445        fn strict_price_liability_ge_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
446            let result = get_strict_price(price, twap, false);
447            prop_assert!(result >= price && result >= twap as u64);
448        }
449    }
450}