1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::PathBuf;
5use std::sync::Mutex;
6use std::time::Instant;
7
8use crate::alerts::{alert_history_path, AlertHistoryEntry};
9use hyper_ta::TechnicalIndicators;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct PositionLimits {
18 pub enabled: bool,
19 pub max_total_position: f64,
20 pub max_per_symbol: f64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct DailyLossLimits {
26 pub enabled: bool,
27 pub max_daily_loss: f64,
28 pub max_daily_loss_percent: f64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct AnomalyDetection {
34 pub enabled: bool,
35 pub max_order_size: f64,
36 pub max_orders_per_minute: u32,
37 pub block_duplicate_orders: bool,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct CircuitBreaker {
43 pub enabled: bool,
44 pub trigger_loss: f64,
45 pub trigger_window_minutes: u32,
46 pub action: String,
47 pub cooldown_minutes: u32,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct RiskConfig {
53 pub position_limits: PositionLimits,
54 pub daily_loss_limits: DailyLossLimits,
55 pub anomaly_detection: AnomalyDetection,
56 pub circuit_breaker: CircuitBreaker,
57}
58
59impl Default for RiskConfig {
60 fn default() -> Self {
61 Self {
62 position_limits: PositionLimits {
63 enabled: true,
64 max_total_position: 100_000.0,
65 max_per_symbol: 25_000.0,
66 },
67 daily_loss_limits: DailyLossLimits {
68 enabled: true,
69 max_daily_loss: 5_000.0,
70 max_daily_loss_percent: 5.0,
71 },
72 anomaly_detection: AnomalyDetection {
73 enabled: true,
74 max_order_size: 50_000.0,
75 max_orders_per_minute: 10,
76 block_duplicate_orders: true,
77 },
78 circuit_breaker: CircuitBreaker {
79 enabled: false,
80 trigger_loss: 10_000.0,
81 trigger_window_minutes: 60,
82 action: "pause_all".to_string(),
83 cooldown_minutes: 30,
84 },
85 }
86 }
87}
88
89pub fn risk_config_path() -> PathBuf {
90 let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
91 path.push("hyper-agent");
92 let _ = fs::create_dir_all(&path);
93 path.push("risk-config.json");
94 path
95}
96
97pub fn get_risk_config_sync() -> RiskConfig {
99 let path = risk_config_path();
100 if path.exists() {
101 fs::read_to_string(&path)
102 .ok()
103 .and_then(|data| serde_json::from_str(&data).ok())
104 .unwrap_or_default()
105 } else {
106 RiskConfig::default()
107 }
108}
109
110pub fn save_risk_config_to_disk(config: &RiskConfig) -> Result<(), String> {
111 let path = risk_config_path();
112 let json = serde_json::to_string_pretty(config).map_err(|e| e.to_string())?;
113 fs::write(&path, json).map_err(|e| e.to_string())?;
114 Ok(())
115}
116
117#[derive(Debug, Clone)]
123pub struct OrderRequest {
124 pub symbol: String,
126 pub side: String,
128 pub size: f64,
130 pub price: f64,
132}
133
134#[derive(Debug, Clone)]
136pub struct AccountState {
137 pub total_position_value: f64,
139 pub position_by_symbol: HashMap<String, f64>,
141 pub daily_realized_loss: f64,
143 pub daily_starting_equity: f64,
145 pub windowed_loss: f64,
147}
148
149impl Default for AccountState {
150 fn default() -> Self {
151 Self {
152 total_position_value: 0.0,
153 position_by_symbol: HashMap::new(),
154 daily_realized_loss: 0.0,
155 daily_starting_equity: 100_000.0,
156 windowed_loss: 0.0,
157 }
158 }
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct RiskViolation {
164 pub rule: String,
165 pub message: String,
166}
167
168impl std::fmt::Display for RiskViolation {
169 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170 write!(f, "[{}] {}", self.rule, self.message)
171 }
172}
173
174#[derive(Debug, Clone)]
176struct RecentOrder {
177 symbol: String,
178 side: String,
179 size: f64,
180 timestamp: Instant,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(tag = "type", rename_all = "snake_case")]
190pub enum TaRiskRule {
191 RsiExtreme {
193 lower: f64,
195 upper: f64,
197 },
198 AtrSurge {
200 threshold_pct: f64,
202 },
203 BollingerBreakout {
205 sigma_multiplier: f64,
207 },
208}
209
210pub fn check_ta_risk(
214 indicators: &TechnicalIndicators,
215 price: f64,
216 rules: &[TaRiskRule],
217) -> Vec<RiskViolation> {
218 let mut violations = Vec::new();
219
220 for rule in rules {
221 match rule {
222 TaRiskRule::RsiExtreme { lower, upper } => {
223 if let Some(rsi) = indicators.rsi_14 {
224 if rsi < *lower {
225 violations.push(RiskViolation {
226 rule: "RsiExtreme".to_string(),
227 message: format!(
228 "RSI ({:.2}) is below the lower extreme threshold ({:.2}). Opening new positions is restricted.",
229 rsi, lower,
230 ),
231 });
232 } else if rsi > *upper {
233 violations.push(RiskViolation {
234 rule: "RsiExtreme".to_string(),
235 message: format!(
236 "RSI ({:.2}) is above the upper extreme threshold ({:.2}). Opening new positions is restricted.",
237 rsi, upper,
238 ),
239 });
240 }
241 }
242 }
243 TaRiskRule::AtrSurge { threshold_pct } => {
244 if let Some(atr) = indicators.atr_14 {
245 if price > 0.0 {
246 let atr_pct = (atr / price) * 100.0;
247 if atr_pct > *threshold_pct {
248 violations.push(RiskViolation {
249 rule: "AtrSurge".to_string(),
250 message: format!(
251 "ATR as percentage of price ({:.2}%) exceeds threshold ({:.2}%). Consider reducing position size.",
252 atr_pct, threshold_pct,
253 ),
254 });
255 }
256 }
257 }
258 }
259 TaRiskRule::BollingerBreakout { sigma_multiplier } => {
260 if let (Some(bb_middle), Some(bb_upper)) =
263 (indicators.bb_middle, indicators.bb_upper)
264 {
265 let default_sigma = 2.0_f64;
267 let std_dev = (bb_upper - bb_middle) / default_sigma;
268 if std_dev > 0.0 {
269 let custom_upper = bb_middle + sigma_multiplier * std_dev;
270 let custom_lower = bb_middle - sigma_multiplier * std_dev;
271 if price > custom_upper {
272 violations.push(RiskViolation {
273 rule: "BollingerBreakout".to_string(),
274 message: format!(
275 "Price ({:.2}) is above the {:.1}σ Bollinger upper band ({:.2}). Trading is paused.",
276 price, sigma_multiplier, custom_upper,
277 ),
278 });
279 } else if price < custom_lower {
280 violations.push(RiskViolation {
281 rule: "BollingerBreakout".to_string(),
282 message: format!(
283 "Price ({:.2}) is below the {:.1}σ Bollinger lower band ({:.2}). Trading is paused.",
284 price, sigma_multiplier, custom_lower,
285 ),
286 });
287 }
288 }
289 }
290 }
291 }
292 }
293
294 violations
295}
296
297pub struct RiskGuard {
302 config: RiskConfig,
303 ta_rules: Vec<TaRiskRule>,
305 circuit_breaker_tripped: Mutex<bool>,
307 circuit_breaker_tripped_at: Mutex<Option<Instant>>,
309 recent_orders: Mutex<Vec<RecentOrder>>,
311}
312
313impl RiskGuard {
314 pub fn new(config: RiskConfig) -> Self {
315 Self {
316 config,
317 ta_rules: Vec::new(),
318 circuit_breaker_tripped: Mutex::new(false),
319 circuit_breaker_tripped_at: Mutex::new(None),
320 recent_orders: Mutex::new(Vec::new()),
321 }
322 }
323
324 pub fn with_ta_rules(config: RiskConfig, ta_rules: Vec<TaRiskRule>) -> Self {
326 Self {
327 config,
328 ta_rules,
329 circuit_breaker_tripped: Mutex::new(false),
330 circuit_breaker_tripped_at: Mutex::new(None),
331 recent_orders: Mutex::new(Vec::new()),
332 }
333 }
334
335 pub fn set_ta_rules(&mut self, rules: Vec<TaRiskRule>) {
337 self.ta_rules = rules;
338 }
339
340 #[allow(dead_code)]
342 pub fn update_config(&mut self, config: RiskConfig) {
343 self.config = config;
344 }
345
346 pub fn check_order_with_ta(
352 &self,
353 order: &OrderRequest,
354 current_state: &AccountState,
355 indicators: &TechnicalIndicators,
356 ) -> Result<(), RiskViolation> {
357 self.check_order(order, current_state)?;
359
360 if !self.ta_rules.is_empty() {
364 let ta_violations = check_ta_risk(indicators, order.price, &self.ta_rules);
365 if !ta_violations.is_empty() {
366 let highest = ta_violations
369 .into_iter()
370 .max_by_key(|v| match v.rule.as_str() {
371 "BollingerBreakout" | "CircuitBreaker" | "DailyLossGuard" => 2,
372 _ => 1, })
374 .unwrap(); return Err(highest);
376 }
377 }
378
379 Ok(())
380 }
381
382 pub fn check_order(
385 &self,
386 order: &OrderRequest,
387 current_state: &AccountState,
388 ) -> Result<(), RiskViolation> {
389 self.check_circuit_breaker(current_state)?;
391
392 self.check_position_limits(order, current_state)?;
394
395 self.check_daily_loss(current_state)?;
397
398 self.check_anomaly(order)?;
400
401 self.record_order(order);
403
404 Ok(())
405 }
406
407 fn check_circuit_breaker(&self, current_state: &AccountState) -> Result<(), RiskViolation> {
409 if !self.config.circuit_breaker.enabled {
410 return Ok(());
411 }
412
413 {
415 let mut tripped = self.circuit_breaker_tripped.lock().unwrap();
416 let tripped_at = self.circuit_breaker_tripped_at.lock().unwrap();
417 if *tripped {
418 if let Some(at) = *tripped_at {
419 let cooldown = std::time::Duration::from_secs(
420 self.config.circuit_breaker.cooldown_minutes as u64 * 60,
421 );
422 if at.elapsed() >= cooldown {
423 *tripped = false;
425 }
426 }
427 }
428
429 if *tripped {
430 return Err(RiskViolation {
431 rule: "CircuitBreaker".to_string(),
432 message: format!(
433 "Circuit breaker is active. All trading is paused for {} minutes.",
434 self.config.circuit_breaker.cooldown_minutes
435 ),
436 });
437 }
438 }
439
440 if current_state.windowed_loss >= self.config.circuit_breaker.trigger_loss {
442 self.trip_circuit_breaker();
443 return Err(RiskViolation {
444 rule: "CircuitBreaker".to_string(),
445 message: format!(
446 "Circuit breaker triggered: loss ${:.2} in the last {} minutes exceeds threshold ${:.2}. All trading paused.",
447 current_state.windowed_loss,
448 self.config.circuit_breaker.trigger_window_minutes,
449 self.config.circuit_breaker.trigger_loss,
450 ),
451 });
452 }
453
454 Ok(())
455 }
456
457 fn trip_circuit_breaker(&self) {
459 let mut tripped = self.circuit_breaker_tripped.lock().unwrap();
460 let mut tripped_at = self.circuit_breaker_tripped_at.lock().unwrap();
461 *tripped = true;
462 *tripped_at = Some(Instant::now());
463 }
464
465 pub fn is_circuit_breaker_tripped(&self) -> bool {
467 *self.circuit_breaker_tripped.lock().unwrap()
468 }
469
470 fn check_position_limits(
472 &self,
473 order: &OrderRequest,
474 current_state: &AccountState,
475 ) -> Result<(), RiskViolation> {
476 if !self.config.position_limits.enabled {
477 return Ok(());
478 }
479
480 let order_notional = order.size * order.price;
481
482 let new_total = current_state.total_position_value + order_notional;
484 if new_total > self.config.position_limits.max_total_position {
485 return Err(RiskViolation {
486 rule: "MaxPositionGuard".to_string(),
487 message: format!(
488 "Order would exceed total position limit. Current: ${:.2}, Order: ${:.2}, Limit: ${:.2}",
489 current_state.total_position_value,
490 order_notional,
491 self.config.position_limits.max_total_position,
492 ),
493 });
494 }
495
496 let current_symbol_pos = current_state
498 .position_by_symbol
499 .get(&order.symbol)
500 .copied()
501 .unwrap_or(0.0);
502 let new_symbol_pos = current_symbol_pos + order_notional;
503 if new_symbol_pos > self.config.position_limits.max_per_symbol {
504 return Err(RiskViolation {
505 rule: "MaxPositionGuard".to_string(),
506 message: format!(
507 "Order would exceed per-symbol position limit for '{}'. Current: ${:.2}, Order: ${:.2}, Limit: ${:.2}",
508 order.symbol,
509 current_symbol_pos,
510 order_notional,
511 self.config.position_limits.max_per_symbol,
512 ),
513 });
514 }
515
516 Ok(())
517 }
518
519 fn check_daily_loss(&self, current_state: &AccountState) -> Result<(), RiskViolation> {
521 if !self.config.daily_loss_limits.enabled {
522 return Ok(());
523 }
524
525 if current_state.daily_realized_loss >= self.config.daily_loss_limits.max_daily_loss {
527 return Err(RiskViolation {
528 rule: "DailyLossGuard".to_string(),
529 message: format!(
530 "Daily loss limit reached. Loss today: ${:.2}, Limit: ${:.2}",
531 current_state.daily_realized_loss, self.config.daily_loss_limits.max_daily_loss,
532 ),
533 });
534 }
535
536 if current_state.daily_starting_equity > 0.0 {
538 let loss_pct =
539 (current_state.daily_realized_loss / current_state.daily_starting_equity) * 100.0;
540 if loss_pct >= self.config.daily_loss_limits.max_daily_loss_percent {
541 return Err(RiskViolation {
542 rule: "DailyLossGuard".to_string(),
543 message: format!(
544 "Daily loss percentage limit reached. Loss: {:.2}%, Limit: {:.2}%",
545 loss_pct, self.config.daily_loss_limits.max_daily_loss_percent,
546 ),
547 });
548 }
549 }
550
551 Ok(())
552 }
553
554 fn check_anomaly(&self, order: &OrderRequest) -> Result<(), RiskViolation> {
556 if !self.config.anomaly_detection.enabled {
557 return Ok(());
558 }
559
560 let order_notional = order.size * order.price;
561
562 if order_notional > self.config.anomaly_detection.max_order_size {
564 return Err(RiskViolation {
565 rule: "AnomalyDetection".to_string(),
566 message: format!(
567 "Order size ${:.2} exceeds maximum allowed order size ${:.2}",
568 order_notional, self.config.anomaly_detection.max_order_size,
569 ),
570 });
571 }
572
573 let mut recent = self.recent_orders.lock().unwrap();
574
575 let one_minute_ago = Instant::now() - std::time::Duration::from_secs(60);
577 recent.retain(|o| o.timestamp > one_minute_ago);
578
579 if recent.len() >= self.config.anomaly_detection.max_orders_per_minute as usize {
581 return Err(RiskViolation {
582 rule: "AnomalyDetection".to_string(),
583 message: format!(
584 "Rate limit exceeded: {} orders in the last minute (limit: {})",
585 recent.len(),
586 self.config.anomaly_detection.max_orders_per_minute,
587 ),
588 });
589 }
590
591 if self.config.anomaly_detection.block_duplicate_orders {
593 let is_duplicate = recent.iter().any(|o| {
594 o.symbol == order.symbol
595 && o.side == order.side
596 && (o.size - order.size).abs() < f64::EPSILON
597 });
598 if is_duplicate {
599 return Err(RiskViolation {
600 rule: "AnomalyDetection".to_string(),
601 message: format!(
602 "Duplicate order detected: {} {} {:.4} on symbol '{}'",
603 order.side, order.size, order.price, order.symbol,
604 ),
605 });
606 }
607 }
608
609 Ok(())
610 }
611
612 fn record_order(&self, order: &OrderRequest) {
614 let mut recent = self.recent_orders.lock().unwrap();
615 recent.push(RecentOrder {
616 symbol: order.symbol.clone(),
617 side: order.side.clone(),
618 size: order.size,
619 timestamp: Instant::now(),
620 });
621 }
622}
623
624pub fn record_risk_alert(violation: &RiskViolation) {
630 let entry = AlertHistoryEntry {
631 id: uuid::Uuid::new_v4().to_string(),
632 timestamp: chrono::Utc::now().to_rfc3339(),
633 severity: match violation.rule.as_str() {
634 "CircuitBreaker" | "DailyLossGuard" | "BollingerBreakout" => "critical".to_string(),
635 "RsiExtreme" | "AtrSurge" => "warning".to_string(),
636 _ => "warning".to_string(),
637 },
638 message: violation.to_string(),
639 channel: "risk_guard".to_string(),
640 status: "fired".to_string(),
641 };
642
643 let path = alert_history_path();
644 let mut history: Vec<AlertHistoryEntry> = if path.exists() {
645 fs::read_to_string(&path)
646 .ok()
647 .and_then(|data| serde_json::from_str(&data).ok())
648 .unwrap_or_default()
649 } else {
650 Vec::new()
651 };
652
653 history.push(entry);
654
655 if let Ok(json) = serde_json::to_string_pretty(&history) {
656 let _ = fs::write(&path, json);
657 }
658}
659
660pub fn check_risk(order: &OrderRequest, account_state: &AccountState) -> Result<(), String> {
676 let config = get_risk_config_sync();
677 let guard = RiskGuard::new(config);
678 match guard.check_order(order, account_state) {
679 Ok(()) => Ok(()),
680 Err(violation) => {
681 record_risk_alert(&violation);
682 Err(format!("Risk violation: {}", violation))
683 }
684 }
685}
686
687#[cfg(test)]
692mod tests {
693 use super::*;
694
695 fn default_guard() -> RiskGuard {
696 RiskGuard::new(RiskConfig::default())
697 }
698
699 fn default_state() -> AccountState {
700 AccountState::default()
701 }
702
703 fn make_order(symbol: &str, side: &str, size: f64, price: f64) -> OrderRequest {
704 OrderRequest {
705 symbol: symbol.to_string(),
706 side: side.to_string(),
707 size,
708 price,
709 }
710 }
711
712 #[test]
715 fn test_order_within_position_limits_passes() {
716 let guard = default_guard();
717 let state = default_state();
718 let order = make_order("BTC", "buy", 0.1, 30_000.0); assert!(guard.check_order(&order, &state).is_ok());
720 }
721
722 #[test]
723 fn test_order_exceeds_total_position_limit() {
724 let guard = default_guard();
725 let state = AccountState {
726 total_position_value: 99_000.0,
727 ..default_state()
728 };
729 let order = make_order("BTC", "buy", 1.0, 2_000.0);
731 let result = guard.check_order(&order, &state);
732 assert!(result.is_err());
733 let violation = result.unwrap_err();
734 assert_eq!(violation.rule, "MaxPositionGuard");
735 assert!(violation.message.contains("total position limit"));
736 }
737
738 #[test]
739 fn test_order_exceeds_per_symbol_position_limit() {
740 let guard = default_guard();
741 let mut position_by_symbol = HashMap::new();
742 position_by_symbol.insert("BTC".to_string(), 24_000.0);
743 let state = AccountState {
744 total_position_value: 24_000.0,
745 position_by_symbol,
746 ..default_state()
747 };
748 let order = make_order("BTC", "buy", 1.0, 2_000.0);
750 let result = guard.check_order(&order, &state);
751 assert!(result.is_err());
752 let violation = result.unwrap_err();
753 assert_eq!(violation.rule, "MaxPositionGuard");
754 assert!(violation.message.contains("per-symbol"));
755 }
756
757 #[test]
758 fn test_position_limits_disabled() {
759 let config = RiskConfig {
760 position_limits: PositionLimits {
761 enabled: false,
762 max_total_position: 1.0,
763 max_per_symbol: 1.0,
764 },
765 anomaly_detection: AnomalyDetection {
766 enabled: true,
767 max_order_size: 1_000_000.0, max_orders_per_minute: 10,
769 block_duplicate_orders: false,
770 },
771 ..RiskConfig::default()
772 };
773 let guard = RiskGuard::new(config);
774 let state = AccountState {
775 total_position_value: 999_999.0,
776 ..default_state()
777 };
778 let order = make_order("BTC", "buy", 1.0, 100_000.0);
779 assert!(guard.check_order(&order, &state).is_ok());
780 }
781
782 #[test]
785 fn test_daily_loss_limit_absolute() {
786 let guard = default_guard();
787 let state = AccountState {
788 daily_realized_loss: 5_000.0, ..default_state()
790 };
791 let order = make_order("BTC", "buy", 0.01, 30_000.0);
792 let result = guard.check_order(&order, &state);
793 assert!(result.is_err());
794 let violation = result.unwrap_err();
795 assert_eq!(violation.rule, "DailyLossGuard");
796 }
797
798 #[test]
799 fn test_daily_loss_limit_percentage() {
800 let guard = default_guard();
801 let state = AccountState {
802 daily_realized_loss: 5_500.0,
803 daily_starting_equity: 100_000.0, ..default_state()
805 };
806 let order = make_order("BTC", "buy", 0.01, 30_000.0);
807 let result = guard.check_order(&order, &state);
808 assert!(result.is_err());
809 let violation = result.unwrap_err();
810 assert_eq!(violation.rule, "DailyLossGuard");
811 }
812
813 #[test]
814 fn test_daily_loss_below_limit_passes() {
815 let guard = default_guard();
816 let state = AccountState {
817 daily_realized_loss: 1_000.0,
818 daily_starting_equity: 100_000.0,
819 ..default_state()
820 };
821 let order = make_order("BTC", "buy", 0.01, 30_000.0);
822 assert!(guard.check_order(&order, &state).is_ok());
823 }
824
825 #[test]
828 fn test_max_order_size_exceeded() {
829 let config = RiskConfig {
831 position_limits: PositionLimits {
832 enabled: false,
833 ..RiskConfig::default().position_limits
834 },
835 anomaly_detection: AnomalyDetection {
836 enabled: true,
837 max_order_size: 50_000.0,
838 max_orders_per_minute: 10,
839 block_duplicate_orders: true,
840 },
841 ..RiskConfig::default()
842 };
843 let guard = RiskGuard::new(config);
844 let state = default_state();
845 let order = make_order("BTC", "buy", 2.0, 30_000.0);
847 let result = guard.check_order(&order, &state);
848 assert!(result.is_err());
849 let violation = result.unwrap_err();
850 assert_eq!(violation.rule, "AnomalyDetection");
851 assert!(violation.message.contains("maximum allowed order size"));
852 }
853
854 #[test]
855 fn test_orders_per_minute_exceeded() {
856 let config = RiskConfig {
857 anomaly_detection: AnomalyDetection {
858 enabled: true,
859 max_order_size: 1_000_000.0,
860 max_orders_per_minute: 3,
861 block_duplicate_orders: false,
862 },
863 ..RiskConfig::default()
864 };
865 let guard = RiskGuard::new(config);
866 let state = default_state();
867
868 for i in 0..3 {
870 let order = make_order("BTC", "buy", 0.01 * (i as f64 + 1.0), 30_000.0);
871 assert!(guard.check_order(&order, &state).is_ok());
872 }
873
874 let order = make_order("ETH", "buy", 0.1, 2_000.0);
876 let result = guard.check_order(&order, &state);
877 assert!(result.is_err());
878 let violation = result.unwrap_err();
879 assert_eq!(violation.rule, "AnomalyDetection");
880 assert!(violation.message.contains("Rate limit"));
881 }
882
883 #[test]
884 fn test_duplicate_order_detection() {
885 let guard = default_guard();
886 let state = default_state();
887
888 let order = make_order("BTC", "buy", 0.1, 30_000.0);
889 assert!(guard.check_order(&order, &state).is_ok());
890
891 let duplicate = make_order("BTC", "buy", 0.1, 31_000.0);
893 let result = guard.check_order(&duplicate, &state);
894 assert!(result.is_err());
895 let violation = result.unwrap_err();
896 assert_eq!(violation.rule, "AnomalyDetection");
897 assert!(violation.message.contains("Duplicate"));
898 }
899
900 #[test]
901 fn test_non_duplicate_different_side_passes() {
902 let guard = default_guard();
903 let state = default_state();
904
905 let order1 = make_order("BTC", "buy", 0.1, 30_000.0);
906 assert!(guard.check_order(&order1, &state).is_ok());
907
908 let order2 = make_order("BTC", "sell", 0.1, 30_000.0);
910 assert!(guard.check_order(&order2, &state).is_ok());
911 }
912
913 #[test]
916 fn test_circuit_breaker_triggers_on_windowed_loss() {
917 let config = RiskConfig {
918 circuit_breaker: CircuitBreaker {
919 enabled: true,
920 trigger_loss: 10_000.0,
921 trigger_window_minutes: 60,
922 action: "pause_all".to_string(),
923 cooldown_minutes: 30,
924 },
925 ..RiskConfig::default()
926 };
927 let guard = RiskGuard::new(config);
928 let state = AccountState {
929 windowed_loss: 10_000.0, ..default_state()
931 };
932
933 let order = make_order("BTC", "buy", 0.01, 30_000.0);
934 let result = guard.check_order(&order, &state);
935 assert!(result.is_err());
936 let violation = result.unwrap_err();
937 assert_eq!(violation.rule, "CircuitBreaker");
938 assert!(guard.is_circuit_breaker_tripped());
939 }
940
941 #[test]
942 fn test_circuit_breaker_blocks_subsequent_orders() {
943 let config = RiskConfig {
944 circuit_breaker: CircuitBreaker {
945 enabled: true,
946 trigger_loss: 10_000.0,
947 trigger_window_minutes: 60,
948 action: "pause_all".to_string(),
949 cooldown_minutes: 30,
950 },
951 ..RiskConfig::default()
952 };
953 let guard = RiskGuard::new(config);
954
955 let bad_state = AccountState {
957 windowed_loss: 15_000.0,
958 ..default_state()
959 };
960 let order = make_order("BTC", "buy", 0.01, 30_000.0);
961 let _ = guard.check_order(&order, &bad_state);
962
963 let clean_state = default_state();
965 let result = guard.check_order(&order, &clean_state);
966 assert!(result.is_err());
967 assert_eq!(result.unwrap_err().rule, "CircuitBreaker");
968 }
969
970 #[test]
971 fn test_circuit_breaker_disabled() {
972 let config = RiskConfig {
973 circuit_breaker: CircuitBreaker {
974 enabled: false,
975 trigger_loss: 1.0,
976 trigger_window_minutes: 60,
977 action: "pause_all".to_string(),
978 cooldown_minutes: 30,
979 },
980 ..RiskConfig::default()
981 };
982 let guard = RiskGuard::new(config);
983 let state = AccountState {
984 windowed_loss: 999_999.0,
985 ..default_state()
986 };
987 let order = make_order("BTC", "buy", 0.01, 30_000.0);
988 assert!(guard.check_order(&order, &state).is_ok());
989 }
990
991 #[test]
994 fn test_all_rules_pass_for_small_order() {
995 let guard = default_guard();
996 let state = default_state();
997 let order = make_order("BTC", "buy", 0.01, 30_000.0); assert!(guard.check_order(&order, &state).is_ok());
999 }
1000
1001 #[test]
1004 fn test_risk_violation_display() {
1005 let v = RiskViolation {
1006 rule: "TestRule".to_string(),
1007 message: "something went wrong".to_string(),
1008 };
1009 assert_eq!(format!("{}", v), "[TestRule] something went wrong");
1010 }
1011
1012 #[test]
1015 fn test_update_config() {
1016 let config = RiskConfig {
1018 position_limits: PositionLimits {
1019 enabled: false,
1020 ..RiskConfig::default().position_limits
1021 },
1022 anomaly_detection: AnomalyDetection {
1023 enabled: true,
1024 max_order_size: 50_000.0,
1025 max_orders_per_minute: 10,
1026 block_duplicate_orders: false,
1027 },
1028 ..RiskConfig::default()
1029 };
1030 let mut guard = RiskGuard::new(config);
1031 let state = default_state();
1032 let big_order = make_order("BTC", "buy", 2.0, 30_000.0); assert!(guard.check_order(&big_order, &state).is_err());
1036
1037 let mut new_config = RiskConfig::default();
1039 new_config.position_limits.enabled = false;
1040 new_config.anomaly_detection.max_order_size = 100_000.0;
1041 new_config.anomaly_detection.block_duplicate_orders = false;
1042 guard.update_config(new_config);
1043
1044 assert!(guard.check_order(&big_order, &state).is_ok());
1045 }
1046
1047 fn make_indicators(
1050 rsi: Option<f64>,
1051 atr: Option<f64>,
1052 bb: Option<(f64, f64, f64)>,
1053 ) -> TechnicalIndicators {
1054 let mut ind = TechnicalIndicators::empty();
1055 ind.rsi_14 = rsi;
1056 ind.atr_14 = atr;
1057 if let Some((upper, middle, lower)) = bb {
1058 ind.bb_upper = Some(upper);
1059 ind.bb_middle = Some(middle);
1060 ind.bb_lower = Some(lower);
1061 }
1062 ind
1063 }
1064
1065 #[test]
1068 fn test_rsi_extreme_below_lower() {
1069 let rules = vec![TaRiskRule::RsiExtreme {
1070 lower: 20.0,
1071 upper: 80.0,
1072 }];
1073 let ind = make_indicators(Some(15.0), None, None);
1074 let violations = check_ta_risk(&ind, 100.0, &rules);
1075 assert_eq!(violations.len(), 1);
1076 assert_eq!(violations[0].rule, "RsiExtreme");
1077 assert!(violations[0].message.contains("below"));
1078 }
1079
1080 #[test]
1081 fn test_rsi_extreme_above_upper() {
1082 let rules = vec![TaRiskRule::RsiExtreme {
1083 lower: 20.0,
1084 upper: 80.0,
1085 }];
1086 let ind = make_indicators(Some(85.0), None, None);
1087 let violations = check_ta_risk(&ind, 100.0, &rules);
1088 assert_eq!(violations.len(), 1);
1089 assert_eq!(violations[0].rule, "RsiExtreme");
1090 assert!(violations[0].message.contains("above"));
1091 }
1092
1093 #[test]
1094 fn test_rsi_within_range_passes() {
1095 let rules = vec![TaRiskRule::RsiExtreme {
1096 lower: 20.0,
1097 upper: 80.0,
1098 }];
1099 let ind = make_indicators(Some(50.0), None, None);
1100 let violations = check_ta_risk(&ind, 100.0, &rules);
1101 assert!(violations.is_empty());
1102 }
1103
1104 #[test]
1105 fn test_rsi_at_boundary_passes() {
1106 let rules = vec![TaRiskRule::RsiExtreme {
1107 lower: 20.0,
1108 upper: 80.0,
1109 }];
1110 let ind_lower = make_indicators(Some(20.0), None, None);
1112 assert!(check_ta_risk(&ind_lower, 100.0, &rules).is_empty());
1113 let ind_upper = make_indicators(Some(80.0), None, None);
1114 assert!(check_ta_risk(&ind_upper, 100.0, &rules).is_empty());
1115 }
1116
1117 #[test]
1118 fn test_rsi_none_skips_check() {
1119 let rules = vec![TaRiskRule::RsiExtreme {
1120 lower: 20.0,
1121 upper: 80.0,
1122 }];
1123 let ind = make_indicators(None, None, None);
1124 assert!(check_ta_risk(&ind, 100.0, &rules).is_empty());
1125 }
1126
1127 #[test]
1130 fn test_atr_surge_exceeds_threshold() {
1131 let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1132 let ind = make_indicators(None, Some(60.0), None);
1134 let violations = check_ta_risk(&ind, 1000.0, &rules);
1135 assert_eq!(violations.len(), 1);
1136 assert_eq!(violations[0].rule, "AtrSurge");
1137 assert!(violations[0].message.contains("6.00%"));
1138 }
1139
1140 #[test]
1141 fn test_atr_surge_below_threshold_passes() {
1142 let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1143 let ind = make_indicators(None, Some(40.0), None);
1145 assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1146 }
1147
1148 #[test]
1149 fn test_atr_surge_at_threshold_passes() {
1150 let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1151 let ind = make_indicators(None, Some(50.0), None);
1153 assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1154 }
1155
1156 #[test]
1157 fn test_atr_surge_zero_price_skips() {
1158 let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1159 let ind = make_indicators(None, Some(60.0), None);
1160 assert!(check_ta_risk(&ind, 0.0, &rules).is_empty());
1161 }
1162
1163 #[test]
1164 fn test_atr_none_skips_check() {
1165 let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1166 let ind = make_indicators(None, None, None);
1167 assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1168 }
1169
1170 #[test]
1173 fn test_bollinger_breakout_above_upper() {
1174 let rules = vec![TaRiskRule::BollingerBreakout {
1177 sigma_multiplier: 3.0,
1178 }];
1179 let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1180 let violations = check_ta_risk(&ind, 120.0, &rules);
1181 assert_eq!(violations.len(), 1);
1182 assert_eq!(violations[0].rule, "BollingerBreakout");
1183 assert!(violations[0].message.contains("above"));
1184 }
1185
1186 #[test]
1187 fn test_bollinger_breakout_below_lower() {
1188 let rules = vec![TaRiskRule::BollingerBreakout {
1190 sigma_multiplier: 3.0,
1191 }];
1192 let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1193 let violations = check_ta_risk(&ind, 80.0, &rules);
1194 assert_eq!(violations.len(), 1);
1195 assert_eq!(violations[0].rule, "BollingerBreakout");
1196 assert!(violations[0].message.contains("below"));
1197 }
1198
1199 #[test]
1200 fn test_bollinger_within_bands_passes() {
1201 let rules = vec![TaRiskRule::BollingerBreakout {
1203 sigma_multiplier: 3.0,
1204 }];
1205 let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1206 assert!(check_ta_risk(&ind, 100.0, &rules).is_empty());
1207 }
1208
1209 #[test]
1210 fn test_bollinger_none_skips_check() {
1211 let rules = vec![TaRiskRule::BollingerBreakout {
1212 sigma_multiplier: 3.0,
1213 }];
1214 let ind = make_indicators(None, None, None);
1215 assert!(check_ta_risk(&ind, 120.0, &rules).is_empty());
1216 }
1217
1218 #[test]
1221 fn test_multiple_rules_all_triggered() {
1222 let rules = vec![
1223 TaRiskRule::RsiExtreme {
1224 lower: 20.0,
1225 upper: 80.0,
1226 },
1227 TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1228 TaRiskRule::BollingerBreakout {
1229 sigma_multiplier: 3.0,
1230 },
1231 ];
1232 let mut ind = TechnicalIndicators::empty();
1234 ind.rsi_14 = Some(90.0);
1235 ind.atr_14 = Some(60.0);
1236 ind.bb_upper = Some(110.0);
1237 ind.bb_middle = Some(100.0);
1238 ind.bb_lower = Some(90.0);
1239 let violations = check_ta_risk(&ind, 120.0, &rules);
1242 assert_eq!(violations.len(), 3);
1243 let rule_names: Vec<&str> = violations.iter().map(|v| v.rule.as_str()).collect();
1244 assert!(rule_names.contains(&"RsiExtreme"));
1245 assert!(rule_names.contains(&"AtrSurge"));
1246 assert!(rule_names.contains(&"BollingerBreakout"));
1247 }
1248
1249 #[test]
1250 fn test_empty_rules_no_violations() {
1251 let rules: Vec<TaRiskRule> = vec![];
1252 let ind = make_indicators(Some(10.0), Some(100.0), Some((110.0, 100.0, 90.0)));
1253 assert!(check_ta_risk(&ind, 120.0, &rules).is_empty());
1254 }
1255
1256 #[test]
1259 fn test_check_order_with_ta_blocks_on_rsi() {
1260 let ta_rules = vec![TaRiskRule::RsiExtreme {
1261 lower: 20.0,
1262 upper: 80.0,
1263 }];
1264 let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1265 let state = default_state();
1266 let order = make_order("BTC", "buy", 0.01, 30_000.0);
1267 let ind = make_indicators(Some(85.0), None, None);
1268 let result = guard.check_order_with_ta(&order, &state, &ind);
1269 assert!(result.is_err());
1270 assert_eq!(result.unwrap_err().rule, "RsiExtreme");
1271 }
1272
1273 #[test]
1274 fn test_check_order_with_ta_passes_when_clean() {
1275 let ta_rules = vec![
1276 TaRiskRule::RsiExtreme {
1277 lower: 20.0,
1278 upper: 80.0,
1279 },
1280 TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1281 ];
1282 let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1283 let state = default_state();
1284 let order = make_order("BTC", "buy", 0.01, 30_000.0);
1285 let ind = make_indicators(Some(50.0), Some(100.0), None); assert!(guard.check_order_with_ta(&order, &state, &ind).is_ok());
1287 }
1288
1289 #[test]
1290 fn test_set_ta_rules() {
1291 let mut guard = default_guard();
1292 let state = default_state();
1293 let order1 = make_order("BTC", "buy", 0.01, 30_000.0);
1294 let ind = make_indicators(Some(85.0), None, None);
1295
1296 assert!(guard.check_order_with_ta(&order1, &state, &ind).is_ok());
1298
1299 guard.set_ta_rules(vec![TaRiskRule::RsiExtreme {
1301 lower: 20.0,
1302 upper: 80.0,
1303 }]);
1304 let order2 = make_order("BTC", "buy", 0.02, 30_000.0);
1305 let result = guard.check_order_with_ta(&order2, &state, &ind);
1306 assert!(result.is_err());
1307 assert_eq!(result.unwrap_err().rule, "RsiExtreme");
1308 }
1309
1310 #[test]
1311 fn test_check_order_with_ta_returns_highest_severity_violation() {
1312 let ta_rules = vec![
1315 TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1316 TaRiskRule::BollingerBreakout {
1317 sigma_multiplier: 3.0,
1318 },
1319 ];
1320 let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1321 let state = default_state();
1322 let mut ind = TechnicalIndicators::empty();
1325 ind.atr_14 = Some(60.0);
1326 ind.bb_upper = Some(110.0);
1327 ind.bb_middle = Some(100.0);
1328 ind.bb_lower = Some(90.0);
1329 let order = make_order("BTC", "buy", 0.01, 120.0);
1330 let result = guard.check_order_with_ta(&order, &state, &ind);
1331 assert!(result.is_err());
1332 assert_eq!(
1333 result.unwrap_err().rule,
1334 "BollingerBreakout",
1335 "should return the critical-level violation, not the warning-level one"
1336 );
1337 }
1338
1339 #[test]
1340 fn test_check_order_with_ta_circuit_breaker_tripped_ta_clean() {
1341 let config = RiskConfig {
1344 circuit_breaker: CircuitBreaker {
1345 enabled: true,
1346 trigger_loss: 10_000.0,
1347 trigger_window_minutes: 60,
1348 action: "pause_all".to_string(),
1349 cooldown_minutes: 30,
1350 },
1351 ..RiskConfig::default()
1352 };
1353 let ta_rules = vec![
1354 TaRiskRule::RsiExtreme {
1355 lower: 20.0,
1356 upper: 80.0,
1357 },
1358 TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1359 ];
1360 let guard = RiskGuard::with_ta_rules(config, ta_rules);
1361 let state = AccountState {
1362 windowed_loss: 15_000.0, ..default_state()
1364 };
1365 let ind = make_indicators(Some(50.0), Some(10.0), None);
1367 let order = make_order("BTC", "buy", 0.01, 30_000.0);
1368 let result = guard.check_order_with_ta(&order, &state, &ind);
1369 assert!(result.is_err());
1370 assert_eq!(
1371 result.unwrap_err().rule,
1372 "CircuitBreaker",
1373 "standard circuit breaker should still block even when TA is clean"
1374 );
1375 }
1376
1377 #[test]
1378 fn test_bollinger_partial_data_skips_check() {
1379 let rules = vec![TaRiskRule::BollingerBreakout {
1381 sigma_multiplier: 3.0,
1382 }];
1383 let mut ind = TechnicalIndicators::empty();
1384 ind.bb_middle = Some(100.0);
1385 let violations = check_ta_risk(&ind, 120.0, &rules);
1387 assert!(
1388 violations.is_empty(),
1389 "partial Bollinger data (bb_upper=None) should skip the check, not panic"
1390 );
1391 }
1392
1393 #[test]
1396 fn test_check_risk_passes_small_order() {
1397 let order = make_order("BTC", "buy", 0.01, 30_000.0); let state = default_state();
1399 assert!(check_risk(&order, &state).is_ok());
1400 }
1401
1402 #[test]
1403 fn test_check_risk_blocks_on_daily_loss() {
1404 let order = make_order("BTC", "buy", 0.01, 30_000.0);
1405 let state = AccountState {
1406 daily_realized_loss: 10_000.0, ..default_state()
1408 };
1409 let result = check_risk(&order, &state);
1410 assert!(result.is_err());
1411 let err_msg = result.unwrap_err();
1412 assert!(err_msg.contains("Risk violation"), "Error: {}", err_msg);
1413 }
1414}