1use std::sync::Arc;
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::{mpsc, oneshot};
12use chrono::{DateTime, Utc};
13use futures::FutureExt;
14use rust_decimal::Decimal;
15use rust_decimal_macros::dec;
16
17use crate::error::{Error, Result};
18
19mod circuit_breaker;
20pub use circuit_breaker::DeadManSwitch;
21
22mod checks;
23pub use checks::{
24 CompositeCheck, LiquidityCheck, MaxTradeAmountCheck,
25 RiskCheckBuilder, SlippageCheck, TokenSecurityCheck,
26};
27
28#[async_trait::async_trait]
30pub trait RiskStateStore: Send + Sync {
31 async fn load(&self) -> Result<HashMap<String, UserState>>;
32 async fn save(&self, states: &HashMap<String, UserState>) -> Result<()>;
33}
34
35pub struct FileRiskStore {
37 path: PathBuf,
38}
39
40impl FileRiskStore {
41 pub fn new(path: impl Into<PathBuf>) -> Self {
42 Self { path: path.into() }
43 }
44}
45
46#[async_trait::async_trait]
47impl RiskStateStore for FileRiskStore {
48 async fn load(&self) -> Result<HashMap<String, UserState>> {
49 if !self.path.exists() {
50 return Ok(HashMap::new());
51 }
52 let content = tokio::fs::read_to_string(&self.path).await?;
53 if content.trim().is_empty() {
54 return Ok(HashMap::new());
55 }
56
57 serde_json::from_str(&content).map_err(|e| {
58 Error::Internal(format!("CORRUPTION: Risk state file at {:?} is malformed. Delete it to reset or fix JSON: {}", self.path, e))
59 })
60 }
61
62 async fn save(&self, states: &HashMap<String, UserState>) -> Result<()> {
63 if let Some(parent) = self.path.parent() {
64 tokio::fs::create_dir_all(parent).await.ok();
65 }
66
67 let path = self.path.clone();
68 let states = states.clone();
69
70 tokio::task::spawn_blocking(move || {
72 let tmp_path = path.with_extension(format!("tmp.{}", uuid::Uuid::new_v4()));
73
74 {
76 let file = std::fs::File::create(&tmp_path)
77 .map_err(|e| Error::Internal(format!("Failed to create tmp risk file: {}", e)))?;
78 let writer = std::io::BufWriter::new(file);
79
80 serde_json::to_writer_pretty(writer, &states)
81 .map_err(|e| Error::Internal(format!("Failed to serialize risk state: {}", e)))?;
82 }
84
85 std::fs::rename(&tmp_path, &path)
86 .map_err(|e| {
87 let _ = std::fs::remove_file(&tmp_path);
89 Error::Internal(format!("Failed to rename risk file: {}", e))
90 })?;
91
92 Ok::<(), Error>(())
93 }).await.map_err(|e| Error::Internal(format!("Join error: {}", e)))??;
94
95 Ok(())
96 }
97}
98
99pub struct InMemoryRiskStore;
101
102#[async_trait::async_trait]
103impl RiskStateStore for InMemoryRiskStore {
104 async fn load(&self) -> Result<HashMap<String, UserState>> { Ok(HashMap::new()) }
105 async fn save(&self, _: &HashMap<String, UserState>) -> Result<()> { Ok(()) }
106}
107
108#[derive(Debug, Clone)]
110pub enum RiskCheckResult {
111 Approved,
113 Rejected { reason: String },
115 PendingReview { reason: String },
117}
118
119impl RiskCheckResult {
120 pub fn is_approved(&self) -> bool {
122 matches!(self, Self::Approved)
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct RiskConfig {
129 pub max_single_trade_usd: Decimal,
131 pub max_daily_volume_usd: Decimal,
133 pub max_slippage_percent: Decimal,
135 pub min_liquidity_usd: Decimal,
137 pub enable_rug_detection: bool,
139 pub trade_cooldown_secs: u64,
141}
142
143impl Default for RiskConfig {
144 fn default() -> Self {
145 Self {
146 max_single_trade_usd: dec!(10000.0),
147 max_daily_volume_usd: dec!(50000.0),
148 max_slippage_percent: dec!(5.0),
149 min_liquidity_usd: dec!(100000.0),
150 enable_rug_detection: true,
151 trade_cooldown_secs: 5,
152 }
153 }
154}
155
156pub trait RiskCheck: Send + Sync {
158 fn name(&self) -> &str;
160
161 fn check(&self, context: &TradeContext) -> RiskCheckResult;
163}
164
165#[derive(Debug, Clone)]
167pub struct TradeContext {
168 pub user_id: String,
170 pub from_token: String,
172 pub to_token: String,
174 pub amount_usd: Decimal,
176 pub expected_slippage: Decimal,
178 pub liquidity_usd: Option<Decimal>,
180 pub is_flagged: bool,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct UserState {
186 pub daily_volume_usd: Decimal,
188 pub pending_volume_usd: Decimal,
190 pub last_trade: Option<DateTime<Utc>>,
192 pub volume_reset: DateTime<Utc>,
194}
195
196impl Default for UserState {
197 fn default() -> Self {
198 Self {
199 daily_volume_usd: Decimal::ZERO,
200 pending_volume_usd: Decimal::ZERO,
201 last_trade: None,
202 volume_reset: Utc::now(),
203 }
204 }
205}
206
207enum RiskCommand {
210 CheckAndReserve { context: TradeContext, checks: Vec<Arc<dyn RiskCheck>>, reply: oneshot::Sender<Result<()>> },
211 Commit { user_id: String, amount_usd: Decimal, reply: oneshot::Sender<Result<()>> },
212 Rollback { user_id: String, amount_usd: Decimal },
213 GetRemaining { user_id: String, reply: oneshot::Sender<Decimal> },
214 LoadState { reply: oneshot::Sender<Result<()>> },
215}
216
217struct RiskActor {
218 config: RiskConfig,
219 state: HashMap<String, UserState>,
220 store: Arc<dyn RiskStateStore>,
221 receiver: mpsc::Receiver<RiskCommand>,
222 last_load_time: Option<DateTime<Utc>>,
223}
224
225impl RiskActor {
226
227 async fn handle_load(&mut self) -> Result<()> {
228 let mut loaded = self.store.load().await?;
229
230 for (_, state) in loaded.iter_mut() {
233 if !state.pending_volume_usd.is_zero() {
234 tracing::warn!("Resetting zombie pending volume of ${} for user", state.pending_volume_usd);
235 state.pending_volume_usd = Decimal::ZERO;
236 }
237 }
238
239 self.state = loaded;
240 self.last_load_time = Some(Utc::now());
241 Ok(())
242 }
243
244 async fn handle_check_and_reserve(&mut self, context: TradeContext, checks: Vec<Arc<dyn RiskCheck>>) -> Result<()> {
245 let config = self.config.clone();
248 let ctx_clone = context.clone();
249 tokio::task::spawn_blocking(move || {
250 Self::validate_stateless(&config, &ctx_clone, &checks)
251 }).await.map_err(|e| Error::Internal(format!("Task panic: {}", e)))??;
252
253 let state = self.state.entry(context.user_id.clone()).or_default();
255
256 let now = Utc::now();
258 if now.date_naive() > state.volume_reset.date_naive() {
259 state.daily_volume_usd = Decimal::ZERO;
260 state.volume_reset = now;
261 }
262
263 let projected = state.daily_volume_usd + state.pending_volume_usd + context.amount_usd;
265 if projected > self.config.max_daily_volume_usd {
266 return Err(Error::RiskLimitExceeded {
267 limit_type: "daily_volume".to_string(),
268 current: format!("${:.2}", projected),
269 max: format!("${:.2}", self.config.max_daily_volume_usd),
270 });
271 }
272
273 if let Some(last) = state.last_trade {
275 let elapsed = now - last;
276 if elapsed < chrono::Duration::seconds(self.config.trade_cooldown_secs as i64) {
277 return Err(Error::risk_check_failed("cooldown", "Trading too fast"));
278 }
279 }
280
281 state.pending_volume_usd += context.amount_usd;
283
284 self.store.save(&self.state).await?;
286
287 Ok(())
288 }
289
290 fn validate_stateless(config: &RiskConfig, context: &TradeContext, checks: &[Arc<dyn RiskCheck>]) -> Result<()> {
292 if context.amount_usd <= Decimal::ZERO {
294 return Err(Error::risk_check_failed("amount_validation", format!("Amount must be positive, got ${:.2}", context.amount_usd)));
295 }
296
297 if context.amount_usd > config.max_single_trade_usd {
298 return Err(Error::RiskLimitExceeded {
299 limit_type: "single_trade".to_string(),
300 current: format!("${:.2}", context.amount_usd),
301 max: format!("${:.2}", config.max_single_trade_usd),
302 });
303 }
304 if context.expected_slippage > config.max_slippage_percent {
305 return Err(Error::risk_check_failed("slippage", format!("Slippage {} > {}", context.expected_slippage, config.max_slippage_percent)));
306 }
307 if let Some(liq) = context.liquidity_usd {
308 if liq < config.min_liquidity_usd {
309 return Err(Error::risk_check_failed("liquidity", "Insufficient liquidity"));
310 }
311 }
312 if config.enable_rug_detection && context.is_flagged {
313 return Err(Error::risk_check_failed("rug_detection", "Token flagged as risky"));
314 }
315
316 for check in checks {
317 if let RiskCheckResult::Rejected { reason } = check.check(context) {
318 return Err(Error::RiskCheckFailed { check_name: check.name().to_string(), reason });
319 }
320 }
321 Ok(())
322 }
323
324 async fn handle_commit(&mut self, user_id: String, amount: Decimal) -> Result<()> {
325 let state = self.state.entry(user_id.clone()).or_default();
326
327 let old_pending = state.pending_volume_usd;
328 let old_daily = state.daily_volume_usd;
329 let old_last = state.last_trade;
330
331 state.pending_volume_usd = (state.pending_volume_usd - amount).max(Decimal::ZERO);
332 state.daily_volume_usd += amount;
333 state.last_trade = Some(Utc::now());
334
335 if let Err(e) = self.store.save(&self.state).await {
336 if let Some(s) = self.state.get_mut(&user_id) {
338 s.pending_volume_usd = old_pending;
339 s.daily_volume_usd = old_daily;
340 s.last_trade = old_last;
341 }
342 return Err(e);
343 }
344 Ok(())
345 }
346
347 fn handle_rollback(&mut self, user_id: String, amount: Decimal) {
348 if let Some(state) = self.state.get_mut(&user_id) {
349 state.pending_volume_usd = (state.pending_volume_usd - amount).max(Decimal::ZERO);
350 }
351 }
352
353 fn handle_get_remaining(&self, user_id: String) -> Decimal {
354 if let Some(state) = self.state.get(&user_id) {
355 (self.config.max_daily_volume_usd - (state.daily_volume_usd + state.pending_volume_usd)).max(Decimal::ZERO)
356 } else {
357 self.config.max_daily_volume_usd
358 }
359 }
360}
361
362pub struct RiskManager {
364 sender: mpsc::Sender<RiskCommand>,
365 config: RiskConfig,
367 custom_checks: std::sync::RwLock<Vec<Arc<dyn RiskCheck>>>,
371}
372
373impl RiskManager {
374 pub async fn new() -> Result<Self> {
376 Self::with_config(RiskConfig::default(), Arc::new(InMemoryRiskStore)).await
377 }
378
379 pub async fn with_config(config: RiskConfig, store: Arc<dyn RiskStateStore>) -> Result<Self> {
381 let (tx, rx) = mpsc::channel(100);
382
383 let mut actor = RiskActor {
384 config: config.clone(),
385 state: HashMap::new(),
386 store,
387 receiver: rx,
388 last_load_time: None,
389 };
390
391 actor.handle_load().await?;
393
394 tokio::spawn(async move {
395 let mut actor = actor;
396 loop {
397 let rx = &mut actor.receiver;
398 if rx.is_closed() {
400 break;
401 }
402
403 tracing::info!("RiskActor starting/restarting");
404 let res = std::panic::AssertUnwindSafe(async {
405 let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
414 let mut dirty = false; loop {
417 tokio::select! {
418 maybe_msg = actor.receiver.recv() => {
419 match maybe_msg {
420 Some(msg) => {
421 match msg {
422 RiskCommand::CheckAndReserve { context, checks, reply } => {
423 let res = actor.handle_check_and_reserve(context, checks).await;
425 dirty = res.is_ok(); let _ = reply.send(res);
427 }
428 RiskCommand::Commit { user_id, amount_usd, reply } => {
429 let res = actor.handle_commit(user_id, amount_usd).await;
430 let _ = reply.send(res);
432 }
433 RiskCommand::Rollback { user_id, amount_usd } => {
434 actor.handle_rollback(user_id, amount_usd);
435 dirty = true;
436 }
437 RiskCommand::GetRemaining { user_id, reply } => {
438 let val = actor.handle_get_remaining(user_id);
439 let _ = reply.send(val);
440 }
441 RiskCommand::LoadState { reply } => {
442 let res = actor.handle_load().await;
443 let _ = reply.send(res);
444 }
445 }
446 }
447 None => break, }
449 }
450 _ = interval.tick() => {
451 if dirty {
453 tracing::debug!("RiskManager: performing periodic state flush");
454 if let Err(e) = actor.store.save(&actor.state).await {
455 tracing::error!("Periodic risk persistence failed: {}", e);
456 } else {
457 dirty = false;
458 }
459 }
460 }
461 }
462 }
463 }).catch_unwind().await;
464
465 if let Err(_) = res {
466 tracing::error!("RiskActor PANICKED. Restarting in 1s...");
467 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
468 } else {
469 break;
471 }
472 }
473 });
474
475 let manager = Self {
476 sender: tx,
477 config,
478 custom_checks: std::sync::RwLock::new(Vec::new()),
479 };
480
481 manager.load_state().await?;
483
484 Ok(manager)
485 }
486
487 pub async fn new_strict(config: RiskConfig, store: Arc<dyn RiskStateStore>) -> Result<Self> {
489 Self::with_config(config, store).await
490 }
491
492 pub async fn load_state(&self) -> Result<()> {
496 let (tx, rx) = oneshot::channel();
497 self.sender.send(RiskCommand::LoadState { reply: tx }).await
498 .map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
499 rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
500 }
501
502 pub fn add_check(&self, check: Arc<dyn RiskCheck>) {
504 if let Ok(mut checks) = self.custom_checks.write() {
505 checks.push(check);
506 }
507 }
508
509 pub async fn check_and_reserve(&self, context: &TradeContext) -> Result<()> {
511 let checks = self.custom_checks.read()
512 .map_err(|_| Error::Internal("Risk check lock poisoned".to_string()))?
513 .clone();
514
515 let (tx, rx) = oneshot::channel();
516 self.sender.send(RiskCommand::CheckAndReserve {
517 context: context.clone(),
518 checks,
519 reply: tx
520 }).await.map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
521
522 rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
523 }
524
525 #[deprecated(note = "Use check_and_reserve for race-condition safety")]
527 pub async fn check_trade(&self, context: &TradeContext) -> Result<()> {
528 self.check_and_reserve(context).await?;
529 self.rollback_trade(&context.user_id, context.amount_usd).await;
530 Ok(())
531 }
532
533 pub async fn commit_trade(&self, user_id: &str, amount_usd: Decimal) -> Result<()> {
535 let (tx, rx) = oneshot::channel();
536 self.sender.send(RiskCommand::Commit {
537 user_id: user_id.to_string(),
538 amount_usd,
539 reply: tx
540 }).await.map_err(|_| Error::Internal("Risk actor closed".to_string()))?;
541
542 rx.await.map_err(|_| Error::Internal("Risk actor dropped reply".to_string()))?
543 }
544
545 pub async fn rollback_trade(&self, user_id: &str, amount_usd: Decimal) {
547 let _ = self.sender.send(RiskCommand::Rollback {
548 user_id: user_id.to_string(),
549 amount_usd
550 }).await;
551 }
552
553 pub async fn record_trade(&self, user_id: &str, amount_usd: Decimal) -> Result<()> {
555 self.commit_trade(user_id, amount_usd).await
556 }
557
558 pub async fn remaining_daily_limit(&self, user_id: &str) -> Decimal {
560 let (tx, rx) = oneshot::channel();
561 if let Err(_) = self.sender.send(RiskCommand::GetRemaining {
562 user_id: user_id.to_string(),
563 reply: tx
564 }).await {
565 return Decimal::ZERO;
566 }
567 rx.await.unwrap_or(Decimal::ZERO)
568 }
569}
570
571#[cfg(test)]
575mod tests {
576 use super::*;
577 use rust_decimal_macros::dec;
578
579 #[tokio::test]
580 async fn test_single_trade_limit() {
581 let manager = RiskManager::with_config(
582 RiskConfig {
583 max_single_trade_usd: dec!(1000.0),
584 ..Default::default()
585 },
586 Arc::new(InMemoryRiskStore),
587 ).await.unwrap();
588
589 let context = TradeContext {
590 user_id: "user1".to_string(),
591 from_token: "USDC".to_string(),
592 to_token: "SOL".to_string(),
593 amount_usd: dec!(5000.0),
594 expected_slippage: dec!(0.5),
595 liquidity_usd: Some(dec!(1_000_000.0)),
596 is_flagged: false,
597 };
598
599 let result = manager.check_and_reserve(&context).await;
600 assert!(result.is_err());
601 }
602
603 #[tokio::test]
604 async fn test_reserve_commit_flow() {
605 let manager = RiskManager::new().await.unwrap();
606
607 let context = TradeContext {
608 user_id: "user1".to_string(),
609 from_token: "USDC".to_string(),
610 to_token: "SOL".to_string(),
611 amount_usd: dec!(100.0),
612 expected_slippage: dec!(0.5),
613 liquidity_usd: Some(dec!(1_000_000.0)),
614 is_flagged: false,
615 };
616
617 assert!(manager.check_and_reserve(&context).await.is_ok());
619
620 manager.commit_trade("user1", dec!(100.0)).await.unwrap();
622
623 let remaining = manager.remaining_daily_limit("user1").await;
625 assert_eq!(remaining, dec!(50_000.0) - dec!(100.0));
626 }
627}