aagt_core/trading/risk/
checks.rs1use super::{RiskCheck, RiskCheckResult, TradeContext};
4use rust_decimal::Decimal;
5use std::sync::Arc;
6
7pub struct MaxTradeAmountCheck {
9 max_amount: Decimal,
10}
11
12impl MaxTradeAmountCheck {
13 pub fn new(max_amount: Decimal) -> Self {
14 Self { max_amount }
15 }
16}
17
18impl RiskCheck for MaxTradeAmountCheck {
19 fn name(&self) -> &str {
20 "max_trade_amount"
21 }
22
23 fn check(&self, context: &TradeContext) -> RiskCheckResult {
24 if context.amount_usd > self.max_amount {
25 RiskCheckResult::Rejected {
26 reason: format!(
27 "Trade amount ${} exceeds maximum ${}",
28 context.amount_usd, self.max_amount
29 ),
30 }
31 } else {
32 RiskCheckResult::Approved
33 }
34 }
35}
36
37pub struct SlippageCheck {
39 max_slippage_percent: Decimal,
40}
41
42impl SlippageCheck {
43 pub fn new(max_slippage_percent: Decimal) -> Self {
44 Self {
45 max_slippage_percent,
46 }
47 }
48}
49
50impl RiskCheck for SlippageCheck {
51 fn name(&self) -> &str {
52 "slippage"
53 }
54
55 fn check(&self, context: &TradeContext) -> RiskCheckResult {
56 if context.expected_slippage > self.max_slippage_percent {
57 RiskCheckResult::Rejected {
58 reason: format!(
59 "Slippage {}% exceeds maximum {}%",
60 context.expected_slippage, self.max_slippage_percent
61 ),
62 }
63 } else {
64 RiskCheckResult::Approved
65 }
66 }
67}
68
69pub struct LiquidityCheck {
71 min_liquidity: Decimal,
72}
73
74impl LiquidityCheck {
75 pub fn new(min_liquidity: Decimal) -> Self {
76 Self { min_liquidity }
77 }
78}
79
80impl RiskCheck for LiquidityCheck {
81 fn name(&self) -> &str {
82 "liquidity"
83 }
84
85 fn check(&self, context: &TradeContext) -> RiskCheckResult {
86 match context.liquidity_usd {
87 Some(liq) if liq < self.min_liquidity => RiskCheckResult::Rejected {
88 reason: format!("Liquidity ${} below minimum ${}", liq, self.min_liquidity),
89 },
90 None => RiskCheckResult::PendingReview {
91 reason: "Liquidity data unavailable".to_string(),
92 },
93 _ => RiskCheckResult::Approved,
94 }
95 }
96}
97
98pub struct TokenSecurityCheck {
100 blacklist: Vec<String>,
101}
102
103impl TokenSecurityCheck {
104 pub fn new(blacklist: Vec<String>) -> Self {
105 Self { blacklist }
106 }
107}
108
109impl RiskCheck for TokenSecurityCheck {
110 fn name(&self) -> &str {
111 "token_security"
112 }
113
114 fn check(&self, context: &TradeContext) -> RiskCheckResult {
115 if context.is_flagged {
116 return RiskCheckResult::Rejected {
117 reason: "Token is flagged as risky".to_string(),
118 };
119 }
120
121 if self.blacklist.contains(&context.to_token) {
122 return RiskCheckResult::Rejected {
123 reason: format!("Token {} is blacklisted", context.to_token),
124 };
125 }
126
127 RiskCheckResult::Approved
128 }
129}
130
131pub struct CompositeCheck {
133 checks: Vec<Arc<dyn RiskCheck>>,
134 name: String,
135}
136
137impl CompositeCheck {
138 pub fn new(name: String, checks: Vec<Arc<dyn RiskCheck>>) -> Self {
139 Self { name, checks }
140 }
141}
142
143impl RiskCheck for CompositeCheck {
144 fn name(&self) -> &str {
145 &self.name
146 }
147
148 fn check(&self, context: &TradeContext) -> RiskCheckResult {
149 for check in &self.checks {
150 match check.check(context) {
151 RiskCheckResult::Approved => continue,
152 other => return other,
153 }
154 }
155 RiskCheckResult::Approved
156 }
157}
158
159pub struct RiskCheckBuilder {
161 checks: Vec<Arc<dyn RiskCheck>>,
162}
163
164impl RiskCheckBuilder {
165 pub fn new() -> Self {
166 Self { checks: Vec::new() }
167 }
168
169 pub fn add_check(mut self, check: Arc<dyn RiskCheck>) -> Self {
170 self.checks.push(check);
171 self
172 }
173
174 pub fn max_trade_amount(self, max: Decimal) -> Self {
175 self.add_check(Arc::new(MaxTradeAmountCheck::new(max)))
176 }
177
178 pub fn max_slippage(self, max_percent: Decimal) -> Self {
179 self.add_check(Arc::new(SlippageCheck::new(max_percent)))
180 }
181
182 pub fn min_liquidity(self, min: Decimal) -> Self {
183 self.add_check(Arc::new(LiquidityCheck::new(min)))
184 }
185
186 pub fn token_security(self, blacklist: Vec<String>) -> Self {
187 self.add_check(Arc::new(TokenSecurityCheck::new(blacklist)))
188 }
189
190 pub fn build(self) -> Vec<Arc<dyn RiskCheck>> {
191 self.checks
192 }
193
194 pub fn build_composite(self, name: String) -> Arc<dyn RiskCheck> {
195 Arc::new(CompositeCheck::new(name, self.checks))
196 }
197}
198
199impl Default for RiskCheckBuilder {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use rust_decimal_macros::dec;
209
210 #[test]
211 fn test_risk_check_builder() {
212 let checks = RiskCheckBuilder::new()
213 .max_trade_amount(dec!(1000.0))
214 .max_slippage(dec!(2.0))
215 .min_liquidity(dec!(100000.0))
216 .build();
217
218 assert_eq!(checks.len(), 3);
219
220 let context = TradeContext {
221 user_id: "test".to_string(),
222 from_token: "USDC".to_string(),
223 to_token: "SOL".to_string(),
224 amount_usd: dec!(500.0),
225 expected_slippage: dec!(1.0),
226 liquidity_usd: Some(dec!(200000.0)),
227 is_flagged: false,
228 };
229
230 for check in &checks {
231 assert!(check.check(&context).is_approved());
232 }
233 }
234
235 #[test]
236 fn test_composite_check() {
237 let composite = RiskCheckBuilder::new()
238 .max_trade_amount(dec!(1000.0))
239 .max_slippage(dec!(2.0))
240 .build_composite("test_composite".to_string());
241
242 let good_context = TradeContext {
243 user_id: "test".to_string(),
244 from_token: "USDC".to_string(),
245 to_token: "SOL".to_string(),
246 amount_usd: dec!(500.0),
247 expected_slippage: dec!(1.0),
248 liquidity_usd: Some(dec!(200000.0)),
249 is_flagged: false,
250 };
251
252 assert!(composite.check(&good_context).is_approved());
253
254 let bad_context = TradeContext {
255 amount_usd: dec!(2000.0),
256 ..good_context
257 };
258
259 assert!(!composite.check(&bad_context).is_approved());
260 }
261}