Skip to main content

hyper_risk/
risk.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::PathBuf;
5use std::sync::Mutex;
6use std::time::Instant;
7
8use crate::alerts::{alert_history_path, AlertHistoryEntry};
9use hyper_ta::TechnicalIndicators;
10
11// ---------------------------------------------------------------------------
12// Risk Config types (existing)
13// ---------------------------------------------------------------------------
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(rename_all = "camelCase")]
17pub struct PositionLimits {
18    pub enabled: bool,
19    pub max_total_position: f64,
20    pub max_per_symbol: f64,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct DailyLossLimits {
26    pub enabled: bool,
27    pub max_daily_loss: f64,
28    pub max_daily_loss_percent: f64,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct AnomalyDetection {
34    pub enabled: bool,
35    pub max_order_size: f64,
36    pub max_orders_per_minute: u32,
37    pub block_duplicate_orders: bool,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct CircuitBreaker {
43    pub enabled: bool,
44    pub trigger_loss: f64,
45    pub trigger_window_minutes: u32,
46    pub action: String,
47    pub cooldown_minutes: u32,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51#[serde(rename_all = "camelCase")]
52pub struct RiskConfig {
53    pub position_limits: PositionLimits,
54    pub daily_loss_limits: DailyLossLimits,
55    pub anomaly_detection: AnomalyDetection,
56    pub circuit_breaker: CircuitBreaker,
57}
58
59impl Default for RiskConfig {
60    fn default() -> Self {
61        Self {
62            position_limits: PositionLimits {
63                enabled: true,
64                max_total_position: 100_000.0,
65                max_per_symbol: 25_000.0,
66            },
67            daily_loss_limits: DailyLossLimits {
68                enabled: true,
69                max_daily_loss: 5_000.0,
70                max_daily_loss_percent: 5.0,
71            },
72            anomaly_detection: AnomalyDetection {
73                enabled: true,
74                max_order_size: 50_000.0,
75                max_orders_per_minute: 10,
76                block_duplicate_orders: true,
77            },
78            circuit_breaker: CircuitBreaker {
79                enabled: false,
80                trigger_loss: 10_000.0,
81                trigger_window_minutes: 60,
82                action: "pause_all".to_string(),
83                cooldown_minutes: 30,
84            },
85        }
86    }
87}
88
89pub fn risk_config_path() -> PathBuf {
90    let mut path = dirs::config_dir().unwrap_or_else(|| PathBuf::from("."));
91    path.push("hyper-agent");
92    let _ = fs::create_dir_all(&path);
93    path.push("risk-config.json");
94    path
95}
96
97/// Synchronous version for loading risk config at startup.
98pub fn get_risk_config_sync() -> RiskConfig {
99    let path = risk_config_path();
100    if path.exists() {
101        fs::read_to_string(&path)
102            .ok()
103            .and_then(|data| serde_json::from_str(&data).ok())
104            .unwrap_or_default()
105    } else {
106        RiskConfig::default()
107    }
108}
109
110pub fn save_risk_config_to_disk(config: &RiskConfig) -> Result<(), String> {
111    let path = risk_config_path();
112    let json = serde_json::to_string_pretty(config).map_err(|e| e.to_string())?;
113    fs::write(&path, json).map_err(|e| e.to_string())?;
114    Ok(())
115}
116
117// ---------------------------------------------------------------------------
118// RiskGuard types
119// ---------------------------------------------------------------------------
120
121/// Represents an incoming order to be checked against risk rules.
122#[derive(Debug, Clone)]
123pub struct OrderRequest {
124    /// Symbol/asset identifier (e.g. asset index as string).
125    pub symbol: String,
126    /// "buy" or "sell".
127    pub side: String,
128    /// Order size in USD notional value.
129    pub size: f64,
130    /// Order price.
131    pub price: f64,
132}
133
134/// Snapshot of current account state used for risk evaluation.
135#[derive(Debug, Clone)]
136pub struct AccountState {
137    /// Total position value across all symbols (in USD).
138    pub total_position_value: f64,
139    /// Per-symbol position values (in USD).
140    pub position_by_symbol: HashMap<String, f64>,
141    /// Realized loss for the current day (positive value = loss).
142    pub daily_realized_loss: f64,
143    /// Starting equity for the day (for percentage calculation).
144    pub daily_starting_equity: f64,
145    /// Losses within the circuit breaker window (positive = loss).
146    pub windowed_loss: f64,
147}
148
149impl Default for AccountState {
150    fn default() -> Self {
151        Self {
152            total_position_value: 0.0,
153            position_by_symbol: HashMap::new(),
154            daily_realized_loss: 0.0,
155            daily_starting_equity: 100_000.0,
156            windowed_loss: 0.0,
157        }
158    }
159}
160
161/// Describes a risk violation that blocks order execution.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct RiskViolation {
164    pub rule: String,
165    pub message: String,
166}
167
168impl std::fmt::Display for RiskViolation {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        write!(f, "[{}] {}", self.rule, self.message)
171    }
172}
173
174/// Tracks recent order timestamps and signatures for anomaly detection.
175#[derive(Debug, Clone)]
176struct RecentOrder {
177    symbol: String,
178    side: String,
179    size: f64,
180    timestamp: Instant,
181}
182
183// ---------------------------------------------------------------------------
184// TA-based risk rules
185// ---------------------------------------------------------------------------
186
187/// Technical-analysis-based risk rule definitions.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(tag = "type", rename_all = "snake_case")]
190pub enum TaRiskRule {
191    /// Block opening new positions when RSI is outside the safe range.
192    RsiExtreme {
193        /// RSI below this value is considered oversold extreme (default 20).
194        lower: f64,
195        /// RSI above this value is considered overbought extreme (default 80).
196        upper: f64,
197    },
198    /// Warn / reduce position when ATR as a percentage of price exceeds threshold.
199    AtrSurge {
200        /// ATR / price * 100 threshold (e.g. 5.0 means 5%).
201        threshold_pct: f64,
202    },
203    /// Block trading when price is outside Bollinger Bands by the given sigma multiplier.
204    BollingerBreakout {
205        /// Number of standard deviations (e.g. 3.0 for 3σ bands).
206        sigma_multiplier: f64,
207    },
208}
209
210/// Check a set of TA-based risk rules against current indicators and price.
211///
212/// Returns a list of violations (empty if all rules pass).
213pub fn check_ta_risk(
214    indicators: &TechnicalIndicators,
215    price: f64,
216    rules: &[TaRiskRule],
217) -> Vec<RiskViolation> {
218    let mut violations = Vec::new();
219
220    for rule in rules {
221        match rule {
222            TaRiskRule::RsiExtreme { lower, upper } => {
223                if let Some(rsi) = indicators.rsi_14 {
224                    if rsi < *lower {
225                        violations.push(RiskViolation {
226                            rule: "RsiExtreme".to_string(),
227                            message: format!(
228                                "RSI ({:.2}) is below the lower extreme threshold ({:.2}). Opening new positions is restricted.",
229                                rsi, lower,
230                            ),
231                        });
232                    } else if rsi > *upper {
233                        violations.push(RiskViolation {
234                            rule: "RsiExtreme".to_string(),
235                            message: format!(
236                                "RSI ({:.2}) is above the upper extreme threshold ({:.2}). Opening new positions is restricted.",
237                                rsi, upper,
238                            ),
239                        });
240                    }
241                }
242            }
243            TaRiskRule::AtrSurge { threshold_pct } => {
244                if let Some(atr) = indicators.atr_14 {
245                    if price > 0.0 {
246                        let atr_pct = (atr / price) * 100.0;
247                        if atr_pct > *threshold_pct {
248                            violations.push(RiskViolation {
249                                rule: "AtrSurge".to_string(),
250                                message: format!(
251                                    "ATR as percentage of price ({:.2}%) exceeds threshold ({:.2}%). Consider reducing position size.",
252                                    atr_pct, threshold_pct,
253                                ),
254                            });
255                        }
256                    }
257                }
258            }
259            TaRiskRule::BollingerBreakout { sigma_multiplier } => {
260                // Use bb_middle and bb_upper/bb_lower to derive the standard deviation,
261                // then compute custom bands at the requested sigma multiplier.
262                if let (Some(bb_middle), Some(bb_upper)) =
263                    (indicators.bb_middle, indicators.bb_upper)
264                {
265                    // Default BB uses 2σ, so σ = (bb_upper - bb_middle) / 2
266                    let default_sigma = 2.0_f64;
267                    let std_dev = (bb_upper - bb_middle) / default_sigma;
268                    if std_dev > 0.0 {
269                        let custom_upper = bb_middle + sigma_multiplier * std_dev;
270                        let custom_lower = bb_middle - sigma_multiplier * std_dev;
271                        if price > custom_upper {
272                            violations.push(RiskViolation {
273                                rule: "BollingerBreakout".to_string(),
274                                message: format!(
275                                    "Price ({:.2}) is above the {:.1}σ Bollinger upper band ({:.2}). Trading is paused.",
276                                    price, sigma_multiplier, custom_upper,
277                                ),
278                            });
279                        } else if price < custom_lower {
280                            violations.push(RiskViolation {
281                                rule: "BollingerBreakout".to_string(),
282                                message: format!(
283                                    "Price ({:.2}) is below the {:.1}σ Bollinger lower band ({:.2}). Trading is paused.",
284                                    price, sigma_multiplier, custom_lower,
285                                ),
286                            });
287                        }
288                    }
289                }
290            }
291        }
292    }
293
294    violations
295}
296
297// ---------------------------------------------------------------------------
298// RiskGuard
299// ---------------------------------------------------------------------------
300
301pub struct RiskGuard {
302    config: RiskConfig,
303    /// TA-based risk rules evaluated against technical indicators.
304    ta_rules: Vec<TaRiskRule>,
305    /// Circuit breaker tripped flag. When true, all orders are blocked.
306    circuit_breaker_tripped: Mutex<bool>,
307    /// Timestamp when the circuit breaker was tripped.
308    circuit_breaker_tripped_at: Mutex<Option<Instant>>,
309    /// Recent orders for anomaly detection (orders per minute, duplicate detection).
310    recent_orders: Mutex<Vec<RecentOrder>>,
311}
312
313impl RiskGuard {
314    pub fn new(config: RiskConfig) -> Self {
315        Self {
316            config,
317            ta_rules: Vec::new(),
318            circuit_breaker_tripped: Mutex::new(false),
319            circuit_breaker_tripped_at: Mutex::new(None),
320            recent_orders: Mutex::new(Vec::new()),
321        }
322    }
323
324    /// Create a RiskGuard with both standard config and TA-based risk rules.
325    pub fn with_ta_rules(config: RiskConfig, ta_rules: Vec<TaRiskRule>) -> Self {
326        Self {
327            config,
328            ta_rules,
329            circuit_breaker_tripped: Mutex::new(false),
330            circuit_breaker_tripped_at: Mutex::new(None),
331            recent_orders: Mutex::new(Vec::new()),
332        }
333    }
334
335    /// Replace the current TA risk rules.
336    pub fn set_ta_rules(&mut self, rules: Vec<TaRiskRule>) {
337        self.ta_rules = rules;
338    }
339
340    /// Update the risk config (e.g. when user saves new settings).
341    #[allow(dead_code)]
342    pub fn update_config(&mut self, config: RiskConfig) {
343        self.config = config;
344    }
345
346    /// Check an order against all enabled risk rules **including** TA-based rules.
347    ///
348    /// Returns `Ok(())` if the order is allowed, or `Err` with the first
349    /// blocking violation.  If you only want TA violations without blocking
350    /// the order, call [`check_ta_risk`] directly.
351    pub fn check_order_with_ta(
352        &self,
353        order: &OrderRequest,
354        current_state: &AccountState,
355        indicators: &TechnicalIndicators,
356    ) -> Result<(), RiskViolation> {
357        // Run standard checks first
358        self.check_order(order, current_state)?;
359
360        // Then check TA rules — return the highest-severity violation
361        // (critical > warning) so that e.g. a BollingerBreakout is not masked
362        // by an AtrSurge that happens to appear first in the rules list.
363        if !self.ta_rules.is_empty() {
364            let ta_violations = check_ta_risk(indicators, order.price, &self.ta_rules);
365            if !ta_violations.is_empty() {
366                // Pick the highest-severity violation.
367                // Severity: critical > warning (same mapping as record_risk_alert).
368                let highest = ta_violations
369                    .into_iter()
370                    .max_by_key(|v| match v.rule.as_str() {
371                        "BollingerBreakout" | "CircuitBreaker" | "DailyLossGuard" => 2,
372                        _ => 1, // warning-level: RsiExtreme, AtrSurge, etc.
373                    })
374                    .unwrap(); // safe: we checked !is_empty()
375                return Err(highest);
376            }
377        }
378
379        Ok(())
380    }
381
382    /// Check an order against all enabled risk rules.
383    /// Returns Ok(()) if the order is allowed, Err(RiskViolation) if blocked.
384    pub fn check_order(
385        &self,
386        order: &OrderRequest,
387        current_state: &AccountState,
388    ) -> Result<(), RiskViolation> {
389        // Check circuit breaker first (blocks everything when tripped)
390        self.check_circuit_breaker(current_state)?;
391
392        // Check position limits
393        self.check_position_limits(order, current_state)?;
394
395        // Check daily loss limits
396        self.check_daily_loss(current_state)?;
397
398        // Check anomaly detection
399        self.check_anomaly(order)?;
400
401        // Order passed all checks - record it for future anomaly detection
402        self.record_order(order);
403
404        Ok(())
405    }
406
407    /// Check if circuit breaker is active and should block orders.
408    fn check_circuit_breaker(&self, current_state: &AccountState) -> Result<(), RiskViolation> {
409        if !self.config.circuit_breaker.enabled {
410            return Ok(());
411        }
412
413        // Check if cooldown has expired
414        {
415            let mut tripped = self.circuit_breaker_tripped.lock().unwrap();
416            let tripped_at = self.circuit_breaker_tripped_at.lock().unwrap();
417            if *tripped {
418                if let Some(at) = *tripped_at {
419                    let cooldown = std::time::Duration::from_secs(
420                        self.config.circuit_breaker.cooldown_minutes as u64 * 60,
421                    );
422                    if at.elapsed() >= cooldown {
423                        // Cooldown expired, reset circuit breaker
424                        *tripped = false;
425                    }
426                }
427            }
428
429            if *tripped {
430                return Err(RiskViolation {
431                    rule: "CircuitBreaker".to_string(),
432                    message: format!(
433                        "Circuit breaker is active. All trading is paused for {} minutes.",
434                        self.config.circuit_breaker.cooldown_minutes
435                    ),
436                });
437            }
438        }
439
440        // Check if windowed loss exceeds trigger threshold
441        if current_state.windowed_loss >= self.config.circuit_breaker.trigger_loss {
442            self.trip_circuit_breaker();
443            return Err(RiskViolation {
444                rule: "CircuitBreaker".to_string(),
445                message: format!(
446                    "Circuit breaker triggered: loss ${:.2} in the last {} minutes exceeds threshold ${:.2}. All trading paused.",
447                    current_state.windowed_loss,
448                    self.config.circuit_breaker.trigger_window_minutes,
449                    self.config.circuit_breaker.trigger_loss,
450                ),
451            });
452        }
453
454        Ok(())
455    }
456
457    /// Trip the circuit breaker, blocking all subsequent orders.
458    fn trip_circuit_breaker(&self) {
459        let mut tripped = self.circuit_breaker_tripped.lock().unwrap();
460        let mut tripped_at = self.circuit_breaker_tripped_at.lock().unwrap();
461        *tripped = true;
462        *tripped_at = Some(Instant::now());
463    }
464
465    /// Returns whether the circuit breaker is currently tripped.
466    pub fn is_circuit_breaker_tripped(&self) -> bool {
467        *self.circuit_breaker_tripped.lock().unwrap()
468    }
469
470    /// Check position limits (total and per-symbol).
471    fn check_position_limits(
472        &self,
473        order: &OrderRequest,
474        current_state: &AccountState,
475    ) -> Result<(), RiskViolation> {
476        if !self.config.position_limits.enabled {
477            return Ok(());
478        }
479
480        let order_notional = order.size * order.price;
481
482        // Check total position limit
483        let new_total = current_state.total_position_value + order_notional;
484        if new_total > self.config.position_limits.max_total_position {
485            return Err(RiskViolation {
486                rule: "MaxPositionGuard".to_string(),
487                message: format!(
488                    "Order would exceed total position limit. Current: ${:.2}, Order: ${:.2}, Limit: ${:.2}",
489                    current_state.total_position_value,
490                    order_notional,
491                    self.config.position_limits.max_total_position,
492                ),
493            });
494        }
495
496        // Check per-symbol position limit
497        let current_symbol_pos = current_state
498            .position_by_symbol
499            .get(&order.symbol)
500            .copied()
501            .unwrap_or(0.0);
502        let new_symbol_pos = current_symbol_pos + order_notional;
503        if new_symbol_pos > self.config.position_limits.max_per_symbol {
504            return Err(RiskViolation {
505                rule: "MaxPositionGuard".to_string(),
506                message: format!(
507                    "Order would exceed per-symbol position limit for '{}'. Current: ${:.2}, Order: ${:.2}, Limit: ${:.2}",
508                    order.symbol,
509                    current_symbol_pos,
510                    order_notional,
511                    self.config.position_limits.max_per_symbol,
512                ),
513            });
514        }
515
516        Ok(())
517    }
518
519    /// Check daily loss limits.
520    fn check_daily_loss(&self, current_state: &AccountState) -> Result<(), RiskViolation> {
521        if !self.config.daily_loss_limits.enabled {
522            return Ok(());
523        }
524
525        // Check absolute daily loss
526        if current_state.daily_realized_loss >= self.config.daily_loss_limits.max_daily_loss {
527            return Err(RiskViolation {
528                rule: "DailyLossGuard".to_string(),
529                message: format!(
530                    "Daily loss limit reached. Loss today: ${:.2}, Limit: ${:.2}",
531                    current_state.daily_realized_loss, self.config.daily_loss_limits.max_daily_loss,
532                ),
533            });
534        }
535
536        // Check percentage daily loss
537        if current_state.daily_starting_equity > 0.0 {
538            let loss_pct =
539                (current_state.daily_realized_loss / current_state.daily_starting_equity) * 100.0;
540            if loss_pct >= self.config.daily_loss_limits.max_daily_loss_percent {
541                return Err(RiskViolation {
542                    rule: "DailyLossGuard".to_string(),
543                    message: format!(
544                        "Daily loss percentage limit reached. Loss: {:.2}%, Limit: {:.2}%",
545                        loss_pct, self.config.daily_loss_limits.max_daily_loss_percent,
546                    ),
547                });
548            }
549        }
550
551        Ok(())
552    }
553
554    /// Check anomaly detection rules (order size, rate, duplicates).
555    fn check_anomaly(&self, order: &OrderRequest) -> Result<(), RiskViolation> {
556        if !self.config.anomaly_detection.enabled {
557            return Ok(());
558        }
559
560        let order_notional = order.size * order.price;
561
562        // 1. Max order size check
563        if order_notional > self.config.anomaly_detection.max_order_size {
564            return Err(RiskViolation {
565                rule: "AnomalyDetection".to_string(),
566                message: format!(
567                    "Order size ${:.2} exceeds maximum allowed order size ${:.2}",
568                    order_notional, self.config.anomaly_detection.max_order_size,
569                ),
570            });
571        }
572
573        let mut recent = self.recent_orders.lock().unwrap();
574
575        // Prune orders older than 1 minute
576        let one_minute_ago = Instant::now() - std::time::Duration::from_secs(60);
577        recent.retain(|o| o.timestamp > one_minute_ago);
578
579        // 2. Orders per minute check
580        if recent.len() >= self.config.anomaly_detection.max_orders_per_minute as usize {
581            return Err(RiskViolation {
582                rule: "AnomalyDetection".to_string(),
583                message: format!(
584                    "Rate limit exceeded: {} orders in the last minute (limit: {})",
585                    recent.len(),
586                    self.config.anomaly_detection.max_orders_per_minute,
587                ),
588            });
589        }
590
591        // 3. Duplicate order detection
592        if self.config.anomaly_detection.block_duplicate_orders {
593            let is_duplicate = recent.iter().any(|o| {
594                o.symbol == order.symbol
595                    && o.side == order.side
596                    && (o.size - order.size).abs() < f64::EPSILON
597            });
598            if is_duplicate {
599                return Err(RiskViolation {
600                    rule: "AnomalyDetection".to_string(),
601                    message: format!(
602                        "Duplicate order detected: {} {} {:.4} on symbol '{}'",
603                        order.side, order.size, order.price, order.symbol,
604                    ),
605                });
606            }
607        }
608
609        Ok(())
610    }
611
612    /// Record an order for future anomaly detection checks.
613    fn record_order(&self, order: &OrderRequest) {
614        let mut recent = self.recent_orders.lock().unwrap();
615        recent.push(RecentOrder {
616            symbol: order.symbol.clone(),
617            side: order.side.clone(),
618            size: order.size,
619            timestamp: Instant::now(),
620        });
621    }
622}
623
624// ---------------------------------------------------------------------------
625// Alert integration helper
626// ---------------------------------------------------------------------------
627
628/// Record a risk violation as an alert history entry.
629pub fn record_risk_alert(violation: &RiskViolation) {
630    let entry = AlertHistoryEntry {
631        id: uuid::Uuid::new_v4().to_string(),
632        timestamp: chrono::Utc::now().to_rfc3339(),
633        severity: match violation.rule.as_str() {
634            "CircuitBreaker" | "DailyLossGuard" | "BollingerBreakout" => "critical".to_string(),
635            "RsiExtreme" | "AtrSurge" => "warning".to_string(),
636            _ => "warning".to_string(),
637        },
638        message: violation.to_string(),
639        channel: "risk_guard".to_string(),
640        status: "fired".to_string(),
641    };
642
643    let path = alert_history_path();
644    let mut history: Vec<AlertHistoryEntry> = if path.exists() {
645        fs::read_to_string(&path)
646            .ok()
647            .and_then(|data| serde_json::from_str(&data).ok())
648            .unwrap_or_default()
649    } else {
650        Vec::new()
651    };
652
653    history.push(entry);
654
655    if let Ok(json) = serde_json::to_string_pretty(&history) {
656        let _ = fs::write(&path, json);
657    }
658}
659
660// ---------------------------------------------------------------------------
661// Convenience free function for pre-trade risk checks
662// ---------------------------------------------------------------------------
663
664/// Perform a pre-trade risk check using the on-disk risk config.
665///
666/// Loads the current `RiskConfig` from disk (or defaults), constructs a
667/// temporary `RiskGuard`, and validates the order against all enabled rules.
668///
669/// Returns `Ok(())` if the order passes all checks, or `Err(String)` with a
670/// descriptive message if any risk rule blocks the order. Violations are
671/// automatically recorded to the alert history.
672///
673/// This is a convenience wrapper around `RiskGuard::check_order` for callers
674/// that do not hold a long-lived `RiskGuard` instance.
675pub fn check_risk(order: &OrderRequest, account_state: &AccountState) -> Result<(), String> {
676    let config = get_risk_config_sync();
677    let guard = RiskGuard::new(config);
678    match guard.check_order(order, account_state) {
679        Ok(()) => Ok(()),
680        Err(violation) => {
681            record_risk_alert(&violation);
682            Err(format!("Risk violation: {}", violation))
683        }
684    }
685}
686
687// ---------------------------------------------------------------------------
688// Tests
689// ---------------------------------------------------------------------------
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    fn default_guard() -> RiskGuard {
696        RiskGuard::new(RiskConfig::default())
697    }
698
699    fn default_state() -> AccountState {
700        AccountState::default()
701    }
702
703    fn make_order(symbol: &str, side: &str, size: f64, price: f64) -> OrderRequest {
704        OrderRequest {
705            symbol: symbol.to_string(),
706            side: side.to_string(),
707            size,
708            price,
709        }
710    }
711
712    // ---- Position limits ----
713
714    #[test]
715    fn test_order_within_position_limits_passes() {
716        let guard = default_guard();
717        let state = default_state();
718        let order = make_order("BTC", "buy", 0.1, 30_000.0); // $3,000
719        assert!(guard.check_order(&order, &state).is_ok());
720    }
721
722    #[test]
723    fn test_order_exceeds_total_position_limit() {
724        let guard = default_guard();
725        let state = AccountState {
726            total_position_value: 99_000.0,
727            ..default_state()
728        };
729        // $99,000 existing + $2,000 new = $101,000 > $100,000 limit
730        let order = make_order("BTC", "buy", 1.0, 2_000.0);
731        let result = guard.check_order(&order, &state);
732        assert!(result.is_err());
733        let violation = result.unwrap_err();
734        assert_eq!(violation.rule, "MaxPositionGuard");
735        assert!(violation.message.contains("total position limit"));
736    }
737
738    #[test]
739    fn test_order_exceeds_per_symbol_position_limit() {
740        let guard = default_guard();
741        let mut position_by_symbol = HashMap::new();
742        position_by_symbol.insert("BTC".to_string(), 24_000.0);
743        let state = AccountState {
744            total_position_value: 24_000.0,
745            position_by_symbol,
746            ..default_state()
747        };
748        // $24,000 existing + $2,000 new = $26,000 > $25,000 per-symbol limit
749        let order = make_order("BTC", "buy", 1.0, 2_000.0);
750        let result = guard.check_order(&order, &state);
751        assert!(result.is_err());
752        let violation = result.unwrap_err();
753        assert_eq!(violation.rule, "MaxPositionGuard");
754        assert!(violation.message.contains("per-symbol"));
755    }
756
757    #[test]
758    fn test_position_limits_disabled() {
759        let config = RiskConfig {
760            position_limits: PositionLimits {
761                enabled: false,
762                max_total_position: 1.0,
763                max_per_symbol: 1.0,
764            },
765            anomaly_detection: AnomalyDetection {
766                enabled: true,
767                max_order_size: 1_000_000.0, // high enough to not trigger
768                max_orders_per_minute: 10,
769                block_duplicate_orders: false,
770            },
771            ..RiskConfig::default()
772        };
773        let guard = RiskGuard::new(config);
774        let state = AccountState {
775            total_position_value: 999_999.0,
776            ..default_state()
777        };
778        let order = make_order("BTC", "buy", 1.0, 100_000.0);
779        assert!(guard.check_order(&order, &state).is_ok());
780    }
781
782    // ---- Daily loss limits ----
783
784    #[test]
785    fn test_daily_loss_limit_absolute() {
786        let guard = default_guard();
787        let state = AccountState {
788            daily_realized_loss: 5_000.0, // equals limit
789            ..default_state()
790        };
791        let order = make_order("BTC", "buy", 0.01, 30_000.0);
792        let result = guard.check_order(&order, &state);
793        assert!(result.is_err());
794        let violation = result.unwrap_err();
795        assert_eq!(violation.rule, "DailyLossGuard");
796    }
797
798    #[test]
799    fn test_daily_loss_limit_percentage() {
800        let guard = default_guard();
801        let state = AccountState {
802            daily_realized_loss: 5_500.0,
803            daily_starting_equity: 100_000.0, // 5.5% > 5%
804            ..default_state()
805        };
806        let order = make_order("BTC", "buy", 0.01, 30_000.0);
807        let result = guard.check_order(&order, &state);
808        assert!(result.is_err());
809        let violation = result.unwrap_err();
810        assert_eq!(violation.rule, "DailyLossGuard");
811    }
812
813    #[test]
814    fn test_daily_loss_below_limit_passes() {
815        let guard = default_guard();
816        let state = AccountState {
817            daily_realized_loss: 1_000.0,
818            daily_starting_equity: 100_000.0,
819            ..default_state()
820        };
821        let order = make_order("BTC", "buy", 0.01, 30_000.0);
822        assert!(guard.check_order(&order, &state).is_ok());
823    }
824
825    // ---- Anomaly detection ----
826
827    #[test]
828    fn test_max_order_size_exceeded() {
829        // Use a config with high position limits so only anomaly detection triggers
830        let config = RiskConfig {
831            position_limits: PositionLimits {
832                enabled: false,
833                ..RiskConfig::default().position_limits
834            },
835            anomaly_detection: AnomalyDetection {
836                enabled: true,
837                max_order_size: 50_000.0,
838                max_orders_per_minute: 10,
839                block_duplicate_orders: true,
840            },
841            ..RiskConfig::default()
842        };
843        let guard = RiskGuard::new(config);
844        let state = default_state();
845        // $60,000 > $50,000 max order size
846        let order = make_order("BTC", "buy", 2.0, 30_000.0);
847        let result = guard.check_order(&order, &state);
848        assert!(result.is_err());
849        let violation = result.unwrap_err();
850        assert_eq!(violation.rule, "AnomalyDetection");
851        assert!(violation.message.contains("maximum allowed order size"));
852    }
853
854    #[test]
855    fn test_orders_per_minute_exceeded() {
856        let config = RiskConfig {
857            anomaly_detection: AnomalyDetection {
858                enabled: true,
859                max_order_size: 1_000_000.0,
860                max_orders_per_minute: 3,
861                block_duplicate_orders: false,
862            },
863            ..RiskConfig::default()
864        };
865        let guard = RiskGuard::new(config);
866        let state = default_state();
867
868        // Place 3 orders (at limit)
869        for i in 0..3 {
870            let order = make_order("BTC", "buy", 0.01 * (i as f64 + 1.0), 30_000.0);
871            assert!(guard.check_order(&order, &state).is_ok());
872        }
873
874        // 4th order should be rate limited
875        let order = make_order("ETH", "buy", 0.1, 2_000.0);
876        let result = guard.check_order(&order, &state);
877        assert!(result.is_err());
878        let violation = result.unwrap_err();
879        assert_eq!(violation.rule, "AnomalyDetection");
880        assert!(violation.message.contains("Rate limit"));
881    }
882
883    #[test]
884    fn test_duplicate_order_detection() {
885        let guard = default_guard();
886        let state = default_state();
887
888        let order = make_order("BTC", "buy", 0.1, 30_000.0);
889        assert!(guard.check_order(&order, &state).is_ok());
890
891        // Same symbol + side + size = duplicate
892        let duplicate = make_order("BTC", "buy", 0.1, 31_000.0);
893        let result = guard.check_order(&duplicate, &state);
894        assert!(result.is_err());
895        let violation = result.unwrap_err();
896        assert_eq!(violation.rule, "AnomalyDetection");
897        assert!(violation.message.contains("Duplicate"));
898    }
899
900    #[test]
901    fn test_non_duplicate_different_side_passes() {
902        let guard = default_guard();
903        let state = default_state();
904
905        let order1 = make_order("BTC", "buy", 0.1, 30_000.0);
906        assert!(guard.check_order(&order1, &state).is_ok());
907
908        // Different side = not a duplicate
909        let order2 = make_order("BTC", "sell", 0.1, 30_000.0);
910        assert!(guard.check_order(&order2, &state).is_ok());
911    }
912
913    // ---- Circuit breaker ----
914
915    #[test]
916    fn test_circuit_breaker_triggers_on_windowed_loss() {
917        let config = RiskConfig {
918            circuit_breaker: CircuitBreaker {
919                enabled: true,
920                trigger_loss: 10_000.0,
921                trigger_window_minutes: 60,
922                action: "pause_all".to_string(),
923                cooldown_minutes: 30,
924            },
925            ..RiskConfig::default()
926        };
927        let guard = RiskGuard::new(config);
928        let state = AccountState {
929            windowed_loss: 10_000.0, // equals trigger
930            ..default_state()
931        };
932
933        let order = make_order("BTC", "buy", 0.01, 30_000.0);
934        let result = guard.check_order(&order, &state);
935        assert!(result.is_err());
936        let violation = result.unwrap_err();
937        assert_eq!(violation.rule, "CircuitBreaker");
938        assert!(guard.is_circuit_breaker_tripped());
939    }
940
941    #[test]
942    fn test_circuit_breaker_blocks_subsequent_orders() {
943        let config = RiskConfig {
944            circuit_breaker: CircuitBreaker {
945                enabled: true,
946                trigger_loss: 10_000.0,
947                trigger_window_minutes: 60,
948                action: "pause_all".to_string(),
949                cooldown_minutes: 30,
950            },
951            ..RiskConfig::default()
952        };
953        let guard = RiskGuard::new(config);
954
955        // Trip the circuit breaker
956        let bad_state = AccountState {
957            windowed_loss: 15_000.0,
958            ..default_state()
959        };
960        let order = make_order("BTC", "buy", 0.01, 30_000.0);
961        let _ = guard.check_order(&order, &bad_state);
962
963        // Now try with a clean state - should still be blocked
964        let clean_state = default_state();
965        let result = guard.check_order(&order, &clean_state);
966        assert!(result.is_err());
967        assert_eq!(result.unwrap_err().rule, "CircuitBreaker");
968    }
969
970    #[test]
971    fn test_circuit_breaker_disabled() {
972        let config = RiskConfig {
973            circuit_breaker: CircuitBreaker {
974                enabled: false,
975                trigger_loss: 1.0,
976                trigger_window_minutes: 60,
977                action: "pause_all".to_string(),
978                cooldown_minutes: 30,
979            },
980            ..RiskConfig::default()
981        };
982        let guard = RiskGuard::new(config);
983        let state = AccountState {
984            windowed_loss: 999_999.0,
985            ..default_state()
986        };
987        let order = make_order("BTC", "buy", 0.01, 30_000.0);
988        assert!(guard.check_order(&order, &state).is_ok());
989    }
990
991    // ---- Integration: all rules pass ----
992
993    #[test]
994    fn test_all_rules_pass_for_small_order() {
995        let guard = default_guard();
996        let state = default_state();
997        let order = make_order("BTC", "buy", 0.01, 30_000.0); // $300
998        assert!(guard.check_order(&order, &state).is_ok());
999    }
1000
1001    // ---- RiskViolation display ----
1002
1003    #[test]
1004    fn test_risk_violation_display() {
1005        let v = RiskViolation {
1006            rule: "TestRule".to_string(),
1007            message: "something went wrong".to_string(),
1008        };
1009        assert_eq!(format!("{}", v), "[TestRule] something went wrong");
1010    }
1011
1012    // ---- Config update ----
1013
1014    #[test]
1015    fn test_update_config() {
1016        // Start with a config that blocks via anomaly detection (position limits disabled)
1017        let config = RiskConfig {
1018            position_limits: PositionLimits {
1019                enabled: false,
1020                ..RiskConfig::default().position_limits
1021            },
1022            anomaly_detection: AnomalyDetection {
1023                enabled: true,
1024                max_order_size: 50_000.0,
1025                max_orders_per_minute: 10,
1026                block_duplicate_orders: false,
1027            },
1028            ..RiskConfig::default()
1029        };
1030        let mut guard = RiskGuard::new(config);
1031        let state = default_state();
1032        let big_order = make_order("BTC", "buy", 2.0, 30_000.0); // $60k
1033
1034        // Should fail with max_order_size = $50k
1035        assert!(guard.check_order(&big_order, &state).is_err());
1036
1037        // Update config to allow larger orders
1038        let mut new_config = RiskConfig::default();
1039        new_config.position_limits.enabled = false;
1040        new_config.anomaly_detection.max_order_size = 100_000.0;
1041        new_config.anomaly_detection.block_duplicate_orders = false;
1042        guard.update_config(new_config);
1043
1044        assert!(guard.check_order(&big_order, &state).is_ok());
1045    }
1046
1047    // ---- TA risk rules ----
1048
1049    fn make_indicators(
1050        rsi: Option<f64>,
1051        atr: Option<f64>,
1052        bb: Option<(f64, f64, f64)>,
1053    ) -> TechnicalIndicators {
1054        let mut ind = TechnicalIndicators::empty();
1055        ind.rsi_14 = rsi;
1056        ind.atr_14 = atr;
1057        if let Some((upper, middle, lower)) = bb {
1058            ind.bb_upper = Some(upper);
1059            ind.bb_middle = Some(middle);
1060            ind.bb_lower = Some(lower);
1061        }
1062        ind
1063    }
1064
1065    // -- RSI extreme tests --
1066
1067    #[test]
1068    fn test_rsi_extreme_below_lower() {
1069        let rules = vec![TaRiskRule::RsiExtreme {
1070            lower: 20.0,
1071            upper: 80.0,
1072        }];
1073        let ind = make_indicators(Some(15.0), None, None);
1074        let violations = check_ta_risk(&ind, 100.0, &rules);
1075        assert_eq!(violations.len(), 1);
1076        assert_eq!(violations[0].rule, "RsiExtreme");
1077        assert!(violations[0].message.contains("below"));
1078    }
1079
1080    #[test]
1081    fn test_rsi_extreme_above_upper() {
1082        let rules = vec![TaRiskRule::RsiExtreme {
1083            lower: 20.0,
1084            upper: 80.0,
1085        }];
1086        let ind = make_indicators(Some(85.0), None, None);
1087        let violations = check_ta_risk(&ind, 100.0, &rules);
1088        assert_eq!(violations.len(), 1);
1089        assert_eq!(violations[0].rule, "RsiExtreme");
1090        assert!(violations[0].message.contains("above"));
1091    }
1092
1093    #[test]
1094    fn test_rsi_within_range_passes() {
1095        let rules = vec![TaRiskRule::RsiExtreme {
1096            lower: 20.0,
1097            upper: 80.0,
1098        }];
1099        let ind = make_indicators(Some(50.0), None, None);
1100        let violations = check_ta_risk(&ind, 100.0, &rules);
1101        assert!(violations.is_empty());
1102    }
1103
1104    #[test]
1105    fn test_rsi_at_boundary_passes() {
1106        let rules = vec![TaRiskRule::RsiExtreme {
1107            lower: 20.0,
1108            upper: 80.0,
1109        }];
1110        // Exactly at boundary should NOT trigger (only < lower or > upper)
1111        let ind_lower = make_indicators(Some(20.0), None, None);
1112        assert!(check_ta_risk(&ind_lower, 100.0, &rules).is_empty());
1113        let ind_upper = make_indicators(Some(80.0), None, None);
1114        assert!(check_ta_risk(&ind_upper, 100.0, &rules).is_empty());
1115    }
1116
1117    #[test]
1118    fn test_rsi_none_skips_check() {
1119        let rules = vec![TaRiskRule::RsiExtreme {
1120            lower: 20.0,
1121            upper: 80.0,
1122        }];
1123        let ind = make_indicators(None, None, None);
1124        assert!(check_ta_risk(&ind, 100.0, &rules).is_empty());
1125    }
1126
1127    // -- ATR surge tests --
1128
1129    #[test]
1130    fn test_atr_surge_exceeds_threshold() {
1131        let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1132        // ATR = 60, price = 1000 → 6% > 5%
1133        let ind = make_indicators(None, Some(60.0), None);
1134        let violations = check_ta_risk(&ind, 1000.0, &rules);
1135        assert_eq!(violations.len(), 1);
1136        assert_eq!(violations[0].rule, "AtrSurge");
1137        assert!(violations[0].message.contains("6.00%"));
1138    }
1139
1140    #[test]
1141    fn test_atr_surge_below_threshold_passes() {
1142        let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1143        // ATR = 40, price = 1000 → 4% < 5%
1144        let ind = make_indicators(None, Some(40.0), None);
1145        assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1146    }
1147
1148    #[test]
1149    fn test_atr_surge_at_threshold_passes() {
1150        let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1151        // ATR = 50, price = 1000 → exactly 5%, not > 5%
1152        let ind = make_indicators(None, Some(50.0), None);
1153        assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1154    }
1155
1156    #[test]
1157    fn test_atr_surge_zero_price_skips() {
1158        let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1159        let ind = make_indicators(None, Some(60.0), None);
1160        assert!(check_ta_risk(&ind, 0.0, &rules).is_empty());
1161    }
1162
1163    #[test]
1164    fn test_atr_none_skips_check() {
1165        let rules = vec![TaRiskRule::AtrSurge { threshold_pct: 5.0 }];
1166        let ind = make_indicators(None, None, None);
1167        assert!(check_ta_risk(&ind, 1000.0, &rules).is_empty());
1168    }
1169
1170    // -- Bollinger breakout tests --
1171
1172    #[test]
1173    fn test_bollinger_breakout_above_upper() {
1174        // Default BB uses 2σ. bb_middle=100, bb_upper=110 → σ=5
1175        // 3σ upper = 100 + 15 = 115. Price 120 > 115 → violation
1176        let rules = vec![TaRiskRule::BollingerBreakout {
1177            sigma_multiplier: 3.0,
1178        }];
1179        let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1180        let violations = check_ta_risk(&ind, 120.0, &rules);
1181        assert_eq!(violations.len(), 1);
1182        assert_eq!(violations[0].rule, "BollingerBreakout");
1183        assert!(violations[0].message.contains("above"));
1184    }
1185
1186    #[test]
1187    fn test_bollinger_breakout_below_lower() {
1188        // σ=5, 3σ lower = 100 - 15 = 85. Price 80 < 85 → violation
1189        let rules = vec![TaRiskRule::BollingerBreakout {
1190            sigma_multiplier: 3.0,
1191        }];
1192        let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1193        let violations = check_ta_risk(&ind, 80.0, &rules);
1194        assert_eq!(violations.len(), 1);
1195        assert_eq!(violations[0].rule, "BollingerBreakout");
1196        assert!(violations[0].message.contains("below"));
1197    }
1198
1199    #[test]
1200    fn test_bollinger_within_bands_passes() {
1201        // σ=5, 3σ range = [85, 115]. Price 100 → OK
1202        let rules = vec![TaRiskRule::BollingerBreakout {
1203            sigma_multiplier: 3.0,
1204        }];
1205        let ind = make_indicators(None, None, Some((110.0, 100.0, 90.0)));
1206        assert!(check_ta_risk(&ind, 100.0, &rules).is_empty());
1207    }
1208
1209    #[test]
1210    fn test_bollinger_none_skips_check() {
1211        let rules = vec![TaRiskRule::BollingerBreakout {
1212            sigma_multiplier: 3.0,
1213        }];
1214        let ind = make_indicators(None, None, None);
1215        assert!(check_ta_risk(&ind, 120.0, &rules).is_empty());
1216    }
1217
1218    // -- Multiple rules at once --
1219
1220    #[test]
1221    fn test_multiple_rules_all_triggered() {
1222        let rules = vec![
1223            TaRiskRule::RsiExtreme {
1224                lower: 20.0,
1225                upper: 80.0,
1226            },
1227            TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1228            TaRiskRule::BollingerBreakout {
1229                sigma_multiplier: 3.0,
1230            },
1231        ];
1232        // RSI=90 (above 80), ATR=60 on price 1000 (6%), BB: σ=5, 3σ upper=115, price=120
1233        let mut ind = TechnicalIndicators::empty();
1234        ind.rsi_14 = Some(90.0);
1235        ind.atr_14 = Some(60.0);
1236        ind.bb_upper = Some(110.0);
1237        ind.bb_middle = Some(100.0);
1238        ind.bb_lower = Some(90.0);
1239        // Note: for ATR we use price=1000 but for BB we'd need price=120.
1240        // Use price=120 to trigger BB; ATR 60/120*100 = 50% > 5% also triggers.
1241        let violations = check_ta_risk(&ind, 120.0, &rules);
1242        assert_eq!(violations.len(), 3);
1243        let rule_names: Vec<&str> = violations.iter().map(|v| v.rule.as_str()).collect();
1244        assert!(rule_names.contains(&"RsiExtreme"));
1245        assert!(rule_names.contains(&"AtrSurge"));
1246        assert!(rule_names.contains(&"BollingerBreakout"));
1247    }
1248
1249    #[test]
1250    fn test_empty_rules_no_violations() {
1251        let rules: Vec<TaRiskRule> = vec![];
1252        let ind = make_indicators(Some(10.0), Some(100.0), Some((110.0, 100.0, 90.0)));
1253        assert!(check_ta_risk(&ind, 120.0, &rules).is_empty());
1254    }
1255
1256    // -- Integration: check_order_with_ta --
1257
1258    #[test]
1259    fn test_check_order_with_ta_blocks_on_rsi() {
1260        let ta_rules = vec![TaRiskRule::RsiExtreme {
1261            lower: 20.0,
1262            upper: 80.0,
1263        }];
1264        let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1265        let state = default_state();
1266        let order = make_order("BTC", "buy", 0.01, 30_000.0);
1267        let ind = make_indicators(Some(85.0), None, None);
1268        let result = guard.check_order_with_ta(&order, &state, &ind);
1269        assert!(result.is_err());
1270        assert_eq!(result.unwrap_err().rule, "RsiExtreme");
1271    }
1272
1273    #[test]
1274    fn test_check_order_with_ta_passes_when_clean() {
1275        let ta_rules = vec![
1276            TaRiskRule::RsiExtreme {
1277                lower: 20.0,
1278                upper: 80.0,
1279            },
1280            TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1281        ];
1282        let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1283        let state = default_state();
1284        let order = make_order("BTC", "buy", 0.01, 30_000.0);
1285        let ind = make_indicators(Some(50.0), Some(100.0), None); // ATR 100/30000 = 0.33%
1286        assert!(guard.check_order_with_ta(&order, &state, &ind).is_ok());
1287    }
1288
1289    #[test]
1290    fn test_set_ta_rules() {
1291        let mut guard = default_guard();
1292        let state = default_state();
1293        let order1 = make_order("BTC", "buy", 0.01, 30_000.0);
1294        let ind = make_indicators(Some(85.0), None, None);
1295
1296        // No TA rules yet — should pass
1297        assert!(guard.check_order_with_ta(&order1, &state, &ind).is_ok());
1298
1299        // Add RSI rule — use a different order to avoid duplicate detection
1300        guard.set_ta_rules(vec![TaRiskRule::RsiExtreme {
1301            lower: 20.0,
1302            upper: 80.0,
1303        }]);
1304        let order2 = make_order("BTC", "buy", 0.02, 30_000.0);
1305        let result = guard.check_order_with_ta(&order2, &state, &ind);
1306        assert!(result.is_err());
1307        assert_eq!(result.unwrap_err().rule, "RsiExtreme");
1308    }
1309
1310    #[test]
1311    fn test_check_order_with_ta_returns_highest_severity_violation() {
1312        // AtrSurge (warning) comes first in the rules list, but
1313        // BollingerBreakout (critical) should be the one returned.
1314        let ta_rules = vec![
1315            TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1316            TaRiskRule::BollingerBreakout {
1317                sigma_multiplier: 3.0,
1318            },
1319        ];
1320        let guard = RiskGuard::with_ta_rules(RiskConfig::default(), ta_rules);
1321        let state = default_state();
1322        // Price 120, ATR=60 → 60/120*100=50% > 5% (AtrSurge triggers)
1323        // BB: σ=5, 3σ upper=115, price 120 > 115 (BollingerBreakout triggers)
1324        let mut ind = TechnicalIndicators::empty();
1325        ind.atr_14 = Some(60.0);
1326        ind.bb_upper = Some(110.0);
1327        ind.bb_middle = Some(100.0);
1328        ind.bb_lower = Some(90.0);
1329        let order = make_order("BTC", "buy", 0.01, 120.0);
1330        let result = guard.check_order_with_ta(&order, &state, &ind);
1331        assert!(result.is_err());
1332        assert_eq!(
1333            result.unwrap_err().rule,
1334            "BollingerBreakout",
1335            "should return the critical-level violation, not the warning-level one"
1336        );
1337    }
1338
1339    #[test]
1340    fn test_check_order_with_ta_circuit_breaker_tripped_ta_clean() {
1341        // Backward compat: standard check (circuit breaker) fails even
1342        // though TA indicators are perfectly clean.
1343        let config = RiskConfig {
1344            circuit_breaker: CircuitBreaker {
1345                enabled: true,
1346                trigger_loss: 10_000.0,
1347                trigger_window_minutes: 60,
1348                action: "pause_all".to_string(),
1349                cooldown_minutes: 30,
1350            },
1351            ..RiskConfig::default()
1352        };
1353        let ta_rules = vec![
1354            TaRiskRule::RsiExtreme {
1355                lower: 20.0,
1356                upper: 80.0,
1357            },
1358            TaRiskRule::AtrSurge { threshold_pct: 5.0 },
1359        ];
1360        let guard = RiskGuard::with_ta_rules(config, ta_rules);
1361        let state = AccountState {
1362            windowed_loss: 15_000.0, // trips the circuit breaker
1363            ..default_state()
1364        };
1365        // TA indicators are clean (RSI=50, ATR low)
1366        let ind = make_indicators(Some(50.0), Some(10.0), None);
1367        let order = make_order("BTC", "buy", 0.01, 30_000.0);
1368        let result = guard.check_order_with_ta(&order, &state, &ind);
1369        assert!(result.is_err());
1370        assert_eq!(
1371            result.unwrap_err().rule,
1372            "CircuitBreaker",
1373            "standard circuit breaker should still block even when TA is clean"
1374        );
1375    }
1376
1377    #[test]
1378    fn test_bollinger_partial_data_skips_check() {
1379        // bb_middle is present but bb_upper is None → should safely skip
1380        let rules = vec![TaRiskRule::BollingerBreakout {
1381            sigma_multiplier: 3.0,
1382        }];
1383        let mut ind = TechnicalIndicators::empty();
1384        ind.bb_middle = Some(100.0);
1385        // bb_upper and bb_lower intentionally left as None
1386        let violations = check_ta_risk(&ind, 120.0, &rules);
1387        assert!(
1388            violations.is_empty(),
1389            "partial Bollinger data (bb_upper=None) should skip the check, not panic"
1390        );
1391    }
1392
1393    // ---- check_risk free function ----
1394
1395    #[test]
1396    fn test_check_risk_passes_small_order() {
1397        let order = make_order("BTC", "buy", 0.01, 30_000.0); // $300
1398        let state = default_state();
1399        assert!(check_risk(&order, &state).is_ok());
1400    }
1401
1402    #[test]
1403    fn test_check_risk_blocks_on_daily_loss() {
1404        let order = make_order("BTC", "buy", 0.01, 30_000.0);
1405        let state = AccountState {
1406            daily_realized_loss: 10_000.0, // exceeds default $5,000 limit
1407            ..default_state()
1408        };
1409        let result = check_risk(&order, &state);
1410        assert!(result.is_err());
1411        let err_msg = result.unwrap_err();
1412        assert!(err_msg.contains("Risk violation"), "Error: {}", err_msg);
1413    }
1414}