Skip to main content

aagt_core/trading/risk/
checks.rs

1//! Enhanced Risk Check system with composable checks
2
3use super::{RiskCheck, RiskCheckResult, TradeContext};
4use rust_decimal::Decimal;
5use std::sync::Arc;
6
7/// Maximum trade amount check
8pub struct MaxTradeAmountCheck {
9    max_amount: Decimal,
10}
11
12impl MaxTradeAmountCheck {
13    pub fn new(max_amount: Decimal) -> Self {
14        Self { max_amount }
15    }
16}
17
18impl RiskCheck for MaxTradeAmountCheck {
19    fn name(&self) -> &str {
20        "max_trade_amount"
21    }
22
23    fn check(&self, context: &TradeContext) -> RiskCheckResult {
24        if context.amount_usd > self.max_amount {
25            RiskCheckResult::Rejected {
26                reason: format!(
27                    "Trade amount ${} exceeds maximum ${}",
28                    context.amount_usd, self.max_amount
29                ),
30            }
31        } else {
32            RiskCheckResult::Approved
33        }
34    }
35}
36
37/// Slippage tolerance check
38pub struct SlippageCheck {
39    max_slippage_percent: Decimal,
40}
41
42impl SlippageCheck {
43    pub fn new(max_slippage_percent: Decimal) -> Self {
44        Self {
45            max_slippage_percent,
46        }
47    }
48}
49
50impl RiskCheck for SlippageCheck {
51    fn name(&self) -> &str {
52        "slippage"
53    }
54
55    fn check(&self, context: &TradeContext) -> RiskCheckResult {
56        if context.expected_slippage > self.max_slippage_percent {
57            RiskCheckResult::Rejected {
58                reason: format!(
59                    "Slippage {}% exceeds maximum {}%",
60                    context.expected_slippage, self.max_slippage_percent
61                ),
62            }
63        } else {
64            RiskCheckResult::Approved
65        }
66    }
67}
68
69/// Liquidity check
70pub struct LiquidityCheck {
71    min_liquidity: Decimal,
72}
73
74impl LiquidityCheck {
75    pub fn new(min_liquidity: Decimal) -> Self {
76        Self { min_liquidity }
77    }
78}
79
80impl RiskCheck for LiquidityCheck {
81    fn name(&self) -> &str {
82        "liquidity"
83    }
84
85    fn check(&self, context: &TradeContext) -> RiskCheckResult {
86        match context.liquidity_usd {
87            Some(liq) if liq < self.min_liquidity => RiskCheckResult::Rejected {
88                reason: format!("Liquidity ${} below minimum ${}", liq, self.min_liquidity),
89            },
90            None => RiskCheckResult::PendingReview {
91                reason: "Liquidity data unavailable".to_string(),
92            },
93            _ => RiskCheckResult::Approved,
94        }
95    }
96}
97
98/// Token security check
99pub struct TokenSecurityCheck {
100    blacklist: Vec<String>,
101}
102
103impl TokenSecurityCheck {
104    pub fn new(blacklist: Vec<String>) -> Self {
105        Self { blacklist }
106    }
107}
108
109impl RiskCheck for TokenSecurityCheck {
110    fn name(&self) -> &str {
111        "token_security"
112    }
113
114    fn check(&self, context: &TradeContext) -> RiskCheckResult {
115        if context.is_flagged {
116            return RiskCheckResult::Rejected {
117                reason: "Token is flagged as risky".to_string(),
118            };
119        }
120
121        if self.blacklist.contains(&context.to_token) {
122            return RiskCheckResult::Rejected {
123                reason: format!("Token {} is blacklisted", context.to_token),
124            };
125        }
126
127        RiskCheckResult::Approved
128    }
129}
130
131/// Composite check that combines multiple checks
132pub struct CompositeCheck {
133    checks: Vec<Arc<dyn RiskCheck>>,
134    name: String,
135}
136
137impl CompositeCheck {
138    pub fn new(name: String, checks: Vec<Arc<dyn RiskCheck>>) -> Self {
139        Self { name, checks }
140    }
141}
142
143impl RiskCheck for CompositeCheck {
144    fn name(&self) -> &str {
145        &self.name
146    }
147
148    fn check(&self, context: &TradeContext) -> RiskCheckResult {
149        for check in &self.checks {
150            match check.check(context) {
151                RiskCheckResult::Approved => continue,
152                other => return other,
153            }
154        }
155        RiskCheckResult::Approved
156    }
157}
158
159/// Builder for creating risk check pipelines
160pub struct RiskCheckBuilder {
161    checks: Vec<Arc<dyn RiskCheck>>,
162}
163
164impl RiskCheckBuilder {
165    pub fn new() -> Self {
166        Self { checks: Vec::new() }
167    }
168
169    pub fn add_check(mut self, check: Arc<dyn RiskCheck>) -> Self {
170        self.checks.push(check);
171        self
172    }
173
174    pub fn max_trade_amount(self, max: Decimal) -> Self {
175        self.add_check(Arc::new(MaxTradeAmountCheck::new(max)))
176    }
177
178    pub fn max_slippage(self, max_percent: Decimal) -> Self {
179        self.add_check(Arc::new(SlippageCheck::new(max_percent)))
180    }
181
182    pub fn min_liquidity(self, min: Decimal) -> Self {
183        self.add_check(Arc::new(LiquidityCheck::new(min)))
184    }
185
186    pub fn token_security(self, blacklist: Vec<String>) -> Self {
187        self.add_check(Arc::new(TokenSecurityCheck::new(blacklist)))
188    }
189
190    pub fn build(self) -> Vec<Arc<dyn RiskCheck>> {
191        self.checks
192    }
193
194    pub fn build_composite(self, name: String) -> Arc<dyn RiskCheck> {
195        Arc::new(CompositeCheck::new(name, self.checks))
196    }
197}
198
199impl Default for RiskCheckBuilder {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use rust_decimal_macros::dec;
209
210    #[test]
211    fn test_risk_check_builder() {
212        let checks = RiskCheckBuilder::new()
213            .max_trade_amount(dec!(1000.0))
214            .max_slippage(dec!(2.0))
215            .min_liquidity(dec!(100000.0))
216            .build();
217
218        assert_eq!(checks.len(), 3);
219
220        let context = TradeContext {
221            user_id: "test".to_string(),
222            from_token: "USDC".to_string(),
223            to_token: "SOL".to_string(),
224            amount_usd: dec!(500.0),
225            expected_slippage: dec!(1.0),
226            liquidity_usd: Some(dec!(200000.0)),
227            is_flagged: false,
228        };
229
230        for check in &checks {
231            assert!(check.check(&context).is_approved());
232        }
233    }
234
235    #[test]
236    fn test_composite_check() {
237        let composite = RiskCheckBuilder::new()
238            .max_trade_amount(dec!(1000.0))
239            .max_slippage(dec!(2.0))
240            .build_composite("test_composite".to_string());
241
242        let good_context = TradeContext {
243            user_id: "test".to_string(),
244            from_token: "USDC".to_string(),
245            to_token: "SOL".to_string(),
246            amount_usd: dec!(500.0),
247            expected_slippage: dec!(1.0),
248            liquidity_usd: Some(dec!(200000.0)),
249            is_flagged: false,
250        };
251
252        assert!(composite.check(&good_context).is_approved());
253
254        let bad_context = TradeContext {
255            amount_usd: dec!(2000.0),
256            ..good_context
257        };
258
259        assert!(!composite.check(&bad_context).is_approved());
260    }
261}