riptide-amm-math 2.0.1

The Riptide program math library
Documentation
#[cfg(feature = "wasm")]
use riptide_amm_macros::wasm_expose;

use super::{
    deviation_per_m, error::ARITHMETIC_OVERFLOW, Price, PER_CENT_DENOMINATOR, PER_M_DENOMINATOR,
};

pub type GuardError = &'static str;

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_EXPIRED: GuardError = "oracle expired";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_IMBALANCE: GuardError = "inventory imbalance";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_A_SIDE_EXCEEDED: GuardError = "A-side inventory cap exceeded";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_B_SIDE_EXCEEDED: GuardError = "B-side inventory cap exceeded";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const SPREAD_BELOW_MIN: GuardError = "spread below minimum";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_PRICE_BELOW_MIN: GuardError = "oracle price below minimum";

#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_PRICE_ABOVE_MAX: GuardError = "oracle price above maximum";

#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[cfg_attr(feature = "wasm", wasm_expose)]
pub struct GuardParams {
    pub max_inventory_imbalance_per_m: i32,
    pub max_a_inventory_per_m: u32,
    pub max_b_inventory_per_m: u32,
    pub min_spread_per_m: i32,
    pub min_oracle_price: u128,
    pub max_oracle_price: u128,
    pub valid_until: u64,
}

impl GuardParams {
    pub fn from_market_fields(
        max_inventory_imbalance_guard_per_cent: u8,
        max_a_inventory_per_m: u32,
        max_b_inventory_per_m: u32,
        min_spread_guard_per_m: i32,
        min_oracle_price_guard: u128,
        max_oracle_price_guard: u128,
        valid_until: u64,
    ) -> Self {
        Self {
            max_inventory_imbalance_per_m: max_inventory_imbalance_guard_per_cent as i32
                * (PER_M_DENOMINATOR / PER_CENT_DENOMINATOR as i32),
            max_a_inventory_per_m,
            max_b_inventory_per_m,
            min_spread_per_m: min_spread_guard_per_m,
            min_oracle_price: min_oracle_price_guard,
            max_oracle_price: max_oracle_price_guard,
            valid_until,
        }
    }
}

fn inventory_imbalance_guard(
    reserves_a: u64,
    reserves_b: u64,
    price: &Price,
    params: &GuardParams,
) -> Result<(), GuardError> {
    #[allow(clippy::useless_conversion)] // `U128` differs under the `wasm` feature.
    let signed_imbalance =
        deviation_per_m(price.oracle_price_q64_64.into(), reserves_a, reserves_b)
            .map_err(|_| ARITHMETIC_OVERFLOW)?;
    let imbalance_per_m = signed_imbalance.abs();

    if imbalance_per_m > params.max_inventory_imbalance_per_m {
        return Err(INVENTORY_IMBALANCE);
    }

    let a_inventory_per_m = signed_imbalance;
    let b_inventory_per_m = -signed_imbalance;

    if params.max_a_inventory_per_m > 0 && a_inventory_per_m > params.max_a_inventory_per_m as i32 {
        return Err(INVENTORY_A_SIDE_EXCEEDED);
    }

    if params.max_b_inventory_per_m > 0 && b_inventory_per_m > params.max_b_inventory_per_m as i32 {
        return Err(INVENTORY_B_SIDE_EXCEEDED);
    }

    Ok(())
}

fn spread_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
    if price.spread_per_m < params.min_spread_per_m {
        return Err(SPREAD_BELOW_MIN);
    }

    Ok(())
}

fn prices_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
    if price.oracle_price_q64_64 < params.min_oracle_price {
        return Err(ORACLE_PRICE_BELOW_MIN);
    }

    if price.oracle_price_q64_64 > params.max_oracle_price {
        return Err(ORACLE_PRICE_ABOVE_MAX);
    }

    Ok(())
}

pub fn check_guards(
    reserves_a: u64,
    reserves_b: u64,
    price: &Price,
    params: &GuardParams,
) -> Result<(), GuardError> {
    inventory_imbalance_guard(reserves_a, reserves_b, price, params)?;
    spread_guard(price, params)?;
    prices_guard(price, params)?;

    Ok(())
}

pub fn check_oracle_validity(current_slot: u64, valid_until: u64) -> Result<(), GuardError> {
    if current_slot > valid_until {
        return Err(ORACLE_EXPIRED);
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;
    use rstest::rstest;

    fn make_params(
        max_inventory_imbalance_per_cent: u8,
        max_a_inventory_per_m: u32,
        max_b_inventory_per_m: u32,
    ) -> GuardParams {
        GuardParams {
            max_inventory_imbalance_per_m: max_inventory_imbalance_per_cent as i32 * 10_000,
            max_a_inventory_per_m,
            max_b_inventory_per_m,
            min_spread_per_m: 0,
            min_oracle_price: 0,
            max_oracle_price: u128::MAX,
            valid_until: 0,
        }
    }

    fn make_price(oracle_price_q64_64: u128) -> Price {
        Price {
            oracle_price_q64_64,
            ..Default::default()
        }
    }

    #[rstest]
    #[case(1000, 2000, Ok(()))]
    #[case(2000, 2000, Ok(()))]
    #[case(2001, 2000, Err(ORACLE_EXPIRED))]
    #[case(0, 0, Ok(()))]
    #[case(1, 0, Err(ORACLE_EXPIRED))]
    #[case(u64::MAX, u64::MAX, Ok(()))]
    fn test_check_oracle_validity(
        #[case] current_slot: u64,
        #[case] valid_until: u64,
        #[case] expected: Result<(), GuardError>,
    ) {
        assert_eq!(check_oracle_validity(current_slot, valid_until), expected);
    }

    #[rstest]
    #[case(1000, 1000, 100, true)]
    #[case(500, 1000, 100, true)]
    #[case(1000, 500, 100, true)]
    #[case(0, 2000, 100, true)]
    #[case(2000, 0, 100, true)]
    #[case(1000, 1000, 34, true)]
    #[case(500, 1000, 34, true)]
    #[case(1000, 500, 34, true)]
    #[case(0, 2000, 34, false)]
    #[case(2000, 0, 34, false)]
    #[case(1000, 1000, 33, true)]
    #[case(500, 1000, 33, false)]
    #[case(1000, 500, 33, false)]
    #[case(0, 2000, 33, false)]
    #[case(2000, 0, 33, false)]
    #[case(1000, 1000, 0, true)]
    #[case(500, 1000, 0, false)]
    #[case(1000, 500, 0, false)]
    #[case(0, 2000, 0, false)]
    #[case(2000, 0, 0, false)]
    fn test_inventory_imbalance_guard_symmetric(
        #[case] reserves_a: u64,
        #[case] reserves_b: u64,
        #[case] max_inventory_imbalance_per_cent: u8,
        #[case] expected_ok: bool,
    ) {
        let params = make_params(max_inventory_imbalance_per_cent, 0, 0);
        let price = make_price(1 << 64);

        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);

        assert_eq!(result.is_ok(), expected_ok);
    }

    #[rstest]
    #[case(2u128 << 64, 500, 1000)]
    #[case(1u128 << 63, 2000, 1000)]
    #[case(4u128 << 64, 250, 1000)]
    fn balanced_market_with_non_unity_price_does_not_trigger(
        #[case] oracle_price_q64_64: u128,
        #[case] reserves_a: u64,
        #[case] reserves_b: u64,
    ) {
        let params = make_params(1, 0, 0);
        let price = make_price(oracle_price_q64_64);

        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);

        assert!(
            result.is_ok(),
            "balanced market (price={}, a={}, b={}) should not trigger",
            oracle_price_q64_64,
            reserves_a,
            reserves_b
        );
    }

    #[rstest]
    #[case(1500, 500, 0, 0, true)]
    #[case(500, 1500, 0, 0, true)]
    #[case(1500, 500, 400_000, 0, false)]
    #[case(1500, 500, 600_000, 0, true)]
    #[case(500, 1500, 0, 400_000, false)]
    #[case(500, 1500, 0, 600_000, true)]
    #[case(500, 1500, 100_000, 0, true)]
    #[case(1500, 500, 0, 100_000, true)]
    #[case(1000, 1000, 1, 1, true)]
    fn test_inventory_directional_caps(
        #[case] reserves_a: u64,
        #[case] reserves_b: u64,
        #[case] max_a_inventory_per_m: u32,
        #[case] max_b_inventory_per_m: u32,
        #[case] expected_ok: bool,
    ) {
        let params = make_params(100, max_a_inventory_per_m, max_b_inventory_per_m);
        let price = make_price(1 << 64);

        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);

        assert_eq!(result.is_ok(), expected_ok);
    }

    #[rstest]
    #[case(-10, -20, false)]
    #[case(-10, 0, true)]
    #[case(-10, 10, true)]
    #[case(-10, 20, true)]
    #[case(0, -20, false)]
    #[case(0, -10, false)]
    #[case(0, 0, true)]
    #[case(0, 10, true)]
    #[case(0, 20, true)]
    #[case(10, -20, false)]
    #[case(10, -10, false)]
    #[case(10, -0, false)]
    #[case(10, 10, true)]
    #[case(10, 20, true)]
    #[case(20, -20, false)]
    #[case(20, -10, false)]
    #[case(20, 0, false)]
    #[case(20, 10, false)]
    #[case(20, 20, true)]
    fn test_spread_guard(
        #[case] min_spread_per_m: i32,
        #[case] spread_per_m: i32,
        #[case] expected_ok: bool,
    ) {
        let params = GuardParams {
            min_spread_per_m,
            ..make_params(0, 0, 0)
        };
        let price = Price {
            spread_per_m,
            oracle_price_q64_64: 1 << 64,
            ..Default::default()
        };

        let result = spread_guard(&price, &params);

        assert_eq!(result.is_ok(), expected_ok);
    }

    #[rstest]
    #[case(100, true)]
    #[case(50, true)]
    #[case(150, true)]
    #[case(49, false)]
    #[case(151, false)]
    fn test_prices_guard(#[case] oracle_price: u128, #[case] expected_ok: bool) {
        let params = GuardParams {
            min_oracle_price: 50,
            max_oracle_price: 150,
            ..make_params(0, 0, 0)
        };
        let price = make_price(oracle_price);

        let result = prices_guard(&price, &params);

        assert_eq!(result.is_ok(), expected_ok);
    }

    #[rstest]
    #[case::all_pass(
        GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
        1000,
        1000,
        Ok(()),
    )]
    #[case::inventory_fail(
        GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
        2000,
        0,
        Err(INVENTORY_IMBALANCE),
    )]
    #[case::spread_fail(
        GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
        1000,
        1000,
        Err(SPREAD_BELOW_MIN),
    )]
    #[case::price_below_min_fail(
        GuardParams { min_oracle_price: 100, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
        Price { oracle_price_q64_64: 50, best_price_q64_64: 50, spread_per_m: 0 },
        1000,
        1000,
        Err(ORACLE_PRICE_BELOW_MIN),
    )]
    #[case::order_inventory_first(
        GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
        2000,
        0,
        Err(INVENTORY_IMBALANCE),
    )]
    fn test_check_guards(
        #[case] params: GuardParams,
        #[case] price: Price,
        #[case] reserves_a: u64,
        #[case] reserves_b: u64,
        #[case] expected: Result<(), GuardError>,
    ) {
        let result = check_guards(reserves_a, reserves_b, &price, &params);

        assert_eq!(result, expected);
    }
}