#[cfg(feature = "wasm")]
use riptide_amm_macros::wasm_expose;
use super::{
deviation_per_m, error::ARITHMETIC_OVERFLOW, Price, PER_CENT_DENOMINATOR, PER_M_DENOMINATOR,
};
pub type GuardError = &'static str;
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_EXPIRED: GuardError = "oracle expired";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_IMBALANCE: GuardError = "inventory imbalance";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_A_SIDE_EXCEEDED: GuardError = "A-side inventory cap exceeded";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const INVENTORY_B_SIDE_EXCEEDED: GuardError = "B-side inventory cap exceeded";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const SPREAD_BELOW_MIN: GuardError = "spread below minimum";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_PRICE_BELOW_MIN: GuardError = "oracle price below minimum";
#[cfg_attr(feature = "wasm", wasm_expose)]
pub const ORACLE_PRICE_ABOVE_MAX: GuardError = "oracle price above maximum";
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[cfg_attr(feature = "wasm", wasm_expose)]
pub struct GuardParams {
pub max_inventory_imbalance_per_m: i32,
pub max_a_inventory_per_m: u32,
pub max_b_inventory_per_m: u32,
pub min_spread_per_m: i32,
pub min_oracle_price: u128,
pub max_oracle_price: u128,
pub valid_until: u64,
}
impl GuardParams {
pub fn from_market_fields(
max_inventory_imbalance_guard_per_cent: u8,
max_a_inventory_per_m: u32,
max_b_inventory_per_m: u32,
min_spread_guard_per_m: i32,
min_oracle_price_guard: u128,
max_oracle_price_guard: u128,
valid_until: u64,
) -> Self {
Self {
max_inventory_imbalance_per_m: max_inventory_imbalance_guard_per_cent as i32
* (PER_M_DENOMINATOR / PER_CENT_DENOMINATOR as i32),
max_a_inventory_per_m,
max_b_inventory_per_m,
min_spread_per_m: min_spread_guard_per_m,
min_oracle_price: min_oracle_price_guard,
max_oracle_price: max_oracle_price_guard,
valid_until,
}
}
}
fn inventory_imbalance_guard(
reserves_a: u64,
reserves_b: u64,
price: &Price,
params: &GuardParams,
) -> Result<(), GuardError> {
#[allow(clippy::useless_conversion)] let signed_imbalance =
deviation_per_m(price.oracle_price_q64_64.into(), reserves_a, reserves_b)
.map_err(|_| ARITHMETIC_OVERFLOW)?;
let imbalance_per_m = signed_imbalance.abs();
if imbalance_per_m > params.max_inventory_imbalance_per_m {
return Err(INVENTORY_IMBALANCE);
}
let a_inventory_per_m = signed_imbalance;
let b_inventory_per_m = -signed_imbalance;
if params.max_a_inventory_per_m > 0 && a_inventory_per_m > params.max_a_inventory_per_m as i32 {
return Err(INVENTORY_A_SIDE_EXCEEDED);
}
if params.max_b_inventory_per_m > 0 && b_inventory_per_m > params.max_b_inventory_per_m as i32 {
return Err(INVENTORY_B_SIDE_EXCEEDED);
}
Ok(())
}
fn spread_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
if price.spread_per_m < params.min_spread_per_m {
return Err(SPREAD_BELOW_MIN);
}
Ok(())
}
fn prices_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
if price.oracle_price_q64_64 < params.min_oracle_price {
return Err(ORACLE_PRICE_BELOW_MIN);
}
if price.oracle_price_q64_64 > params.max_oracle_price {
return Err(ORACLE_PRICE_ABOVE_MAX);
}
Ok(())
}
pub fn check_guards(
reserves_a: u64,
reserves_b: u64,
price: &Price,
params: &GuardParams,
) -> Result<(), GuardError> {
inventory_imbalance_guard(reserves_a, reserves_b, price, params)?;
spread_guard(price, params)?;
prices_guard(price, params)?;
Ok(())
}
pub fn check_oracle_validity(current_slot: u64, valid_until: u64) -> Result<(), GuardError> {
if current_slot > valid_until {
return Err(ORACLE_EXPIRED);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
fn make_params(
max_inventory_imbalance_per_cent: u8,
max_a_inventory_per_m: u32,
max_b_inventory_per_m: u32,
) -> GuardParams {
GuardParams {
max_inventory_imbalance_per_m: max_inventory_imbalance_per_cent as i32 * 10_000,
max_a_inventory_per_m,
max_b_inventory_per_m,
min_spread_per_m: 0,
min_oracle_price: 0,
max_oracle_price: u128::MAX,
valid_until: 0,
}
}
fn make_price(oracle_price_q64_64: u128) -> Price {
Price {
oracle_price_q64_64,
..Default::default()
}
}
#[rstest]
#[case(1000, 2000, Ok(()))]
#[case(2000, 2000, Ok(()))]
#[case(2001, 2000, Err(ORACLE_EXPIRED))]
#[case(0, 0, Ok(()))]
#[case(1, 0, Err(ORACLE_EXPIRED))]
#[case(u64::MAX, u64::MAX, Ok(()))]
fn test_check_oracle_validity(
#[case] current_slot: u64,
#[case] valid_until: u64,
#[case] expected: Result<(), GuardError>,
) {
assert_eq!(check_oracle_validity(current_slot, valid_until), expected);
}
#[rstest]
#[case(1000, 1000, 100, true)]
#[case(500, 1000, 100, true)]
#[case(1000, 500, 100, true)]
#[case(0, 2000, 100, true)]
#[case(2000, 0, 100, true)]
#[case(1000, 1000, 34, true)]
#[case(500, 1000, 34, true)]
#[case(1000, 500, 34, true)]
#[case(0, 2000, 34, false)]
#[case(2000, 0, 34, false)]
#[case(1000, 1000, 33, true)]
#[case(500, 1000, 33, false)]
#[case(1000, 500, 33, false)]
#[case(0, 2000, 33, false)]
#[case(2000, 0, 33, false)]
#[case(1000, 1000, 0, true)]
#[case(500, 1000, 0, false)]
#[case(1000, 500, 0, false)]
#[case(0, 2000, 0, false)]
#[case(2000, 0, 0, false)]
fn test_inventory_imbalance_guard_symmetric(
#[case] reserves_a: u64,
#[case] reserves_b: u64,
#[case] max_inventory_imbalance_per_cent: u8,
#[case] expected_ok: bool,
) {
let params = make_params(max_inventory_imbalance_per_cent, 0, 0);
let price = make_price(1 << 64);
let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
assert_eq!(result.is_ok(), expected_ok);
}
#[rstest]
#[case(2u128 << 64, 500, 1000)]
#[case(1u128 << 63, 2000, 1000)]
#[case(4u128 << 64, 250, 1000)]
fn balanced_market_with_non_unity_price_does_not_trigger(
#[case] oracle_price_q64_64: u128,
#[case] reserves_a: u64,
#[case] reserves_b: u64,
) {
let params = make_params(1, 0, 0);
let price = make_price(oracle_price_q64_64);
let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
assert!(
result.is_ok(),
"balanced market (price={}, a={}, b={}) should not trigger",
oracle_price_q64_64,
reserves_a,
reserves_b
);
}
#[rstest]
#[case(1500, 500, 0, 0, true)]
#[case(500, 1500, 0, 0, true)]
#[case(1500, 500, 400_000, 0, false)]
#[case(1500, 500, 600_000, 0, true)]
#[case(500, 1500, 0, 400_000, false)]
#[case(500, 1500, 0, 600_000, true)]
#[case(500, 1500, 100_000, 0, true)]
#[case(1500, 500, 0, 100_000, true)]
#[case(1000, 1000, 1, 1, true)]
fn test_inventory_directional_caps(
#[case] reserves_a: u64,
#[case] reserves_b: u64,
#[case] max_a_inventory_per_m: u32,
#[case] max_b_inventory_per_m: u32,
#[case] expected_ok: bool,
) {
let params = make_params(100, max_a_inventory_per_m, max_b_inventory_per_m);
let price = make_price(1 << 64);
let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
assert_eq!(result.is_ok(), expected_ok);
}
#[rstest]
#[case(-10, -20, false)]
#[case(-10, 0, true)]
#[case(-10, 10, true)]
#[case(-10, 20, true)]
#[case(0, -20, false)]
#[case(0, -10, false)]
#[case(0, 0, true)]
#[case(0, 10, true)]
#[case(0, 20, true)]
#[case(10, -20, false)]
#[case(10, -10, false)]
#[case(10, -0, false)]
#[case(10, 10, true)]
#[case(10, 20, true)]
#[case(20, -20, false)]
#[case(20, -10, false)]
#[case(20, 0, false)]
#[case(20, 10, false)]
#[case(20, 20, true)]
fn test_spread_guard(
#[case] min_spread_per_m: i32,
#[case] spread_per_m: i32,
#[case] expected_ok: bool,
) {
let params = GuardParams {
min_spread_per_m,
..make_params(0, 0, 0)
};
let price = Price {
spread_per_m,
oracle_price_q64_64: 1 << 64,
..Default::default()
};
let result = spread_guard(&price, ¶ms);
assert_eq!(result.is_ok(), expected_ok);
}
#[rstest]
#[case(100, true)]
#[case(50, true)]
#[case(150, true)]
#[case(49, false)]
#[case(151, false)]
fn test_prices_guard(#[case] oracle_price: u128, #[case] expected_ok: bool) {
let params = GuardParams {
min_oracle_price: 50,
max_oracle_price: 150,
..make_params(0, 0, 0)
};
let price = make_price(oracle_price);
let result = prices_guard(&price, ¶ms);
assert_eq!(result.is_ok(), expected_ok);
}
#[rstest]
#[case::all_pass(
GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
1000,
1000,
Ok(()),
)]
#[case::inventory_fail(
GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
2000,
0,
Err(INVENTORY_IMBALANCE),
)]
#[case::spread_fail(
GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
1000,
1000,
Err(SPREAD_BELOW_MIN),
)]
#[case::price_below_min_fail(
GuardParams { min_oracle_price: 100, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
Price { oracle_price_q64_64: 50, best_price_q64_64: 50, spread_per_m: 0 },
1000,
1000,
Err(ORACLE_PRICE_BELOW_MIN),
)]
#[case::order_inventory_first(
GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
2000,
0,
Err(INVENTORY_IMBALANCE),
)]
fn test_check_guards(
#[case] params: GuardParams,
#[case] price: Price,
#[case] reserves_a: u64,
#[case] reserves_b: u64,
#[case] expected: Result<(), GuardError>,
) {
let result = check_guards(reserves_a, reserves_b, &price, ¶ms);
assert_eq!(result, expected);
}
}