Skip to main content

aagt_core/trading/
risk.rs

1//! Risk control system for trading operations
2//!
3//! Provides safety checks before executing trades.
4//! Refactored to use the Actor Model for lock-free concurrency and durability.
5
6use std::sync::Arc;
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::{mpsc, oneshot};
12use chrono::{DateTime, Utc};
13use futures::FutureExt;
14use rust_decimal::Decimal;
15use rust_decimal_macros::dec;
16
17use crate::error::{Error, Result};
18
19mod circuit_breaker;
20pub use circuit_breaker::DeadManSwitch;
21
22mod checks;
23pub use checks::{
24    CompositeCheck, LiquidityCheck, MaxTradeAmountCheck, 
25    RiskCheckBuilder, SlippageCheck, TokenSecurityCheck,
26};
27
28/// Persistence trait for risk state
29#[async_trait::async_trait]
30pub trait RiskStateStore: Send + Sync {
31    async fn load(&self) -> Result<HashMap<String, UserState>>;
32    async fn save(&self, states: &HashMap<String, UserState>) -> Result<()>;
33}
34
35/// Simple JSON file store for risk state
36pub struct FileRiskStore {
37    path: PathBuf,
38}
39
40impl FileRiskStore {
41    pub fn new(path: impl Into<PathBuf>) -> Self {
42        Self { path: path.into() }
43    }
44}
45
46#[async_trait::async_trait]
47impl RiskStateStore for FileRiskStore {
48    async fn load(&self) -> Result<HashMap<String, UserState>> {
49        if !self.path.exists() {
50            return Ok(HashMap::new());
51        }
52        let content = tokio::fs::read_to_string(&self.path).await?;
53        if content.trim().is_empty() {
54            return Ok(HashMap::new());
55        }
56        
57        serde_json::from_str(&content).map_err(|e| {
58            Error::Internal(format!("CORRUPTION: Risk state file at {:?} is malformed. Delete it to reset or fix JSON: {}", self.path, e))
59        })
60    }
61
62    async fn save(&self, states: &HashMap<String, UserState>) -> Result<()> {
63        if let Some(parent) = self.path.parent() {
64            tokio::fs::create_dir_all(parent).await.ok();
65        }
66
67        let path = self.path.clone();
68        let states = states.clone(); 
69
70        // Fix #3: Atomic Write Pattern (Write tmp -> Rename)
71        tokio::task::spawn_blocking(move || {
72            let tmp_path = path.with_extension(format!("tmp.{}", uuid::Uuid::new_v4()));
73            
74            // Scope for file to ensure it closes before rename
75            {
76                let file = std::fs::File::create(&tmp_path)
77                    .map_err(|e| Error::Internal(format!("Failed to create tmp risk file: {}", e)))?;
78                let writer = std::io::BufWriter::new(file);
79                
80                serde_json::to_writer_pretty(writer, &states)
81                    .map_err(|e| Error::Internal(format!("Failed to serialize risk state: {}", e)))?;
82                // File closes here
83            }
84
85            std::fs::rename(&tmp_path, &path)
86                .map_err(|e| {
87                    // Try to clean up tmp file if rename fails
88                    let _ = std::fs::remove_file(&tmp_path);
89                    Error::Internal(format!("Failed to rename risk file: {}", e))
90                })?;
91            
92            Ok::<(), Error>(())
93        }).await.map_err(|e| Error::Internal(format!("Join error: {}", e)))??;
94        
95        Ok(())
96    }
97}
98
99/// No-op store for in-memory only execution
100pub struct InMemoryRiskStore;
101
102#[async_trait::async_trait]
103impl RiskStateStore for InMemoryRiskStore {
104    async fn load(&self) -> Result<HashMap<String, UserState>> { Ok(HashMap::new()) }
105    async fn save(&self, _: &HashMap<String, UserState>) -> Result<()> { Ok(()) }
106}
107
108/// Risk check result
109#[derive(Debug, Clone)]
110pub enum RiskCheckResult {
111    /// Check passed
112    Approved,
113    /// Check failed with reason
114    Rejected { reason: String },
115    /// Needs manual review
116    PendingReview { reason: String },
117}
118
119impl RiskCheckResult {
120    /// Check if approved
121    pub fn is_approved(&self) -> bool {
122        matches!(self, Self::Approved)
123    }
124}
125
126/// Configuration for risk controls
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct RiskConfig {
129    /// Maximum single trade amount in USD
130    pub max_single_trade_usd: Decimal,
131    /// Maximum daily volume usd
132    pub max_daily_volume_usd: Decimal,
133    /// Maximum slippage percentage allowed
134    pub max_slippage_percent: Decimal,
135    /// Minimum liquidity required in USD
136    pub min_liquidity_usd: Decimal,
137    /// Enable rug pull detection
138    pub enable_rug_detection: bool,
139    /// Cooldown between trades in seconds
140    pub trade_cooldown_secs: u64,
141}
142
143impl Default for RiskConfig {
144    fn default() -> Self {
145        Self {
146            max_single_trade_usd: dec!(10000.0),
147            max_daily_volume_usd: dec!(50000.0),
148            max_slippage_percent: dec!(5.0),
149            min_liquidity_usd: dec!(100000.0),
150            enable_rug_detection: true,
151            trade_cooldown_secs: 5,
152        }
153    }
154}
155
156/// A risk check that can be performed
157pub trait RiskCheck: Send + Sync {
158    /// Name of this check
159    fn name(&self) -> &str;
160
161    /// Perform the check
162    fn check(&self, context: &TradeContext) -> RiskCheckResult;
163}
164
165/// Context for a trade being checked
166#[derive(Debug, Clone)]
167pub struct TradeContext {
168    /// User ID
169    pub user_id: String,
170    /// Token being sold
171    pub from_token: String,
172    /// Token being bought
173    pub to_token: String,
174    /// Amount in USD
175    pub amount_usd: Decimal,
176    /// Expected slippage
177    pub expected_slippage: Decimal,
178    /// Token liquidity in USD
179    pub liquidity_usd: Option<Decimal>,
180    /// Is this token flagged as risky
181    pub is_flagged: bool,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct UserState {
186    /// Daily volume traded (committed)
187    pub daily_volume_usd: Decimal,
188    /// Volume currently reserved by pending trades (not yet committed)
189    pub pending_volume_usd: Decimal,
190    /// Last trade timestamp
191    pub last_trade: Option<DateTime<Utc>>,
192    /// Volume reset time (Last date processed)
193    pub volume_reset: DateTime<Utc>,
194}
195
196impl Default for UserState {
197    fn default() -> Self {
198        Self {
199            daily_volume_usd: Decimal::ZERO,
200            pending_volume_usd: Decimal::ZERO,
201            last_trade: None,
202            volume_reset: Utc::now(),
203        }
204    }
205}
206
207// --- Actor Implementation ---
208
209enum RiskCommand {
210    CheckAndReserve { context: TradeContext, checks: Vec<Arc<dyn RiskCheck>>, reply: oneshot::Sender<Result<()>> },
211    Commit { user_id: String, amount_usd: Decimal, reply: oneshot::Sender<Result<()>> },
212    Rollback { user_id: String, amount_usd: Decimal },
213    GetRemaining { user_id: String, reply: oneshot::Sender<Decimal> },
214    LoadState { reply: oneshot::Sender<Result<()>> },
215}
216
217struct RiskActor {
218    config: RiskConfig,
219    state: HashMap<String, UserState>,
220    store: Arc<dyn RiskStateStore>,
221    receiver: mpsc::Receiver<RiskCommand>,
222    last_load_time: Option<DateTime<Utc>>,
223}
224
225impl RiskActor {
226
227    async fn handle_load(&mut self) -> Result<()> {
228        let mut loaded = self.store.load().await?;
229        
230        // Fix #2.1: Clear zombie pending volumes on startup
231        // Anything pending during a crash is considered failed/not-executed.
232        for (_, state) in loaded.iter_mut() {
233            if !state.pending_volume_usd.is_zero() {
234                tracing::warn!("Resetting zombie pending volume of ${} for user", state.pending_volume_usd);
235                state.pending_volume_usd = Decimal::ZERO;
236            }
237        }
238
239        self.state = loaded;
240        self.last_load_time = Some(Utc::now());
241        Ok(())
242    }
243
244    async fn handle_check_and_reserve(&mut self, context: TradeContext, checks: Vec<Arc<dyn RiskCheck>>) -> Result<()> {
245        // 1. Offload heavy/STATLESS checks to blocking thread
246        // These checks don't need UserState (RAM) and could involve I/O in custom checks
247        let config = self.config.clone();
248        let ctx_clone = context.clone();
249        tokio::task::spawn_blocking(move || {
250             Self::validate_stateless(&config, &ctx_clone, &checks)
251        }).await.map_err(|e| Error::Internal(format!("Task panic: {}", e)))??;
252
253        // 2. Perform STATEFUL checks inside Actor (Atomic)
254        let state = self.state.entry(context.user_id.clone()).or_default();
255        
256        // Reset volume if day changed
257        let now = Utc::now();
258        if now.date_naive() > state.volume_reset.date_naive() {
259            state.daily_volume_usd = Decimal::ZERO;
260            state.volume_reset = now;
261        }
262
263        // Daily limit check
264        let projected = state.daily_volume_usd + state.pending_volume_usd + context.amount_usd;
265        if projected > self.config.max_daily_volume_usd {
266            return Err(Error::RiskLimitExceeded {
267                limit_type: "daily_volume".to_string(),
268                current: format!("${:.2}", projected),
269                max: format!("${:.2}", self.config.max_daily_volume_usd),
270            });
271        }
272
273        // Cooldown check
274        if let Some(last) = state.last_trade {
275            let elapsed = now - last;
276            if elapsed < chrono::Duration::seconds(self.config.trade_cooldown_secs as i64) {
277                 return Err(Error::risk_check_failed("cooldown", "Trading too fast"));
278            }
279        }
280
281        // Commit reservation
282        state.pending_volume_usd += context.amount_usd;
283        
284        // Immediate save for reservation
285        self.store.save(&self.state).await?;
286        
287        Ok(())
288    }
289
290    /// Stateless validation logic - can be run outside Actor
291    fn validate_stateless(config: &RiskConfig, context: &TradeContext, checks: &[Arc<dyn RiskCheck>]) -> Result<()> {
292        // Fix #2: Reject negative or zero amounts (Crucial Security Fix)
293        if context.amount_usd <= Decimal::ZERO {
294             return Err(Error::risk_check_failed("amount_validation", format!("Amount must be positive, got ${:.2}", context.amount_usd)));
295        }
296
297        if context.amount_usd > config.max_single_trade_usd {
298            return Err(Error::RiskLimitExceeded {
299                limit_type: "single_trade".to_string(),
300                current: format!("${:.2}", context.amount_usd),
301                max: format!("${:.2}", config.max_single_trade_usd),
302            });
303        }
304        if context.expected_slippage > config.max_slippage_percent {
305            return Err(Error::risk_check_failed("slippage", format!("Slippage {} > {}", context.expected_slippage, config.max_slippage_percent)));
306        }
307        if let Some(liq) = context.liquidity_usd {
308            if liq < config.min_liquidity_usd {
309                return Err(Error::risk_check_failed("liquidity", "Insufficient liquidity"));
310            }
311        }
312        if config.enable_rug_detection && context.is_flagged {
313            return Err(Error::risk_check_failed("rug_detection", "Token flagged as risky"));
314        }
315
316        for check in checks {
317            if let RiskCheckResult::Rejected { reason } = check.check(context) {
318                return Err(Error::RiskCheckFailed { check_name: check.name().to_string(), reason });
319            }
320        }
321        Ok(())
322    }
323
324    async fn handle_commit(&mut self, user_id: String, amount: Decimal) -> Result<()> {
325        let state = self.state.entry(user_id.clone()).or_default();
326        
327        let old_pending = state.pending_volume_usd;
328        let old_daily = state.daily_volume_usd;
329        let old_last = state.last_trade;
330
331        state.pending_volume_usd = (state.pending_volume_usd - amount).max(Decimal::ZERO);
332        state.daily_volume_usd += amount;
333        state.last_trade = Some(Utc::now());
334
335        if let Err(e) = self.store.save(&self.state).await {
336            // Rollback on failure
337            if let Some(s) = self.state.get_mut(&user_id) {
338                s.pending_volume_usd = old_pending;
339                s.daily_volume_usd = old_daily;
340                s.last_trade = old_last;
341            }
342            return Err(e);
343        }
344        Ok(())
345    }
346
347    fn handle_rollback(&mut self, user_id: String, amount: Decimal) {
348        if let Some(state) = self.state.get_mut(&user_id) {
349            state.pending_volume_usd = (state.pending_volume_usd - amount).max(Decimal::ZERO);
350        }
351    }
352
353    fn handle_get_remaining(&self, user_id: String) -> Decimal {
354        if let Some(state) = self.state.get(&user_id) {
355             (self.config.max_daily_volume_usd - (state.daily_volume_usd + state.pending_volume_usd)).max(Decimal::ZERO)
356        } else {
357            self.config.max_daily_volume_usd
358        }
359    }
360}
361
362/// The main risk manager
363pub struct RiskManager {
364    sender: mpsc::Sender<RiskCommand>,
365    /// We keep config copy for easy access if needed, but actor has it too.
366    config: RiskConfig,
367    /// Custom checks are moved to calls or kept here?
368    /// If we keep them here, we have to clone/send them on every check.
369    /// `Arc<dyn RiskCheck>` is cheap to clone.
370    custom_checks: std::sync::RwLock<Vec<Arc<dyn RiskCheck>>>,
371}
372
373impl RiskManager {
374    /// Create with default config and in-memory storage (Async)
375    pub async fn new() -> Result<Self> {
376        Self::with_config(RiskConfig::default(), Arc::new(InMemoryRiskStore)).await
377    }
378
379    /// Create with custom config and storage (Async)
380    pub async fn with_config(config: RiskConfig, store: Arc<dyn RiskStateStore>) -> Result<Self> {
381        let (tx, rx) = mpsc::channel(100);
382        
383        let mut actor = RiskActor {
384            config: config.clone(),
385            state: HashMap::new(),
386            store,
387            receiver: rx,
388            last_load_time: None,
389        };
390        
391        // Load initial state from store
392        actor.handle_load().await?;
393
394        tokio::spawn(async move {
395            let mut actor = actor;
396            loop {
397                let rx = &mut actor.receiver;
398                // If the receiver is closed, the manager is dropped, so we should exit
399                if rx.is_closed() {
400                    break;
401                }
402
403                tracing::info!("RiskActor starting/restarting");
404                let res = std::panic::AssertUnwindSafe(async {
405                    // We need to re-create the actor if we want to reset some state, 
406                    // but here we just want to keep the loop running and handle messages.
407                    // The state is already in the `actor` struct.
408                    
409                    // Actually, if it panics, we might want to reload state from disk
410                    // but we have to be careful about pending volume.
411                    // For now, just keep the task alive.
412                    // Fix: Track if state was modified during message processing
413                    let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
414                    let mut dirty = false;  // Fix L2: Track if state needs saving
415                    
416                    loop {
417                        tokio::select! {
418                            maybe_msg = actor.receiver.recv() => {
419                                match maybe_msg {
420                                    Some(msg) => {
421                                         match msg {
422                                             RiskCommand::CheckAndReserve { context, checks, reply } => {
423                                                 // Moved checks into the handler
424                                                 let res = actor.handle_check_and_reserve(context, checks).await;
425                                                 dirty = res.is_ok();  // Mark dirty if reservation succeeded
426                                                 let _ = reply.send(res);
427                                             }
428                                             RiskCommand::Commit { user_id, amount_usd, reply } => {
429                                                 let res = actor.handle_commit(user_id, amount_usd).await;
430                                                 // Commit already saves, no need to set dirty
431                                                 let _ = reply.send(res);
432                                             }
433                                             RiskCommand::Rollback { user_id, amount_usd } => {
434                                                 actor.handle_rollback(user_id, amount_usd);
435                                                 dirty = true;
436                                             }
437                                             RiskCommand::GetRemaining { user_id, reply } => {
438                                                 let val = actor.handle_get_remaining(user_id);
439                                                 let _ = reply.send(val);
440                                             }
441                                             RiskCommand::LoadState { reply } => {
442                                                 let res = actor.handle_load().await;
443                                                 let _ = reply.send(res);
444                                             }
445                                         }
446                                    }
447                                    None => break, // Channel closed
448                                }
449                            }
450                            _ = interval.tick() => {
451                                // Fix L2: Only save if state was modified
452                                if dirty {
453                                    tracing::debug!("RiskManager: performing periodic state flush");
454                                    if let Err(e) = actor.store.save(&actor.state).await {
455                                         tracing::error!("Periodic risk persistence failed: {}", e);
456                                    } else {
457                                        dirty = false;
458                                    }
459                                }
460                            }
461                        }
462                    }
463                }).catch_unwind().await;
464
465                if let Err(_) = res {
466                    tracing::error!("RiskActor PANICKED. Restarting in 1s...");
467                    tokio::time::sleep(std::time::Duration::from_secs(1)).await;
468                } else {
469                    // Normal exit (sender dropped)
470                    break;
471                }
472            }
473        });
474
475        let manager = Self {
476            sender: tx,
477            config,
478            custom_checks: std::sync::RwLock::new(Vec::new()),
479        };
480        
481        // Fix #1: Auto-load state on startup
482        manager.load_state().await?;
483        
484        Ok(manager)
485    }
486    
487    /// Backward compatible Strict constructor (already strict, now matches new behavior but keeps name)
488    pub async fn new_strict(config: RiskConfig, store: Arc<dyn RiskStateStore>) -> Result<Self> {
489        Self::with_config(config, store).await
490    }
491    
492    // ... load_state and other methods remain ...
493
494    /// Load state from store
495    pub async fn load_state(&self) -> Result<()> {
496        let (tx, rx) = oneshot::channel();
497        self.sender.send(RiskCommand::LoadState { reply: tx }).await
498            .map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
499        rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
500    }
501
502    /// Add a custom risk check
503    pub fn add_check(&self, check: Arc<dyn RiskCheck>) {
504        if let Ok(mut checks) = self.custom_checks.write() {
505            checks.push(check);
506        }
507    }
508
509    /// Perform all risk checks for a trade AND reserve the volume.
510    pub async fn check_and_reserve(&self, context: &TradeContext) -> Result<()> {
511        let checks = self.custom_checks.read()
512            .map_err(|_| Error::Internal("Risk check lock poisoned".to_string()))?
513            .clone();
514        
515        let (tx, rx) = oneshot::channel();
516        self.sender.send(RiskCommand::CheckAndReserve { 
517            context: context.clone(), 
518            checks, 
519            reply: tx 
520        }).await.map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
521        
522        rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
523    }
524
525    /// Backward compatible check
526    #[deprecated(note = "Use check_and_reserve for race-condition safety")]
527    pub async fn check_trade(&self, context: &TradeContext) -> Result<()> {
528        self.check_and_reserve(context).await?;
529        self.rollback_trade(&context.user_id, context.amount_usd).await;
530        Ok(())
531    }
532
533    /// Commit a trade that was previously reserved
534    pub async fn commit_trade(&self, user_id: &str, amount_usd: Decimal) -> Result<()> {
535        let (tx, rx) = oneshot::channel();
536        self.sender.send(RiskCommand::Commit { 
537            user_id: user_id.to_string(), 
538            amount_usd, 
539            reply: tx 
540        }).await.map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
541        
542        rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
543    }
544
545    /// Rollback a reservation
546    pub async fn rollback_trade(&self, user_id: &str, amount_usd: Decimal) {
547        let _ = self.sender.send(RiskCommand::Rollback { 
548            user_id: user_id.to_string(), 
549            amount_usd 
550        }).await;
551    }
552
553    /// Record a trade immediately
554    pub async fn record_trade(&self, user_id: &str, amount_usd: Decimal) -> Result<()> {
555        self.commit_trade(user_id, amount_usd).await
556    }
557    
558    /// Get remaining daily limit for a user
559    pub async fn remaining_daily_limit(&self, user_id: &str) -> Decimal {
560        let (tx, rx) = oneshot::channel();
561        if let Err(_) = self.sender.send(RiskCommand::GetRemaining { 
562            user_id: user_id.to_string(), 
563            reply: tx 
564        }).await {
565            return Decimal::ZERO;
566        }
567        rx.await.unwrap_or(Decimal::ZERO)
568    }
569}
570
571// Default trait removed because new() is async. Use RiskManager::new().await instead.
572
573// Tests kept but might need async adjustment if logic changed (it mostly didn't, just interface)
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use rust_decimal_macros::dec;
578
579    #[tokio::test]
580    async fn test_single_trade_limit() {
581        let manager = RiskManager::with_config(
582            RiskConfig {
583                max_single_trade_usd: dec!(1000.0),
584                ..Default::default()
585            },
586            Arc::new(InMemoryRiskStore),
587        ).await.unwrap();
588
589        let context = TradeContext {
590            user_id: "user1".to_string(),
591            from_token: "USDC".to_string(),
592            to_token: "SOL".to_string(),
593            amount_usd: dec!(5000.0),
594            expected_slippage: dec!(0.5),
595            liquidity_usd: Some(dec!(1_000_000.0)),
596            is_flagged: false,
597        };
598
599        let result = manager.check_and_reserve(&context).await;
600        assert!(result.is_err());
601    }
602
603    #[tokio::test]
604    async fn test_reserve_commit_flow() {
605        let manager = RiskManager::new().await.unwrap();
606
607        let context = TradeContext {
608            user_id: "user1".to_string(),
609            from_token: "USDC".to_string(),
610            to_token: "SOL".to_string(),
611            amount_usd: dec!(100.0),
612            expected_slippage: dec!(0.5),
613            liquidity_usd: Some(dec!(1_000_000.0)),
614            is_flagged: false,
615        };
616
617        // 1. Reserve
618        assert!(manager.check_and_reserve(&context).await.is_ok());
619        
620        // 2. Commit
621        manager.commit_trade("user1", dec!(100.0)).await.unwrap();
622        
623        // 3. Check remaining
624        let remaining = manager.remaining_daily_limit("user1").await;
625        assert_eq!(remaining, dec!(50_000.0) - dec!(100.0));
626    }
627}