pyra-margin 0.4.2

Margin weight, balance, and price calculations for Drift spot positions
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
use std::cmp;

use pyra_types::SpotMarket;

use crate::error::{MathError, MathResult};

/// Drift spot market weight precision: 10_000 = 100%.
pub const SPOT_WEIGHT_PRECISION: u128 = 10_000;
/// Drift IMF factor precision.
pub const SPOT_IMF_PRECISION: u128 = 1_000_000;
/// Drift AMM reserve precision (used for size normalisation).
pub const AMM_RESERVE_PRECISION: u128 = 1_000_000_000;

/// Integer square root (floor) via Newton's method.
fn isqrt(n: u128) -> u128 {
    if n < 2 {
        return n;
    }
    let mut x = 1u128 << ((128u32.saturating_sub(n.leading_zeros()).saturating_add(1)) / 2);
    let mut y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
    while y < x {
        x = y;
        y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
    }
    x
}

/// Convert a token balance from native decimals to AMM_RESERVE_PRECISION (1e9).
pub fn to_amm_precision(balance: u128, token_decimals: u32) -> MathResult<u128> {
    let size_precision = 10u128
        .checked_pow(token_decimals)
        .ok_or(MathError::Overflow)?;

    if size_precision > AMM_RESERVE_PRECISION {
        let scale = size_precision
            .checked_div(AMM_RESERVE_PRECISION)
            .ok_or(MathError::Overflow)?;
        balance.checked_div(scale).ok_or(MathError::Overflow)
    } else {
        balance
            .checked_mul(AMM_RESERVE_PRECISION)
            .ok_or(MathError::Overflow)?
            .checked_div(size_precision)
            .ok_or(MathError::Overflow)
    }
}

/// Scales the initial asset weight down when total market deposits exceed a threshold.
///
/// Reference: Drift SDK `calculateScaledInitialAssetWeight` in `math/spotBalance.ts`.
pub fn calculate_scaled_initial_asset_weight(
    spot_market: &SpotMarket,
    oracle_price: u64,
) -> MathResult<u128> {
    let initial_asset_weight = spot_market.initial_asset_weight as u128;

    if spot_market.scale_initial_asset_weight_start == 0 {
        return Ok(initial_asset_weight);
    }

    let precision_decrease = 10u128
        .checked_pow(19u32.saturating_sub(spot_market.decimals))
        .ok_or(MathError::Overflow)?;

    let deposit_tokens = (spot_market.deposit_balance)
        .checked_mul(spot_market.cumulative_deposit_interest)
        .ok_or(MathError::Overflow)?
        .checked_div(precision_decrease)
        .ok_or(MathError::Overflow)?;

    let token_precision = 10u128
        .checked_pow(spot_market.decimals)
        .ok_or(MathError::Overflow)?;

    let deposits_value = deposit_tokens
        .checked_mul(oracle_price as u128)
        .ok_or(MathError::Overflow)?
        .checked_div(token_precision)
        .ok_or(MathError::Overflow)?;

    let threshold = spot_market.scale_initial_asset_weight_start as u128;
    if deposits_value < threshold {
        return Ok(initial_asset_weight);
    }

    initial_asset_weight
        .checked_mul(threshold)
        .ok_or(MathError::Overflow)?
        .checked_div(deposits_value)
        .ok_or(MathError::Overflow)
}

/// Applies IMF size discount to asset weight — larger deposits get less collateral credit.
///
/// Reference: Drift SDK `calculateSizeDiscountAssetWeight` in `math/margin.ts`.
pub fn calculate_size_discount_asset_weight(
    size_in_amm: u128,
    imf_factor: u32,
    asset_weight: u128,
) -> MathResult<u128> {
    if imf_factor == 0 {
        return Ok(asset_weight);
    }

    let size_times_10 = size_in_amm
        .checked_mul(10)
        .ok_or(MathError::Overflow)?
        .checked_add(1)
        .ok_or(MathError::Overflow)?;
    let size_sqrt = isqrt(size_times_10);

    let imf_numerator: u128 = SPOT_IMF_PRECISION
        .checked_add(
            SPOT_IMF_PRECISION
                .checked_div(10)
                .ok_or(MathError::Overflow)?,
        )
        .ok_or(MathError::Overflow)?;

    let numerator = imf_numerator
        .checked_mul(SPOT_WEIGHT_PRECISION)
        .ok_or(MathError::Overflow)?;

    let inner = size_sqrt
        .checked_mul(imf_factor as u128)
        .ok_or(MathError::Overflow)?
        .checked_div(100_000)
        .ok_or(MathError::Overflow)?;
    let denominator = SPOT_IMF_PRECISION
        .checked_add(inner)
        .ok_or(MathError::Overflow)?;

    let size_discount_weight = numerator
        .checked_div(denominator)
        .ok_or(MathError::Overflow)?;

    Ok(cmp::min(asset_weight, size_discount_weight))
}

/// Applies IMF size premium to liability weight — larger borrows need more margin.
///
/// Reference: Drift SDK `calculateSizePremiumLiabilityWeight` in `math/margin.ts`.
pub fn calculate_size_premium_liability_weight(
    size_in_amm: u128,
    imf_factor: u32,
    liability_weight: u128,
) -> MathResult<u128> {
    if imf_factor == 0 {
        return Ok(liability_weight);
    }

    let size_times_10 = size_in_amm
        .checked_mul(10)
        .ok_or(MathError::Overflow)?
        .checked_add(1)
        .ok_or(MathError::Overflow)?;
    let size_sqrt = isqrt(size_times_10);

    let lw_fifth = liability_weight.checked_div(5).ok_or(MathError::Overflow)?;
    let liability_weight_numerator = liability_weight
        .checked_sub(lw_fifth)
        .ok_or(MathError::Overflow)?;

    let denom = 100_000u128
        .checked_mul(SPOT_IMF_PRECISION)
        .ok_or(MathError::Overflow)?
        .checked_div(SPOT_WEIGHT_PRECISION)
        .ok_or(MathError::Overflow)?;

    let premium_term = size_sqrt
        .checked_mul(imf_factor as u128)
        .ok_or(MathError::Overflow)?
        .checked_div(denom)
        .ok_or(MathError::Overflow)?;

    let size_premium_weight = liability_weight_numerator
        .checked_add(premium_term)
        .ok_or(MathError::Overflow)?;

    Ok(cmp::max(liability_weight, size_premium_weight))
}

/// Calculate the effective initial asset weight for a position, applying both
/// scale-down (when market deposits are large) and IMF size discount.
pub fn calculate_asset_weight(
    token_amount: u128,
    oracle_price: u64,
    spot_market: &SpotMarket,
) -> MathResult<u128> {
    let scaled_weight = calculate_scaled_initial_asset_weight(spot_market, oracle_price)?;
    let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
    calculate_size_discount_asset_weight(size_in_amm, spot_market.imf_factor, scaled_weight)
}

/// Calculate the effective initial liability weight for a position, applying
/// IMF size premium.
pub fn calculate_liability_weight(
    token_amount: u128,
    spot_market: &SpotMarket,
) -> MathResult<u128> {
    let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
    calculate_size_premium_liability_weight(
        size_in_amm,
        spot_market.imf_factor,
        spot_market.initial_liability_weight as u128,
    )
}

/// Get a conservative oracle price for margin calculations.
///
/// For assets: use `min(oracle, twap5min)` — lower price = less collateral.
/// For liabilities: use `max(oracle, twap5min)` — higher price = larger debt.
///
/// The TWAP comes from `historical_oracle_data.last_oracle_price_twap5min` on the SpotMarket
/// and is in PRICE_PRECISION (1e6), same scale as `price_usdc_base_units`.
pub fn get_strict_price(price_usdc_base_units: u64, twap5min: i64, is_asset: bool) -> u64 {
    let twap = if twap5min > 0 {
        twap5min as u64
    } else {
        price_usdc_base_units
    };
    if is_asset {
        cmp::min(price_usdc_base_units, twap)
    } else {
        cmp::max(price_usdc_base_units, twap)
    }
}

#[cfg(test)]
#[allow(
    clippy::allow_attributes,
    clippy::allow_attributes_without_reason,
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::panic,
    clippy::arithmetic_side_effects,
    reason = "test code"
)]
mod tests {
    use super::*;

    #[test]
    fn isqrt_basic_values() {
        assert_eq!(isqrt(0), 0);
        assert_eq!(isqrt(1), 1);
        assert_eq!(isqrt(4), 2);
        assert_eq!(isqrt(9), 3);
        assert_eq!(isqrt(10), 3);
        assert_eq!(isqrt(100), 10);
        assert_eq!(isqrt(10_000_000_000), 100_000);
    }

    #[test]
    fn size_discount_asset_weight_no_imf() {
        let result = calculate_size_discount_asset_weight(1_000_000_000, 0, 8_000).unwrap();
        assert_eq!(result, 8_000);
    }

    #[test]
    fn size_discount_asset_weight_with_imf() {
        let result = calculate_size_discount_asset_weight(1_000_000_000, 1000, 8_000).unwrap();
        assert_eq!(result, 8_000);

        let result =
            calculate_size_discount_asset_weight(1_000_000_000_000_000, 1000, 8_000).unwrap();
        assert!(result < 8_000, "Large position should have reduced weight");
    }

    #[test]
    fn size_premium_liability_weight_no_imf() {
        let result = calculate_size_premium_liability_weight(1_000_000_000, 0, 12_000).unwrap();
        assert_eq!(result, 12_000);
    }

    #[test]
    fn size_premium_liability_weight_with_imf() {
        let result = calculate_size_premium_liability_weight(1_000_000_000, 1000, 12_000).unwrap();
        assert_eq!(result, 12_000);

        let result =
            calculate_size_premium_liability_weight(1_000_000_000_000_000, 1000, 12_000).unwrap();
        assert!(
            result > 12_000,
            "Large position should have increased weight"
        );
    }

    #[test]
    fn strict_price_asset_uses_min() {
        assert_eq!(get_strict_price(1_000_000, 900_000, true), 900_000);
        assert_eq!(get_strict_price(1_000_000, 1_100_000, true), 1_000_000);
    }

    #[test]
    fn strict_price_liability_uses_max() {
        assert_eq!(get_strict_price(1_000_000, 900_000, false), 1_000_000);
        assert_eq!(get_strict_price(1_000_000, 1_100_000, false), 1_100_000);
    }

    #[test]
    fn strict_price_nonpositive_twap_falls_back() {
        assert_eq!(get_strict_price(1_000_000, 0, true), 1_000_000);
        assert_eq!(get_strict_price(1_000_000, -500, true), 1_000_000);
        assert_eq!(get_strict_price(1_000_000, 0, false), 1_000_000);
    }

    fn make_weight_market(
        initial_asset_weight: u32,
        scale_start: u64,
        decimals: u32,
        deposit_interest: u128,
        deposit_balance: u128,
    ) -> SpotMarket {
        SpotMarket {
            pubkey: vec![],
            market_index: 0,
            initial_asset_weight,
            initial_liability_weight: 0,
            imf_factor: 0,
            scale_initial_asset_weight_start: scale_start,
            decimals,
            cumulative_deposit_interest: deposit_interest,
            cumulative_borrow_interest: 0,
            deposit_balance,
            borrow_balance: 0,
            optimal_utilization: 0,
            optimal_borrow_rate: 0,
            max_borrow_rate: 0,
            min_borrow_rate: 0,
            insurance_fund: Default::default(),
            historical_oracle_data: Default::default(),
            oracle: None,
        }
    }

    #[test]
    fn scaled_initial_asset_weight_no_scaling() {
        let market = make_weight_market(8_000, 0, 0, 0, 0);
        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
        assert_eq!(result, 8_000);
    }

    #[test]
    fn scaled_initial_asset_weight_below_threshold() {
        let decimals = 6u32;
        let precision_decrease = 10u128.pow(19 - decimals);
        let market = make_weight_market(
            8_000,
            1_000_000_000_000,
            decimals,
            precision_decrease,
            500_000_000_000,
        );
        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
        assert_eq!(result, 8_000);
    }

    #[test]
    fn scaled_initial_asset_weight_above_threshold() {
        let decimals = 6u32;
        let precision_decrease = 10u128.pow(19 - decimals);
        let market = make_weight_market(
            8_000,
            500_000_000_000,
            decimals,
            precision_decrease,
            1_000_000_000_000,
        );
        let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
        assert_eq!(result, 4_000);
    }

    #[test]
    fn to_amm_precision_decimals_6() {
        let result = to_amm_precision(1_000_000, 6).unwrap();
        assert_eq!(result, 1_000_000_000);
    }

    #[test]
    fn to_amm_precision_decimals_9() {
        let result = to_amm_precision(1_000_000_000, 9).unwrap();
        assert_eq!(result, 1_000_000_000);
    }

    #[test]
    fn to_amm_precision_decimals_18() {
        let result = to_amm_precision(1_000_000_000_000_000_000, 18).unwrap();
        assert_eq!(result, 1_000_000_000);
    }
}

#[cfg(test)]
#[allow(
    clippy::allow_attributes,
    clippy::allow_attributes_without_reason,
    clippy::unwrap_used,
    clippy::expect_used,
    clippy::panic,
    clippy::arithmetic_side_effects,
    reason = "test code"
)]
mod proptests {
    use super::*;
    use proptest::prelude::*;

    proptest! {
        #[test]
        fn isqrt_correct(n in 0u128..=1_000_000_000_000_000_000u128) {
            let root = isqrt(n);
            // root^2 <= n
            prop_assert!(root.checked_mul(root).unwrap() <= n);
            // (root+1)^2 > n
            let next = root + 1;
            prop_assert!(next.checked_mul(next).unwrap() > n);
        }

        #[test]
        fn size_discount_weight_le_base(
            size in 0u128..=1_000_000_000_000_000_000u128,
            imf in 0u32..=100_000u32,
            base_weight in 1u128..=20_000u128,
        ) {
            let result = calculate_size_discount_asset_weight(size, imf, base_weight).unwrap();
            prop_assert!(result <= base_weight, "discount weight {} > base {}", result, base_weight);
        }

        #[test]
        fn size_premium_weight_ge_base(
            size in 0u128..=1_000_000_000_000_000_000u128,
            imf in 0u32..=100_000u32,
            base_weight in 5u128..=20_000u128,
        ) {
            let result = calculate_size_premium_liability_weight(size, imf, base_weight).unwrap();
            prop_assert!(result >= base_weight, "premium weight {} < base {}", result, base_weight);
        }

        #[test]
        fn strict_price_asset_le_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
            let result = get_strict_price(price, twap, true);
            prop_assert!(result <= price);
            prop_assert!(result <= twap as u64);
        }

        #[test]
        fn strict_price_liability_ge_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
            let result = get_strict_price(price, twap, false);
            prop_assert!(result >= price && result >= twap as u64);
        }
    }
}