use std::cmp;
use pyra_types::SpotMarket;
use crate::error::{MathError, MathResult};
pub const SPOT_WEIGHT_PRECISION: u128 = 10_000;
pub const SPOT_IMF_PRECISION: u128 = 1_000_000;
pub const AMM_RESERVE_PRECISION: u128 = 1_000_000_000;
fn isqrt(n: u128) -> u128 {
if n < 2 {
return n;
}
let mut x = 1u128 << ((128u32.saturating_sub(n.leading_zeros()).saturating_add(1)) / 2);
let mut y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
while y < x {
x = y;
y = x.checked_add(n.checked_div(x).unwrap_or(0)).unwrap_or(x) / 2;
}
x
}
pub fn to_amm_precision(balance: u128, token_decimals: u32) -> MathResult<u128> {
let size_precision = 10u128
.checked_pow(token_decimals)
.ok_or(MathError::Overflow)?;
if size_precision > AMM_RESERVE_PRECISION {
let scale = size_precision
.checked_div(AMM_RESERVE_PRECISION)
.ok_or(MathError::Overflow)?;
balance.checked_div(scale).ok_or(MathError::Overflow)
} else {
balance
.checked_mul(AMM_RESERVE_PRECISION)
.ok_or(MathError::Overflow)?
.checked_div(size_precision)
.ok_or(MathError::Overflow)
}
}
pub fn calculate_scaled_initial_asset_weight(
spot_market: &SpotMarket,
oracle_price: u64,
) -> MathResult<u128> {
let initial_asset_weight = spot_market.initial_asset_weight as u128;
if spot_market.scale_initial_asset_weight_start == 0 {
return Ok(initial_asset_weight);
}
let precision_decrease = 10u128
.checked_pow(19u32.saturating_sub(spot_market.decimals))
.ok_or(MathError::Overflow)?;
let deposit_tokens = (spot_market.deposit_balance)
.checked_mul(spot_market.cumulative_deposit_interest)
.ok_or(MathError::Overflow)?
.checked_div(precision_decrease)
.ok_or(MathError::Overflow)?;
let token_precision = 10u128
.checked_pow(spot_market.decimals)
.ok_or(MathError::Overflow)?;
let deposits_value = deposit_tokens
.checked_mul(oracle_price as u128)
.ok_or(MathError::Overflow)?
.checked_div(token_precision)
.ok_or(MathError::Overflow)?;
let threshold = spot_market.scale_initial_asset_weight_start as u128;
if deposits_value < threshold {
return Ok(initial_asset_weight);
}
initial_asset_weight
.checked_mul(threshold)
.ok_or(MathError::Overflow)?
.checked_div(deposits_value)
.ok_or(MathError::Overflow)
}
pub fn calculate_size_discount_asset_weight(
size_in_amm: u128,
imf_factor: u32,
asset_weight: u128,
) -> MathResult<u128> {
if imf_factor == 0 {
return Ok(asset_weight);
}
let size_times_10 = size_in_amm
.checked_mul(10)
.ok_or(MathError::Overflow)?
.checked_add(1)
.ok_or(MathError::Overflow)?;
let size_sqrt = isqrt(size_times_10);
let imf_numerator: u128 = SPOT_IMF_PRECISION
.checked_add(
SPOT_IMF_PRECISION
.checked_div(10)
.ok_or(MathError::Overflow)?,
)
.ok_or(MathError::Overflow)?;
let numerator = imf_numerator
.checked_mul(SPOT_WEIGHT_PRECISION)
.ok_or(MathError::Overflow)?;
let inner = size_sqrt
.checked_mul(imf_factor as u128)
.ok_or(MathError::Overflow)?
.checked_div(100_000)
.ok_or(MathError::Overflow)?;
let denominator = SPOT_IMF_PRECISION
.checked_add(inner)
.ok_or(MathError::Overflow)?;
let size_discount_weight = numerator
.checked_div(denominator)
.ok_or(MathError::Overflow)?;
Ok(cmp::min(asset_weight, size_discount_weight))
}
pub fn calculate_size_premium_liability_weight(
size_in_amm: u128,
imf_factor: u32,
liability_weight: u128,
) -> MathResult<u128> {
if imf_factor == 0 {
return Ok(liability_weight);
}
let size_times_10 = size_in_amm
.checked_mul(10)
.ok_or(MathError::Overflow)?
.checked_add(1)
.ok_or(MathError::Overflow)?;
let size_sqrt = isqrt(size_times_10);
let lw_fifth = liability_weight.checked_div(5).ok_or(MathError::Overflow)?;
let liability_weight_numerator = liability_weight
.checked_sub(lw_fifth)
.ok_or(MathError::Overflow)?;
let denom = 100_000u128
.checked_mul(SPOT_IMF_PRECISION)
.ok_or(MathError::Overflow)?
.checked_div(SPOT_WEIGHT_PRECISION)
.ok_or(MathError::Overflow)?;
let premium_term = size_sqrt
.checked_mul(imf_factor as u128)
.ok_or(MathError::Overflow)?
.checked_div(denom)
.ok_or(MathError::Overflow)?;
let size_premium_weight = liability_weight_numerator
.checked_add(premium_term)
.ok_or(MathError::Overflow)?;
Ok(cmp::max(liability_weight, size_premium_weight))
}
pub fn calculate_asset_weight(
token_amount: u128,
oracle_price: u64,
spot_market: &SpotMarket,
) -> MathResult<u128> {
let scaled_weight = calculate_scaled_initial_asset_weight(spot_market, oracle_price)?;
let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
calculate_size_discount_asset_weight(size_in_amm, spot_market.imf_factor, scaled_weight)
}
pub fn calculate_liability_weight(
token_amount: u128,
spot_market: &SpotMarket,
) -> MathResult<u128> {
let size_in_amm = to_amm_precision(token_amount, spot_market.decimals)?;
calculate_size_premium_liability_weight(
size_in_amm,
spot_market.imf_factor,
spot_market.initial_liability_weight as u128,
)
}
pub fn get_strict_price(price_usdc_base_units: u64, twap5min: i64, is_asset: bool) -> u64 {
let twap = if twap5min > 0 {
twap5min as u64
} else {
price_usdc_base_units
};
if is_asset {
cmp::min(price_usdc_base_units, twap)
} else {
cmp::max(price_usdc_base_units, twap)
}
}
#[cfg(test)]
#[allow(
clippy::allow_attributes,
clippy::allow_attributes_without_reason,
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::arithmetic_side_effects,
reason = "test code"
)]
mod tests {
use super::*;
#[test]
fn isqrt_basic_values() {
assert_eq!(isqrt(0), 0);
assert_eq!(isqrt(1), 1);
assert_eq!(isqrt(4), 2);
assert_eq!(isqrt(9), 3);
assert_eq!(isqrt(10), 3);
assert_eq!(isqrt(100), 10);
assert_eq!(isqrt(10_000_000_000), 100_000);
}
#[test]
fn size_discount_asset_weight_no_imf() {
let result = calculate_size_discount_asset_weight(1_000_000_000, 0, 8_000).unwrap();
assert_eq!(result, 8_000);
}
#[test]
fn size_discount_asset_weight_with_imf() {
let result = calculate_size_discount_asset_weight(1_000_000_000, 1000, 8_000).unwrap();
assert_eq!(result, 8_000);
let result =
calculate_size_discount_asset_weight(1_000_000_000_000_000, 1000, 8_000).unwrap();
assert!(result < 8_000, "Large position should have reduced weight");
}
#[test]
fn size_premium_liability_weight_no_imf() {
let result = calculate_size_premium_liability_weight(1_000_000_000, 0, 12_000).unwrap();
assert_eq!(result, 12_000);
}
#[test]
fn size_premium_liability_weight_with_imf() {
let result = calculate_size_premium_liability_weight(1_000_000_000, 1000, 12_000).unwrap();
assert_eq!(result, 12_000);
let result =
calculate_size_premium_liability_weight(1_000_000_000_000_000, 1000, 12_000).unwrap();
assert!(
result > 12_000,
"Large position should have increased weight"
);
}
#[test]
fn strict_price_asset_uses_min() {
assert_eq!(get_strict_price(1_000_000, 900_000, true), 900_000);
assert_eq!(get_strict_price(1_000_000, 1_100_000, true), 1_000_000);
}
#[test]
fn strict_price_liability_uses_max() {
assert_eq!(get_strict_price(1_000_000, 900_000, false), 1_000_000);
assert_eq!(get_strict_price(1_000_000, 1_100_000, false), 1_100_000);
}
#[test]
fn strict_price_nonpositive_twap_falls_back() {
assert_eq!(get_strict_price(1_000_000, 0, true), 1_000_000);
assert_eq!(get_strict_price(1_000_000, -500, true), 1_000_000);
assert_eq!(get_strict_price(1_000_000, 0, false), 1_000_000);
}
fn make_weight_market(
initial_asset_weight: u32,
scale_start: u64,
decimals: u32,
deposit_interest: u128,
deposit_balance: u128,
) -> SpotMarket {
SpotMarket {
pubkey: vec![],
market_index: 0,
initial_asset_weight,
initial_liability_weight: 0,
imf_factor: 0,
scale_initial_asset_weight_start: scale_start,
decimals,
cumulative_deposit_interest: deposit_interest,
cumulative_borrow_interest: 0,
deposit_balance,
borrow_balance: 0,
optimal_utilization: 0,
optimal_borrow_rate: 0,
max_borrow_rate: 0,
min_borrow_rate: 0,
insurance_fund: Default::default(),
historical_oracle_data: Default::default(),
oracle: None,
}
}
#[test]
fn scaled_initial_asset_weight_no_scaling() {
let market = make_weight_market(8_000, 0, 0, 0, 0);
let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
assert_eq!(result, 8_000);
}
#[test]
fn scaled_initial_asset_weight_below_threshold() {
let decimals = 6u32;
let precision_decrease = 10u128.pow(19 - decimals);
let market = make_weight_market(
8_000,
1_000_000_000_000,
decimals,
precision_decrease,
500_000_000_000,
);
let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
assert_eq!(result, 8_000);
}
#[test]
fn scaled_initial_asset_weight_above_threshold() {
let decimals = 6u32;
let precision_decrease = 10u128.pow(19 - decimals);
let market = make_weight_market(
8_000,
500_000_000_000,
decimals,
precision_decrease,
1_000_000_000_000,
);
let result = calculate_scaled_initial_asset_weight(&market, 1_000_000).unwrap();
assert_eq!(result, 4_000);
}
#[test]
fn to_amm_precision_decimals_6() {
let result = to_amm_precision(1_000_000, 6).unwrap();
assert_eq!(result, 1_000_000_000);
}
#[test]
fn to_amm_precision_decimals_9() {
let result = to_amm_precision(1_000_000_000, 9).unwrap();
assert_eq!(result, 1_000_000_000);
}
#[test]
fn to_amm_precision_decimals_18() {
let result = to_amm_precision(1_000_000_000_000_000_000, 18).unwrap();
assert_eq!(result, 1_000_000_000);
}
}
#[cfg(test)]
#[allow(
clippy::allow_attributes,
clippy::allow_attributes_without_reason,
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::arithmetic_side_effects,
reason = "test code"
)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn isqrt_correct(n in 0u128..=1_000_000_000_000_000_000u128) {
let root = isqrt(n);
prop_assert!(root.checked_mul(root).unwrap() <= n);
let next = root + 1;
prop_assert!(next.checked_mul(next).unwrap() > n);
}
#[test]
fn size_discount_weight_le_base(
size in 0u128..=1_000_000_000_000_000_000u128,
imf in 0u32..=100_000u32,
base_weight in 1u128..=20_000u128,
) {
let result = calculate_size_discount_asset_weight(size, imf, base_weight).unwrap();
prop_assert!(result <= base_weight, "discount weight {} > base {}", result, base_weight);
}
#[test]
fn size_premium_weight_ge_base(
size in 0u128..=1_000_000_000_000_000_000u128,
imf in 0u32..=100_000u32,
base_weight in 5u128..=20_000u128,
) {
let result = calculate_size_premium_liability_weight(size, imf, base_weight).unwrap();
prop_assert!(result >= base_weight, "premium weight {} < base {}", result, base_weight);
}
#[test]
fn strict_price_asset_le_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
let result = get_strict_price(price, twap, true);
prop_assert!(result <= price);
prop_assert!(result <= twap as u64);
}
#[test]
fn strict_price_liability_ge_oracle(price in 1u64..=u64::MAX / 2, twap in 1i64..=i64::MAX / 2) {
let result = get_strict_price(price, twap, false);
prop_assert!(result >= price && result >= twap as u64);
}
}
}