Skip to main content

riptide_amm_math/
guards.rs

1#[cfg(feature = "wasm")]
2use riptide_amm_macros::wasm_expose;
3
4use super::{
5    deviation_per_m, error::ARITHMETIC_OVERFLOW, Price, PER_CENT_DENOMINATOR, PER_M_DENOMINATOR,
6};
7
8pub type GuardError = &'static str;
9
10#[cfg_attr(feature = "wasm", wasm_expose)]
11pub const ORACLE_EXPIRED: GuardError = "oracle expired";
12
13#[cfg_attr(feature = "wasm", wasm_expose)]
14pub const INVENTORY_IMBALANCE: GuardError = "inventory imbalance";
15
16#[cfg_attr(feature = "wasm", wasm_expose)]
17pub const INVENTORY_A_SIDE_EXCEEDED: GuardError = "A-side inventory cap exceeded";
18
19#[cfg_attr(feature = "wasm", wasm_expose)]
20pub const INVENTORY_B_SIDE_EXCEEDED: GuardError = "B-side inventory cap exceeded";
21
22#[cfg_attr(feature = "wasm", wasm_expose)]
23pub const SPREAD_BELOW_MIN: GuardError = "spread below minimum";
24
25#[cfg_attr(feature = "wasm", wasm_expose)]
26pub const ORACLE_PRICE_BELOW_MIN: GuardError = "oracle price below minimum";
27
28#[cfg_attr(feature = "wasm", wasm_expose)]
29pub const ORACLE_PRICE_ABOVE_MAX: GuardError = "oracle price above maximum";
30
31#[derive(Debug, Clone, Copy, Eq, PartialEq)]
32#[cfg_attr(feature = "wasm", wasm_expose)]
33pub struct GuardParams {
34    pub max_inventory_imbalance_per_m: i32,
35    pub max_a_inventory_per_m: u32,
36    pub max_b_inventory_per_m: u32,
37    pub min_spread_per_m: i32,
38    pub min_oracle_price: u128,
39    pub max_oracle_price: u128,
40    pub valid_until: u64,
41}
42
43impl GuardParams {
44    pub fn from_market_fields(
45        max_inventory_imbalance_guard_per_cent: u8,
46        max_a_inventory_per_m: u32,
47        max_b_inventory_per_m: u32,
48        min_spread_guard_per_m: i32,
49        min_oracle_price_guard: u128,
50        max_oracle_price_guard: u128,
51        valid_until: u64,
52    ) -> Self {
53        Self {
54            max_inventory_imbalance_per_m: max_inventory_imbalance_guard_per_cent as i32
55                * (PER_M_DENOMINATOR / PER_CENT_DENOMINATOR as i32),
56            max_a_inventory_per_m,
57            max_b_inventory_per_m,
58            min_spread_per_m: min_spread_guard_per_m,
59            min_oracle_price: min_oracle_price_guard,
60            max_oracle_price: max_oracle_price_guard,
61            valid_until,
62        }
63    }
64}
65
66fn inventory_imbalance_guard(
67    reserves_a: u64,
68    reserves_b: u64,
69    price: &Price,
70    params: &GuardParams,
71) -> Result<(), GuardError> {
72    #[allow(clippy::useless_conversion)] // `U128` differs under the `wasm` feature.
73    let signed_imbalance =
74        deviation_per_m(price.oracle_price_q64_64.into(), reserves_a, reserves_b)
75            .map_err(|_| ARITHMETIC_OVERFLOW)?;
76    let imbalance_per_m = signed_imbalance.abs();
77
78    if imbalance_per_m > params.max_inventory_imbalance_per_m {
79        return Err(INVENTORY_IMBALANCE);
80    }
81
82    let a_inventory_per_m = signed_imbalance;
83    let b_inventory_per_m = -signed_imbalance;
84
85    if params.max_a_inventory_per_m > 0 && a_inventory_per_m > params.max_a_inventory_per_m as i32 {
86        return Err(INVENTORY_A_SIDE_EXCEEDED);
87    }
88
89    if params.max_b_inventory_per_m > 0 && b_inventory_per_m > params.max_b_inventory_per_m as i32 {
90        return Err(INVENTORY_B_SIDE_EXCEEDED);
91    }
92
93    Ok(())
94}
95
96fn spread_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
97    if price.spread_per_m < params.min_spread_per_m {
98        return Err(SPREAD_BELOW_MIN);
99    }
100
101    Ok(())
102}
103
104fn prices_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
105    if price.oracle_price_q64_64 < params.min_oracle_price {
106        return Err(ORACLE_PRICE_BELOW_MIN);
107    }
108
109    if price.oracle_price_q64_64 > params.max_oracle_price {
110        return Err(ORACLE_PRICE_ABOVE_MAX);
111    }
112
113    Ok(())
114}
115
116pub fn check_guards(
117    reserves_a: u64,
118    reserves_b: u64,
119    price: &Price,
120    params: &GuardParams,
121) -> Result<(), GuardError> {
122    inventory_imbalance_guard(reserves_a, reserves_b, price, params)?;
123    spread_guard(price, params)?;
124    prices_guard(price, params)?;
125
126    Ok(())
127}
128
129pub fn check_oracle_validity(current_slot: u64, valid_until: u64) -> Result<(), GuardError> {
130    if current_slot > valid_until {
131        return Err(ORACLE_EXPIRED);
132    }
133
134    Ok(())
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use rstest::rstest;
141
142    fn make_params(
143        max_inventory_imbalance_per_cent: u8,
144        max_a_inventory_per_m: u32,
145        max_b_inventory_per_m: u32,
146    ) -> GuardParams {
147        GuardParams {
148            max_inventory_imbalance_per_m: max_inventory_imbalance_per_cent as i32 * 10_000,
149            max_a_inventory_per_m,
150            max_b_inventory_per_m,
151            min_spread_per_m: 0,
152            min_oracle_price: 0,
153            max_oracle_price: u128::MAX,
154            valid_until: 0,
155        }
156    }
157
158    fn make_price(oracle_price_q64_64: u128) -> Price {
159        Price {
160            oracle_price_q64_64,
161            ..Default::default()
162        }
163    }
164
165    #[rstest]
166    #[case(1000, 2000, Ok(()))]
167    #[case(2000, 2000, Ok(()))]
168    #[case(2001, 2000, Err(ORACLE_EXPIRED))]
169    #[case(0, 0, Ok(()))]
170    #[case(1, 0, Err(ORACLE_EXPIRED))]
171    #[case(u64::MAX, u64::MAX, Ok(()))]
172    fn test_check_oracle_validity(
173        #[case] current_slot: u64,
174        #[case] valid_until: u64,
175        #[case] expected: Result<(), GuardError>,
176    ) {
177        assert_eq!(check_oracle_validity(current_slot, valid_until), expected);
178    }
179
180    #[rstest]
181    #[case(1000, 1000, 100, true)]
182    #[case(500, 1000, 100, true)]
183    #[case(1000, 500, 100, true)]
184    #[case(0, 2000, 100, true)]
185    #[case(2000, 0, 100, true)]
186    #[case(1000, 1000, 34, true)]
187    #[case(500, 1000, 34, true)]
188    #[case(1000, 500, 34, true)]
189    #[case(0, 2000, 34, false)]
190    #[case(2000, 0, 34, false)]
191    #[case(1000, 1000, 33, true)]
192    #[case(500, 1000, 33, false)]
193    #[case(1000, 500, 33, false)]
194    #[case(0, 2000, 33, false)]
195    #[case(2000, 0, 33, false)]
196    #[case(1000, 1000, 0, true)]
197    #[case(500, 1000, 0, false)]
198    #[case(1000, 500, 0, false)]
199    #[case(0, 2000, 0, false)]
200    #[case(2000, 0, 0, false)]
201    fn test_inventory_imbalance_guard_symmetric(
202        #[case] reserves_a: u64,
203        #[case] reserves_b: u64,
204        #[case] max_inventory_imbalance_per_cent: u8,
205        #[case] expected_ok: bool,
206    ) {
207        let params = make_params(max_inventory_imbalance_per_cent, 0, 0);
208        let price = make_price(1 << 64);
209
210        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);
211
212        assert_eq!(result.is_ok(), expected_ok);
213    }
214
215    #[rstest]
216    #[case(2u128 << 64, 500, 1000)]
217    #[case(1u128 << 63, 2000, 1000)]
218    #[case(4u128 << 64, 250, 1000)]
219    fn balanced_market_with_non_unity_price_does_not_trigger(
220        #[case] oracle_price_q64_64: u128,
221        #[case] reserves_a: u64,
222        #[case] reserves_b: u64,
223    ) {
224        let params = make_params(1, 0, 0);
225        let price = make_price(oracle_price_q64_64);
226
227        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);
228
229        assert!(
230            result.is_ok(),
231            "balanced market (price={}, a={}, b={}) should not trigger",
232            oracle_price_q64_64,
233            reserves_a,
234            reserves_b
235        );
236    }
237
238    #[rstest]
239    #[case(1500, 500, 0, 0, true)]
240    #[case(500, 1500, 0, 0, true)]
241    #[case(1500, 500, 400_000, 0, false)]
242    #[case(1500, 500, 600_000, 0, true)]
243    #[case(500, 1500, 0, 400_000, false)]
244    #[case(500, 1500, 0, 600_000, true)]
245    #[case(500, 1500, 100_000, 0, true)]
246    #[case(1500, 500, 0, 100_000, true)]
247    #[case(1000, 1000, 1, 1, true)]
248    fn test_inventory_directional_caps(
249        #[case] reserves_a: u64,
250        #[case] reserves_b: u64,
251        #[case] max_a_inventory_per_m: u32,
252        #[case] max_b_inventory_per_m: u32,
253        #[case] expected_ok: bool,
254    ) {
255        let params = make_params(100, max_a_inventory_per_m, max_b_inventory_per_m);
256        let price = make_price(1 << 64);
257
258        let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, &params);
259
260        assert_eq!(result.is_ok(), expected_ok);
261    }
262
263    #[rstest]
264    #[case(-10, -20, false)]
265    #[case(-10, 0, true)]
266    #[case(-10, 10, true)]
267    #[case(-10, 20, true)]
268    #[case(0, -20, false)]
269    #[case(0, -10, false)]
270    #[case(0, 0, true)]
271    #[case(0, 10, true)]
272    #[case(0, 20, true)]
273    #[case(10, -20, false)]
274    #[case(10, -10, false)]
275    #[case(10, -0, false)]
276    #[case(10, 10, true)]
277    #[case(10, 20, true)]
278    #[case(20, -20, false)]
279    #[case(20, -10, false)]
280    #[case(20, 0, false)]
281    #[case(20, 10, false)]
282    #[case(20, 20, true)]
283    fn test_spread_guard(
284        #[case] min_spread_per_m: i32,
285        #[case] spread_per_m: i32,
286        #[case] expected_ok: bool,
287    ) {
288        let params = GuardParams {
289            min_spread_per_m,
290            ..make_params(0, 0, 0)
291        };
292        let price = Price {
293            spread_per_m,
294            oracle_price_q64_64: 1 << 64,
295            ..Default::default()
296        };
297
298        let result = spread_guard(&price, &params);
299
300        assert_eq!(result.is_ok(), expected_ok);
301    }
302
303    #[rstest]
304    #[case(100, true)]
305    #[case(50, true)]
306    #[case(150, true)]
307    #[case(49, false)]
308    #[case(151, false)]
309    fn test_prices_guard(#[case] oracle_price: u128, #[case] expected_ok: bool) {
310        let params = GuardParams {
311            min_oracle_price: 50,
312            max_oracle_price: 150,
313            ..make_params(0, 0, 0)
314        };
315        let price = make_price(oracle_price);
316
317        let result = prices_guard(&price, &params);
318
319        assert_eq!(result.is_ok(), expected_ok);
320    }
321
322    #[rstest]
323    #[case::all_pass(
324        GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
325        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
326        1000,
327        1000,
328        Ok(()),
329    )]
330    #[case::inventory_fail(
331        GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
332        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
333        2000,
334        0,
335        Err(INVENTORY_IMBALANCE),
336    )]
337    #[case::spread_fail(
338        GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
339        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
340        1000,
341        1000,
342        Err(SPREAD_BELOW_MIN),
343    )]
344    #[case::price_below_min_fail(
345        GuardParams { min_oracle_price: 100, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
346        Price { oracle_price_q64_64: 50, best_price_q64_64: 50, spread_per_m: 0 },
347        1000,
348        1000,
349        Err(ORACLE_PRICE_BELOW_MIN),
350    )]
351    #[case::order_inventory_first(
352        GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
353        Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
354        2000,
355        0,
356        Err(INVENTORY_IMBALANCE),
357    )]
358    fn test_check_guards(
359        #[case] params: GuardParams,
360        #[case] price: Price,
361        #[case] reserves_a: u64,
362        #[case] reserves_b: u64,
363        #[case] expected: Result<(), GuardError>,
364    ) {
365        let result = check_guards(reserves_a, reserves_b, &price, &params);
366
367        assert_eq!(result, expected);
368    }
369}