1#[cfg(feature = "wasm")]
2use riptide_amm_macros::wasm_expose;
3
4use super::{
5 deviation_per_m, error::ARITHMETIC_OVERFLOW, Price, PER_CENT_DENOMINATOR, PER_M_DENOMINATOR,
6};
7
8pub type GuardError = &'static str;
9
10#[cfg_attr(feature = "wasm", wasm_expose)]
11pub const ORACLE_EXPIRED: GuardError = "oracle expired";
12
13#[cfg_attr(feature = "wasm", wasm_expose)]
14pub const INVENTORY_IMBALANCE: GuardError = "inventory imbalance";
15
16#[cfg_attr(feature = "wasm", wasm_expose)]
17pub const INVENTORY_A_SIDE_EXCEEDED: GuardError = "A-side inventory cap exceeded";
18
19#[cfg_attr(feature = "wasm", wasm_expose)]
20pub const INVENTORY_B_SIDE_EXCEEDED: GuardError = "B-side inventory cap exceeded";
21
22#[cfg_attr(feature = "wasm", wasm_expose)]
23pub const SPREAD_BELOW_MIN: GuardError = "spread below minimum";
24
25#[cfg_attr(feature = "wasm", wasm_expose)]
26pub const ORACLE_PRICE_BELOW_MIN: GuardError = "oracle price below minimum";
27
28#[cfg_attr(feature = "wasm", wasm_expose)]
29pub const ORACLE_PRICE_ABOVE_MAX: GuardError = "oracle price above maximum";
30
31#[derive(Debug, Clone, Copy, Eq, PartialEq)]
32#[cfg_attr(feature = "wasm", wasm_expose)]
33pub struct GuardParams {
34 pub max_inventory_imbalance_per_m: i32,
35 pub max_a_inventory_per_m: u32,
36 pub max_b_inventory_per_m: u32,
37 pub min_spread_per_m: i32,
38 pub min_oracle_price: u128,
39 pub max_oracle_price: u128,
40 pub valid_until: u64,
41}
42
43impl GuardParams {
44 pub fn from_market_fields(
45 max_inventory_imbalance_guard_per_cent: u8,
46 max_a_inventory_per_m: u32,
47 max_b_inventory_per_m: u32,
48 min_spread_guard_per_m: i32,
49 min_oracle_price_guard: u128,
50 max_oracle_price_guard: u128,
51 valid_until: u64,
52 ) -> Self {
53 Self {
54 max_inventory_imbalance_per_m: max_inventory_imbalance_guard_per_cent as i32
55 * (PER_M_DENOMINATOR / PER_CENT_DENOMINATOR as i32),
56 max_a_inventory_per_m,
57 max_b_inventory_per_m,
58 min_spread_per_m: min_spread_guard_per_m,
59 min_oracle_price: min_oracle_price_guard,
60 max_oracle_price: max_oracle_price_guard,
61 valid_until,
62 }
63 }
64}
65
66fn inventory_imbalance_guard(
67 reserves_a: u64,
68 reserves_b: u64,
69 price: &Price,
70 params: &GuardParams,
71) -> Result<(), GuardError> {
72 #[allow(clippy::useless_conversion)] let signed_imbalance =
74 deviation_per_m(price.oracle_price_q64_64.into(), reserves_a, reserves_b)
75 .map_err(|_| ARITHMETIC_OVERFLOW)?;
76 let imbalance_per_m = signed_imbalance.abs();
77
78 if imbalance_per_m > params.max_inventory_imbalance_per_m {
79 return Err(INVENTORY_IMBALANCE);
80 }
81
82 let a_inventory_per_m = signed_imbalance;
83 let b_inventory_per_m = -signed_imbalance;
84
85 if params.max_a_inventory_per_m > 0 && a_inventory_per_m > params.max_a_inventory_per_m as i32 {
86 return Err(INVENTORY_A_SIDE_EXCEEDED);
87 }
88
89 if params.max_b_inventory_per_m > 0 && b_inventory_per_m > params.max_b_inventory_per_m as i32 {
90 return Err(INVENTORY_B_SIDE_EXCEEDED);
91 }
92
93 Ok(())
94}
95
96fn spread_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
97 if price.spread_per_m < params.min_spread_per_m {
98 return Err(SPREAD_BELOW_MIN);
99 }
100
101 Ok(())
102}
103
104fn prices_guard(price: &Price, params: &GuardParams) -> Result<(), GuardError> {
105 if price.oracle_price_q64_64 < params.min_oracle_price {
106 return Err(ORACLE_PRICE_BELOW_MIN);
107 }
108
109 if price.oracle_price_q64_64 > params.max_oracle_price {
110 return Err(ORACLE_PRICE_ABOVE_MAX);
111 }
112
113 Ok(())
114}
115
116pub fn check_guards(
117 reserves_a: u64,
118 reserves_b: u64,
119 price: &Price,
120 params: &GuardParams,
121) -> Result<(), GuardError> {
122 inventory_imbalance_guard(reserves_a, reserves_b, price, params)?;
123 spread_guard(price, params)?;
124 prices_guard(price, params)?;
125
126 Ok(())
127}
128
129pub fn check_oracle_validity(current_slot: u64, valid_until: u64) -> Result<(), GuardError> {
130 if current_slot > valid_until {
131 return Err(ORACLE_EXPIRED);
132 }
133
134 Ok(())
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use rstest::rstest;
141
142 fn make_params(
143 max_inventory_imbalance_per_cent: u8,
144 max_a_inventory_per_m: u32,
145 max_b_inventory_per_m: u32,
146 ) -> GuardParams {
147 GuardParams {
148 max_inventory_imbalance_per_m: max_inventory_imbalance_per_cent as i32 * 10_000,
149 max_a_inventory_per_m,
150 max_b_inventory_per_m,
151 min_spread_per_m: 0,
152 min_oracle_price: 0,
153 max_oracle_price: u128::MAX,
154 valid_until: 0,
155 }
156 }
157
158 fn make_price(oracle_price_q64_64: u128) -> Price {
159 Price {
160 oracle_price_q64_64,
161 ..Default::default()
162 }
163 }
164
165 #[rstest]
166 #[case(1000, 2000, Ok(()))]
167 #[case(2000, 2000, Ok(()))]
168 #[case(2001, 2000, Err(ORACLE_EXPIRED))]
169 #[case(0, 0, Ok(()))]
170 #[case(1, 0, Err(ORACLE_EXPIRED))]
171 #[case(u64::MAX, u64::MAX, Ok(()))]
172 fn test_check_oracle_validity(
173 #[case] current_slot: u64,
174 #[case] valid_until: u64,
175 #[case] expected: Result<(), GuardError>,
176 ) {
177 assert_eq!(check_oracle_validity(current_slot, valid_until), expected);
178 }
179
180 #[rstest]
181 #[case(1000, 1000, 100, true)]
182 #[case(500, 1000, 100, true)]
183 #[case(1000, 500, 100, true)]
184 #[case(0, 2000, 100, true)]
185 #[case(2000, 0, 100, true)]
186 #[case(1000, 1000, 34, true)]
187 #[case(500, 1000, 34, true)]
188 #[case(1000, 500, 34, true)]
189 #[case(0, 2000, 34, false)]
190 #[case(2000, 0, 34, false)]
191 #[case(1000, 1000, 33, true)]
192 #[case(500, 1000, 33, false)]
193 #[case(1000, 500, 33, false)]
194 #[case(0, 2000, 33, false)]
195 #[case(2000, 0, 33, false)]
196 #[case(1000, 1000, 0, true)]
197 #[case(500, 1000, 0, false)]
198 #[case(1000, 500, 0, false)]
199 #[case(0, 2000, 0, false)]
200 #[case(2000, 0, 0, false)]
201 fn test_inventory_imbalance_guard_symmetric(
202 #[case] reserves_a: u64,
203 #[case] reserves_b: u64,
204 #[case] max_inventory_imbalance_per_cent: u8,
205 #[case] expected_ok: bool,
206 ) {
207 let params = make_params(max_inventory_imbalance_per_cent, 0, 0);
208 let price = make_price(1 << 64);
209
210 let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
211
212 assert_eq!(result.is_ok(), expected_ok);
213 }
214
215 #[rstest]
216 #[case(2u128 << 64, 500, 1000)]
217 #[case(1u128 << 63, 2000, 1000)]
218 #[case(4u128 << 64, 250, 1000)]
219 fn balanced_market_with_non_unity_price_does_not_trigger(
220 #[case] oracle_price_q64_64: u128,
221 #[case] reserves_a: u64,
222 #[case] reserves_b: u64,
223 ) {
224 let params = make_params(1, 0, 0);
225 let price = make_price(oracle_price_q64_64);
226
227 let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
228
229 assert!(
230 result.is_ok(),
231 "balanced market (price={}, a={}, b={}) should not trigger",
232 oracle_price_q64_64,
233 reserves_a,
234 reserves_b
235 );
236 }
237
238 #[rstest]
239 #[case(1500, 500, 0, 0, true)]
240 #[case(500, 1500, 0, 0, true)]
241 #[case(1500, 500, 400_000, 0, false)]
242 #[case(1500, 500, 600_000, 0, true)]
243 #[case(500, 1500, 0, 400_000, false)]
244 #[case(500, 1500, 0, 600_000, true)]
245 #[case(500, 1500, 100_000, 0, true)]
246 #[case(1500, 500, 0, 100_000, true)]
247 #[case(1000, 1000, 1, 1, true)]
248 fn test_inventory_directional_caps(
249 #[case] reserves_a: u64,
250 #[case] reserves_b: u64,
251 #[case] max_a_inventory_per_m: u32,
252 #[case] max_b_inventory_per_m: u32,
253 #[case] expected_ok: bool,
254 ) {
255 let params = make_params(100, max_a_inventory_per_m, max_b_inventory_per_m);
256 let price = make_price(1 << 64);
257
258 let result = inventory_imbalance_guard(reserves_a, reserves_b, &price, ¶ms);
259
260 assert_eq!(result.is_ok(), expected_ok);
261 }
262
263 #[rstest]
264 #[case(-10, -20, false)]
265 #[case(-10, 0, true)]
266 #[case(-10, 10, true)]
267 #[case(-10, 20, true)]
268 #[case(0, -20, false)]
269 #[case(0, -10, false)]
270 #[case(0, 0, true)]
271 #[case(0, 10, true)]
272 #[case(0, 20, true)]
273 #[case(10, -20, false)]
274 #[case(10, -10, false)]
275 #[case(10, -0, false)]
276 #[case(10, 10, true)]
277 #[case(10, 20, true)]
278 #[case(20, -20, false)]
279 #[case(20, -10, false)]
280 #[case(20, 0, false)]
281 #[case(20, 10, false)]
282 #[case(20, 20, true)]
283 fn test_spread_guard(
284 #[case] min_spread_per_m: i32,
285 #[case] spread_per_m: i32,
286 #[case] expected_ok: bool,
287 ) {
288 let params = GuardParams {
289 min_spread_per_m,
290 ..make_params(0, 0, 0)
291 };
292 let price = Price {
293 spread_per_m,
294 oracle_price_q64_64: 1 << 64,
295 ..Default::default()
296 };
297
298 let result = spread_guard(&price, ¶ms);
299
300 assert_eq!(result.is_ok(), expected_ok);
301 }
302
303 #[rstest]
304 #[case(100, true)]
305 #[case(50, true)]
306 #[case(150, true)]
307 #[case(49, false)]
308 #[case(151, false)]
309 fn test_prices_guard(#[case] oracle_price: u128, #[case] expected_ok: bool) {
310 let params = GuardParams {
311 min_oracle_price: 50,
312 max_oracle_price: 150,
313 ..make_params(0, 0, 0)
314 };
315 let price = make_price(oracle_price);
316
317 let result = prices_guard(&price, ¶ms);
318
319 assert_eq!(result.is_ok(), expected_ok);
320 }
321
322 #[rstest]
323 #[case::all_pass(
324 GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
325 Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
326 1000,
327 1000,
328 Ok(()),
329 )]
330 #[case::inventory_fail(
331 GuardParams { min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
332 Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 0 },
333 2000,
334 0,
335 Err(INVENTORY_IMBALANCE),
336 )]
337 #[case::spread_fail(
338 GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
339 Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
340 1000,
341 1000,
342 Err(SPREAD_BELOW_MIN),
343 )]
344 #[case::price_below_min_fail(
345 GuardParams { min_oracle_price: 100, max_oracle_price: u128::MAX, ..make_params(100, 0, 0) },
346 Price { oracle_price_q64_64: 50, best_price_q64_64: 50, spread_per_m: 0 },
347 1000,
348 1000,
349 Err(ORACLE_PRICE_BELOW_MIN),
350 )]
351 #[case::order_inventory_first(
352 GuardParams { min_spread_per_m: 100, min_oracle_price: 0, max_oracle_price: u128::MAX, ..make_params(10, 0, 0) },
353 Price { oracle_price_q64_64: 1 << 64, best_price_q64_64: 1 << 64, spread_per_m: 50 },
354 2000,
355 0,
356 Err(INVENTORY_IMBALANCE),
357 )]
358 fn test_check_guards(
359 #[case] params: GuardParams,
360 #[case] price: Price,
361 #[case] reserves_a: u64,
362 #[case] reserves_b: u64,
363 #[case] expected: Result<(), GuardError>,
364 ) {
365 let result = check_guards(reserves_a, reserves_b, &price, ¶ms);
366
367 assert_eq!(result, expected);
368 }
369}