Skip to main content

hyper_risk/
risk_defaults.rs

1use serde::{Deserialize, Serialize};
2
3use hyper_ta::Candle;
4use hyper_ta::TechnicalIndicators;
5
6// ---------------------------------------------------------------------------
7// #215 — Per-category SL/TP Defaults
8// ---------------------------------------------------------------------------
9
10/// The type of stop-loss calculation.
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum SlType {
14    /// Stop is ATR-multiple based.
15    Atr,
16    /// Stop is a fixed percentage of entry price.
17    Percent,
18}
19
20/// Per-category risk defaults for SL, TP, and trailing-stop parameters.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(rename_all = "camelCase")]
23pub struct RiskDefaults {
24    pub category: String,
25    pub sl_type: SlType,
26    /// Multiplier applied to ATR (or percentage value for `Percent` type).
27    pub sl_mult: f64,
28    /// Take-profit multiplier (ATR multiples or percentage).
29    pub tp_mult: f64,
30    /// Whether trailing stop is enabled for this category.
31    pub trailing: bool,
32    /// Trailing-stop ATR multiplier (only meaningful when `trailing == true`).
33    pub trailing_mult: f64,
34}
35
36/// Look up the risk defaults for the given strategy category.
37///
38/// Unknown categories fall back to a conservative default
39/// (ATR-based, SL 2.0, TP 3.0, no trailing).
40pub fn get_risk_defaults(category: &str) -> RiskDefaults {
41    match category {
42        "trend_following" => RiskDefaults {
43            category: category.to_string(),
44            sl_type: SlType::Atr,
45            sl_mult: 2.0,
46            tp_mult: 4.0,
47            trailing: true,
48            trailing_mult: 3.0,
49        },
50        "momentum" => RiskDefaults {
51            category: category.to_string(),
52            sl_type: SlType::Atr,
53            sl_mult: 1.5,
54            tp_mult: 3.0,
55            trailing: false,
56            trailing_mult: 0.0,
57        },
58        "mean_reversion" => RiskDefaults {
59            category: category.to_string(),
60            sl_type: SlType::Atr,
61            sl_mult: 1.0,
62            tp_mult: 1.5,
63            trailing: false,
64            trailing_mult: 0.0,
65        },
66        "volatility" => RiskDefaults {
67            category: category.to_string(),
68            sl_type: SlType::Atr,
69            sl_mult: 2.0,
70            tp_mult: 3.5,
71            trailing: false,
72            trailing_mult: 0.0,
73        },
74        "stat_arb" => RiskDefaults {
75            category: category.to_string(),
76            sl_type: SlType::Percent,
77            sl_mult: 2.0,
78            tp_mult: 3.0,
79            trailing: false,
80            trailing_mult: 0.0,
81        },
82        "onchain" => RiskDefaults {
83            category: category.to_string(),
84            sl_type: SlType::Atr,
85            sl_mult: 1.5,
86            tp_mult: 2.5,
87            trailing: false,
88            trailing_mult: 0.0,
89        },
90        "factor" => RiskDefaults {
91            category: category.to_string(),
92            sl_type: SlType::Atr,
93            sl_mult: 1.5,
94            tp_mult: 3.0,
95            trailing: false,
96            trailing_mult: 0.0,
97        },
98        "execution" => RiskDefaults {
99            category: category.to_string(),
100            sl_type: SlType::Percent,
101            sl_mult: 1.0,
102            tp_mult: 2.0,
103            trailing: false,
104            trailing_mult: 0.0,
105        },
106        "snr" => RiskDefaults {
107            category: category.to_string(),
108            sl_type: SlType::Atr,
109            sl_mult: 1.5,
110            tp_mult: 2.5,
111            trailing: false,
112            trailing_mult: 0.0,
113        },
114        "composite" => RiskDefaults {
115            category: category.to_string(),
116            sl_type: SlType::Atr,
117            sl_mult: 2.0,
118            tp_mult: 4.0,
119            trailing: true,
120            trailing_mult: 3.0,
121        },
122        // Unknown category — conservative fallback
123        _ => RiskDefaults {
124            category: category.to_string(),
125            sl_type: SlType::Atr,
126            sl_mult: 2.0,
127            tp_mult: 3.0,
128            trailing: false,
129            trailing_mult: 0.0,
130        },
131    }
132}
133
134// ---------------------------------------------------------------------------
135// #216 — Dynamic Position Sizing
136// ---------------------------------------------------------------------------
137
138/// Return the base position percentage for the given strategy category.
139///
140/// Values are expressed as fractions (0.0–1.0).
141pub fn base_position_pct(category: &str) -> f64 {
142    match category {
143        "trend_following" => 0.60,
144        "mean_reversion" => 0.30,
145        "momentum" => 0.50,
146        "volatility" => 0.40,
147        "stat_arb" => 0.20,
148        "onchain" => 0.40,
149        "factor" => 0.30,
150        "execution" => 0.15,
151        "snr" => 0.40,
152        "composite" => 0.50,
153        _ => 0.25, // conservative fallback
154    }
155}
156
157/// Compute a suggested position size as a percentage of total capital.
158///
159/// `signal_strength` is expected in 0.0–1.0.
160/// Returns a value in 0.0–1.0 (fraction of capital).
161pub fn suggested_qty_pct(category: &str, signal_strength: f64) -> f64 {
162    let base = base_position_pct(category);
163    let clamped_strength = signal_strength.clamp(0.0, 1.0);
164    base * clamped_strength
165}
166
167// ---------------------------------------------------------------------------
168// #217 — Trailing Stop State Management
169// ---------------------------------------------------------------------------
170
171/// Direction of the trailing stop (tracks position side).
172#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
173#[serde(rename_all = "snake_case")]
174pub enum TrailingDirection {
175    Long,
176    Short,
177}
178
179/// State for a trailing stop attached to an open position.
180#[derive(Debug, Clone, Serialize, Deserialize)]
181#[serde(rename_all = "camelCase")]
182pub struct TrailingStopState {
183    pub highest_since_entry: f64,
184    pub lowest_since_entry: f64,
185    pub current_stop: f64,
186    pub direction: TrailingDirection,
187    pub active: bool,
188}
189
190impl TrailingStopState {
191    /// Create a new trailing stop state at the given entry price.
192    pub fn new(entry_price: f64, direction: TrailingDirection) -> Self {
193        Self {
194            highest_since_entry: entry_price,
195            lowest_since_entry: entry_price,
196            current_stop: 0.0, // will be set on first update
197            direction,
198            active: true,
199        }
200    }
201
202    /// Update the trailing stop with the latest price and ATR.
203    ///
204    /// For longs:  track highest, stop = highest − multiplier × ATR
205    /// For shorts: track lowest,  stop = lowest  + multiplier × ATR
206    pub fn update(&mut self, current_price: f64, atr: f64, multiplier: f64) {
207        if !self.active {
208            return;
209        }
210
211        match self.direction {
212            TrailingDirection::Long => {
213                if current_price > self.highest_since_entry {
214                    self.highest_since_entry = current_price;
215                }
216                let new_stop = self.highest_since_entry - multiplier * atr;
217                // Only ratchet the stop upward
218                if new_stop > self.current_stop {
219                    self.current_stop = new_stop;
220                }
221            }
222            TrailingDirection::Short => {
223                if current_price < self.lowest_since_entry {
224                    self.lowest_since_entry = current_price;
225                }
226                let new_stop = self.lowest_since_entry + multiplier * atr;
227                // Only ratchet the stop downward (for shorts, lower stop = tighter)
228                if self.current_stop == 0.0 || new_stop < self.current_stop {
229                    self.current_stop = new_stop;
230                }
231            }
232        }
233    }
234
235    /// Check whether the trailing stop has been triggered.
236    pub fn is_triggered(&self, current_price: f64) -> bool {
237        if !self.active || self.current_stop == 0.0 {
238            return false;
239        }
240
241        match self.direction {
242            TrailingDirection::Long => current_price <= self.current_stop,
243            TrailingDirection::Short => current_price >= self.current_stop,
244        }
245    }
246}
247
248// ---------------------------------------------------------------------------
249// #218 — ADX Regime Filter for Mean Reversion
250// ---------------------------------------------------------------------------
251
252/// Apply ADX-based filtering to a mean-reversion signal.
253///
254/// - ADX > 30  → suppress completely (return 0.0)
255/// - ADX 25–30 → reduce by 50% (return strength × 0.5)
256/// - ADX < 25  → pass through unchanged
257/// - ADX unavailable (None) → pass through unchanged
258pub fn apply_adx_filter(signal_strength: f64, adx: Option<f64>) -> f64 {
259    match adx {
260        Some(v) if v > 30.0 => 0.0,
261        Some(v) if v >= 25.0 => signal_strength * 0.5,
262        _ => signal_strength,
263    }
264}
265
266// ---------------------------------------------------------------------------
267// #219 — Volume Context Analysis
268// ---------------------------------------------------------------------------
269
270/// Classified volume context for the current market state.
271#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
272#[serde(rename_all = "snake_case")]
273pub enum VolumeContext {
274    /// High volume + with-trend price move: trend likely continues.
275    TrendContinuation,
276    /// Volume spike + counter-trend / reversal candle: exhaustion.
277    ClimaxExhaustion,
278    /// Volume spike + RSI extreme + reversal: reversal confirmed.
279    ReversalConfirm,
280    /// Very large volume spike + ADX breakout: regime changing.
281    RegimeChange,
282    /// No noteworthy volume pattern.
283    Normal,
284}
285
286/// Analyse volume context using candles and technical indicators.
287///
288/// Classification logic (evaluated in priority order):
289///
290/// 1. `volume_zscore > 3.0` AND `adx > 25` AND ADX rising → `RegimeChange`
291/// 2. `volume_zscore > 2.0` AND RSI extreme (>70 or <30) AND reversal candle → `ReversalConfirm`
292/// 3. `volume_zscore > 2.0` AND counter-trend candle → `ClimaxExhaustion`
293/// 4. `volume_zscore > 2.0` AND with-trend candle → `TrendContinuation`
294/// 5. Otherwise → `Normal`
295pub fn analyze_volume_context(
296    candles: &[Candle],
297    indicators: &TechnicalIndicators,
298) -> VolumeContext {
299    let vol_zscore = match indicators.volume_zscore_20 {
300        Some(z) => z,
301        None => return VolumeContext::Normal,
302    };
303
304    // Need at least 2 candles for trend/reversal detection
305    if candles.len() < 2 {
306        return VolumeContext::Normal;
307    }
308
309    let last = &candles[candles.len() - 1];
310    let prev = &candles[candles.len() - 2];
311
312    // Determine price direction of the latest candle
313    let candle_bullish = last.close > last.open;
314    let candle_bearish = last.close < last.open;
315
316    // Simple trend proxy: previous close vs two bars ago close
317    let trend_up = last.close > prev.close;
318
319    // A "reversal candle" is one that moves against the prior trend
320    let is_reversal = (trend_up && candle_bearish) || (!trend_up && candle_bullish);
321
322    // With-trend: candle direction aligns with recent trend
323    let is_with_trend = (trend_up && candle_bullish) || (!trend_up && candle_bearish);
324
325    // 1. RegimeChange: very large spike + ADX breakout
326    if vol_zscore > 3.0 {
327        if let Some(adx) = indicators.adx_14 {
328            if adx > 25.0 {
329                return VolumeContext::RegimeChange;
330            }
331        }
332    }
333
334    // 2. ReversalConfirm: volume spike + RSI extreme + reversal candle
335    if vol_zscore > 2.0 {
336        if let Some(rsi) = indicators.rsi_14 {
337            if (rsi > 70.0 || rsi < 30.0) && is_reversal {
338                return VolumeContext::ReversalConfirm;
339            }
340        }
341    }
342
343    // 3. ClimaxExhaustion: volume spike + against-trend
344    if vol_zscore > 2.0 && is_reversal {
345        return VolumeContext::ClimaxExhaustion;
346    }
347
348    // 4. TrendContinuation: volume spike + with-trend
349    if vol_zscore > 2.0 && is_with_trend {
350        return VolumeContext::TrendContinuation;
351    }
352
353    VolumeContext::Normal
354}
355
356/// Return a strength modifier for a given volume context and strategy category.
357///
358/// Values > 1.0 amplify the signal; values < 1.0 dampen it.
359pub fn volume_strength_modifier(context: &VolumeContext, category: &str) -> f64 {
360    match context {
361        VolumeContext::TrendContinuation => {
362            match category {
363                "trend_following" | "composite" => 1.3,
364                "momentum" => 1.2,
365                "mean_reversion" => 0.7, // continuation is bad for MR
366                _ => 1.1,
367            }
368        }
369        VolumeContext::ClimaxExhaustion => {
370            match category {
371                "trend_following" | "momentum" => 0.5, // exhaustion dampens trend signals
372                "mean_reversion" => 1.3,               // exhaustion supports MR
373                _ => 0.8,
374            }
375        }
376        VolumeContext::ReversalConfirm => {
377            match category {
378                "mean_reversion" => 1.5, // strong confirmation for MR
379                "trend_following" => 0.4,
380                "momentum" => 0.5,
381                _ => 1.0,
382            }
383        }
384        VolumeContext::RegimeChange => {
385            match category {
386                "volatility" => 1.4,
387                "trend_following" | "composite" => 1.2,
388                "mean_reversion" | "stat_arb" => 0.3, // regime change is risky for MR/arb
389                "execution" => 0.5,
390                _ => 0.8,
391            }
392        }
393        VolumeContext::Normal => 1.0,
394    }
395}
396
397// ---------------------------------------------------------------------------
398// Tests
399// ---------------------------------------------------------------------------
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    // -----------------------------------------------------------------------
406    // #215 — Risk Defaults
407    // -----------------------------------------------------------------------
408
409    #[test]
410    fn test_risk_defaults_trend_following() {
411        let d = get_risk_defaults("trend_following");
412        assert_eq!(d.sl_type, SlType::Atr);
413        assert_eq!(d.sl_mult, 2.0);
414        assert_eq!(d.tp_mult, 4.0);
415        assert!(d.trailing);
416        assert_eq!(d.trailing_mult, 3.0);
417    }
418
419    #[test]
420    fn test_risk_defaults_momentum() {
421        let d = get_risk_defaults("momentum");
422        assert_eq!(d.sl_type, SlType::Atr);
423        assert_eq!(d.sl_mult, 1.5);
424        assert_eq!(d.tp_mult, 3.0);
425        assert!(!d.trailing);
426    }
427
428    #[test]
429    fn test_risk_defaults_mean_reversion() {
430        let d = get_risk_defaults("mean_reversion");
431        assert_eq!(d.sl_type, SlType::Atr);
432        assert_eq!(d.sl_mult, 1.0);
433        assert_eq!(d.tp_mult, 1.5);
434        assert!(!d.trailing);
435    }
436
437    #[test]
438    fn test_risk_defaults_stat_arb_uses_percent() {
439        let d = get_risk_defaults("stat_arb");
440        assert_eq!(d.sl_type, SlType::Percent);
441        assert_eq!(d.sl_mult, 2.0);
442        assert_eq!(d.tp_mult, 3.0);
443    }
444
445    #[test]
446    fn test_risk_defaults_execution_uses_percent() {
447        let d = get_risk_defaults("execution");
448        assert_eq!(d.sl_type, SlType::Percent);
449        assert_eq!(d.sl_mult, 1.0);
450        assert_eq!(d.tp_mult, 2.0);
451    }
452
453    #[test]
454    fn test_risk_defaults_composite_has_trailing() {
455        let d = get_risk_defaults("composite");
456        assert!(d.trailing);
457        assert_eq!(d.trailing_mult, 3.0);
458    }
459
460    #[test]
461    fn test_risk_defaults_unknown_category_fallback() {
462        let d = get_risk_defaults("unknown_strategy_xyz");
463        assert_eq!(d.sl_type, SlType::Atr);
464        assert_eq!(d.sl_mult, 2.0);
465        assert_eq!(d.tp_mult, 3.0);
466        assert!(!d.trailing);
467        assert_eq!(d.category, "unknown_strategy_xyz");
468    }
469
470    #[test]
471    fn test_risk_defaults_all_categories() {
472        let categories = [
473            "trend_following",
474            "momentum",
475            "mean_reversion",
476            "volatility",
477            "stat_arb",
478            "onchain",
479            "factor",
480            "execution",
481            "snr",
482            "composite",
483        ];
484        for cat in &categories {
485            let d = get_risk_defaults(cat);
486            assert_eq!(d.category, *cat);
487            assert!(d.sl_mult > 0.0);
488            assert!(d.tp_mult > 0.0);
489        }
490    }
491
492    #[test]
493    fn test_risk_defaults_serialization_roundtrip() {
494        let d = get_risk_defaults("trend_following");
495        let json = serde_json::to_string(&d).unwrap();
496        let parsed: RiskDefaults = serde_json::from_str(&json).unwrap();
497        assert_eq!(parsed.sl_type, d.sl_type);
498        assert_eq!(parsed.sl_mult, d.sl_mult);
499        assert_eq!(parsed.trailing, d.trailing);
500    }
501
502    // -----------------------------------------------------------------------
503    // #216 — Dynamic Position Sizing
504    // -----------------------------------------------------------------------
505
506    #[test]
507    fn test_base_position_pct_known_categories() {
508        assert_eq!(base_position_pct("trend_following"), 0.60);
509        assert_eq!(base_position_pct("mean_reversion"), 0.30);
510        assert_eq!(base_position_pct("momentum"), 0.50);
511        assert_eq!(base_position_pct("execution"), 0.15);
512    }
513
514    #[test]
515    fn test_base_position_pct_unknown_category() {
516        assert_eq!(base_position_pct("unknown"), 0.25);
517    }
518
519    #[test]
520    fn test_suggested_qty_pct_full_strength() {
521        let pct = suggested_qty_pct("trend_following", 1.0);
522        assert!((pct - 0.60).abs() < 1e-10);
523    }
524
525    #[test]
526    fn test_suggested_qty_pct_half_strength() {
527        let pct = suggested_qty_pct("trend_following", 0.5);
528        assert!((pct - 0.30).abs() < 1e-10);
529    }
530
531    #[test]
532    fn test_suggested_qty_pct_zero_strength() {
533        let pct = suggested_qty_pct("momentum", 0.0);
534        assert_eq!(pct, 0.0);
535    }
536
537    #[test]
538    fn test_suggested_qty_pct_clamps_above_one() {
539        let pct = suggested_qty_pct("momentum", 1.5);
540        // Should clamp to 1.0, so base * 1.0 = 0.50
541        assert!((pct - 0.50).abs() < 1e-10);
542    }
543
544    #[test]
545    fn test_suggested_qty_pct_clamps_negative() {
546        let pct = suggested_qty_pct("momentum", -0.3);
547        assert_eq!(pct, 0.0);
548    }
549
550    // -----------------------------------------------------------------------
551    // #217 — Trailing Stop State Management
552    // -----------------------------------------------------------------------
553
554    #[test]
555    fn test_trailing_stop_new_long() {
556        let ts = TrailingStopState::new(100.0, TrailingDirection::Long);
557        assert_eq!(ts.highest_since_entry, 100.0);
558        assert_eq!(ts.lowest_since_entry, 100.0);
559        assert_eq!(ts.current_stop, 0.0);
560        assert!(ts.active);
561    }
562
563    #[test]
564    fn test_trailing_stop_long_update_raises_stop() {
565        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Long);
566        // ATR = 2.0, multiplier = 3.0 → stop = 100 - 6 = 94
567        ts.update(100.0, 2.0, 3.0);
568        assert_eq!(ts.current_stop, 94.0);
569
570        // Price rises to 110 → stop = 110 - 6 = 104
571        ts.update(110.0, 2.0, 3.0);
572        assert_eq!(ts.highest_since_entry, 110.0);
573        assert_eq!(ts.current_stop, 104.0);
574
575        // Price drops back to 105 → stop should NOT decrease
576        ts.update(105.0, 2.0, 3.0);
577        assert_eq!(ts.current_stop, 104.0);
578    }
579
580    #[test]
581    fn test_trailing_stop_long_triggered() {
582        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Long);
583        ts.update(100.0, 2.0, 3.0); // stop = 94
584        assert!(!ts.is_triggered(95.0));
585        assert!(ts.is_triggered(94.0));
586        assert!(ts.is_triggered(90.0));
587    }
588
589    #[test]
590    fn test_trailing_stop_short_update_lowers_stop() {
591        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Short);
592        // ATR = 2.0, multiplier = 3.0 → stop = 100 + 6 = 106
593        ts.update(100.0, 2.0, 3.0);
594        assert_eq!(ts.current_stop, 106.0);
595
596        // Price drops to 90 → stop = 90 + 6 = 96
597        ts.update(90.0, 2.0, 3.0);
598        assert_eq!(ts.lowest_since_entry, 90.0);
599        assert_eq!(ts.current_stop, 96.0);
600
601        // Price rises to 95 → stop should NOT increase (ratchet down only)
602        ts.update(95.0, 2.0, 3.0);
603        assert_eq!(ts.current_stop, 96.0);
604    }
605
606    #[test]
607    fn test_trailing_stop_short_triggered() {
608        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Short);
609        ts.update(100.0, 2.0, 3.0); // stop = 106
610        assert!(!ts.is_triggered(105.0));
611        assert!(ts.is_triggered(106.0));
612        assert!(ts.is_triggered(110.0));
613    }
614
615    #[test]
616    fn test_trailing_stop_inactive_does_not_update() {
617        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Long);
618        ts.active = false;
619        ts.update(200.0, 2.0, 3.0);
620        assert_eq!(ts.current_stop, 0.0);
621        assert_eq!(ts.highest_since_entry, 100.0);
622    }
623
624    #[test]
625    fn test_trailing_stop_inactive_not_triggered() {
626        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Long);
627        ts.update(100.0, 2.0, 3.0); // stop = 94
628        ts.active = false;
629        assert!(!ts.is_triggered(50.0));
630    }
631
632    #[test]
633    fn test_trailing_stop_zero_stop_not_triggered() {
634        let ts = TrailingStopState::new(100.0, TrailingDirection::Long);
635        // current_stop is 0.0 (never updated)
636        assert!(!ts.is_triggered(0.0));
637    }
638
639    #[test]
640    fn test_trailing_stop_serialization_roundtrip() {
641        let mut ts = TrailingStopState::new(100.0, TrailingDirection::Long);
642        ts.update(105.0, 2.0, 3.0);
643        let json = serde_json::to_string(&ts).unwrap();
644        let parsed: TrailingStopState = serde_json::from_str(&json).unwrap();
645        assert_eq!(parsed.highest_since_entry, 105.0);
646        assert_eq!(parsed.current_stop, ts.current_stop);
647        assert_eq!(parsed.direction, TrailingDirection::Long);
648    }
649
650    // -----------------------------------------------------------------------
651    // #218 — ADX Regime Filter
652    // -----------------------------------------------------------------------
653
654    #[test]
655    fn test_adx_filter_suppress_high_adx() {
656        assert_eq!(apply_adx_filter(0.8, Some(35.0)), 0.0);
657        assert_eq!(apply_adx_filter(1.0, Some(30.1)), 0.0);
658    }
659
660    #[test]
661    fn test_adx_filter_reduce_medium_adx() {
662        let result = apply_adx_filter(0.8, Some(27.0));
663        assert!((result - 0.4).abs() < 1e-10);
664    }
665
666    #[test]
667    fn test_adx_filter_boundary_25() {
668        let result = apply_adx_filter(1.0, Some(25.0));
669        assert!((result - 0.5).abs() < 1e-10);
670    }
671
672    #[test]
673    fn test_adx_filter_boundary_30() {
674        // ADX = 30.0 is >= 25.0 and not > 30.0, so it's in the 25-30 bucket
675        let result = apply_adx_filter(1.0, Some(30.0));
676        assert!((result - 0.5).abs() < 1e-10);
677    }
678
679    #[test]
680    fn test_adx_filter_low_adx_passthrough() {
681        assert_eq!(apply_adx_filter(0.8, Some(20.0)), 0.8);
682        assert_eq!(apply_adx_filter(0.8, Some(24.9)), 0.8);
683    }
684
685    #[test]
686    fn test_adx_filter_none_passthrough() {
687        assert_eq!(apply_adx_filter(0.8, None), 0.8);
688    }
689
690    #[test]
691    fn test_adx_filter_zero_strength() {
692        assert_eq!(apply_adx_filter(0.0, Some(27.0)), 0.0);
693    }
694
695    // -----------------------------------------------------------------------
696    // #219 — Volume Context Analysis
697    // -----------------------------------------------------------------------
698
699    fn make_candle(open: f64, close: f64, high: f64, low: f64, volume: f64) -> Candle {
700        Candle {
701            time: 0,
702            open,
703            high,
704            low,
705            close,
706            volume,
707        }
708    }
709
710    fn make_indicators(overrides: impl FnOnce(&mut TechnicalIndicators)) -> TechnicalIndicators {
711        let mut ind = TechnicalIndicators::empty();
712        overrides(&mut ind);
713        ind
714    }
715
716    #[test]
717    fn test_volume_context_normal_no_spike() {
718        let candles = vec![
719            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
720            make_candle(101.0, 102.0, 103.0, 100.0, 1100.0),
721        ];
722        let ind = make_indicators(|i| i.volume_zscore_20 = Some(0.5));
723        assert_eq!(
724            analyze_volume_context(&candles, &ind),
725            VolumeContext::Normal
726        );
727    }
728
729    #[test]
730    fn test_volume_context_trend_continuation() {
731        // Uptrend (close > prev close) + bullish candle + volume spike
732        let candles = vec![
733            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
734            make_candle(101.0, 105.0, 106.0, 100.0, 5000.0),
735        ];
736        let ind = make_indicators(|i| i.volume_zscore_20 = Some(2.5));
737        assert_eq!(
738            analyze_volume_context(&candles, &ind),
739            VolumeContext::TrendContinuation
740        );
741    }
742
743    #[test]
744    fn test_volume_context_climax_exhaustion() {
745        // Uptrend (close > prev close) but bearish candle (reversal) + volume spike
746        // No RSI extreme → ClimaxExhaustion (not ReversalConfirm)
747        let candles = vec![
748            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
749            make_candle(105.0, 102.0, 106.0, 101.0, 5000.0), // bearish reversal
750        ];
751        let ind = make_indicators(|i| {
752            i.volume_zscore_20 = Some(2.5);
753            i.rsi_14 = Some(55.0); // not extreme
754        });
755        assert_eq!(
756            analyze_volume_context(&candles, &ind),
757            VolumeContext::ClimaxExhaustion
758        );
759    }
760
761    #[test]
762    fn test_volume_context_reversal_confirm() {
763        // Uptrend + bearish candle + RSI overbought + volume spike
764        let candles = vec![
765            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
766            make_candle(105.0, 102.0, 106.0, 101.0, 5000.0), // bearish reversal
767        ];
768        let ind = make_indicators(|i| {
769            i.volume_zscore_20 = Some(2.5);
770            i.rsi_14 = Some(75.0); // overbought
771        });
772        assert_eq!(
773            analyze_volume_context(&candles, &ind),
774            VolumeContext::ReversalConfirm
775        );
776    }
777
778    #[test]
779    fn test_volume_context_regime_change() {
780        let candles = vec![
781            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
782            make_candle(101.0, 105.0, 106.0, 100.0, 10000.0),
783        ];
784        let ind = make_indicators(|i| {
785            i.volume_zscore_20 = Some(3.5);
786            i.adx_14 = Some(30.0); // > 25
787        });
788        assert_eq!(
789            analyze_volume_context(&candles, &ind),
790            VolumeContext::RegimeChange
791        );
792    }
793
794    #[test]
795    fn test_volume_context_no_zscore_returns_normal() {
796        let candles = vec![
797            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
798            make_candle(101.0, 105.0, 106.0, 100.0, 5000.0),
799        ];
800        let ind = TechnicalIndicators::empty();
801        assert_eq!(
802            analyze_volume_context(&candles, &ind),
803            VolumeContext::Normal
804        );
805    }
806
807    #[test]
808    fn test_volume_context_insufficient_candles() {
809        let candles = vec![make_candle(100.0, 101.0, 102.0, 99.0, 1000.0)];
810        let ind = make_indicators(|i| i.volume_zscore_20 = Some(3.0));
811        assert_eq!(
812            analyze_volume_context(&candles, &ind),
813            VolumeContext::Normal
814        );
815    }
816
817    #[test]
818    fn test_volume_context_empty_candles() {
819        let candles: Vec<Candle> = vec![];
820        let ind = make_indicators(|i| i.volume_zscore_20 = Some(3.0));
821        assert_eq!(
822            analyze_volume_context(&candles, &ind),
823            VolumeContext::Normal
824        );
825    }
826
827    #[test]
828    fn test_volume_context_regime_change_low_adx_falls_through() {
829        // zscore > 3 but ADX < 25 → should NOT be RegimeChange
830        let candles = vec![
831            make_candle(100.0, 101.0, 102.0, 99.0, 1000.0),
832            make_candle(101.0, 105.0, 106.0, 100.0, 10000.0), // bullish, with-trend
833        ];
834        let ind = make_indicators(|i| {
835            i.volume_zscore_20 = Some(3.5);
836            i.adx_14 = Some(20.0); // < 25
837        });
838        // Falls through to TrendContinuation (with-trend + spike)
839        assert_eq!(
840            analyze_volume_context(&candles, &ind),
841            VolumeContext::TrendContinuation
842        );
843    }
844
845    // --- Volume Strength Modifier ---
846
847    #[test]
848    fn test_volume_modifier_normal_always_one() {
849        assert_eq!(
850            volume_strength_modifier(&VolumeContext::Normal, "trend_following"),
851            1.0
852        );
853        assert_eq!(
854            volume_strength_modifier(&VolumeContext::Normal, "mean_reversion"),
855            1.0
856        );
857        assert_eq!(
858            volume_strength_modifier(&VolumeContext::Normal, "unknown"),
859            1.0
860        );
861    }
862
863    #[test]
864    fn test_volume_modifier_trend_continuation_amplifies_trend() {
865        assert!(
866            volume_strength_modifier(&VolumeContext::TrendContinuation, "trend_following") > 1.0
867        );
868        assert!(volume_strength_modifier(&VolumeContext::TrendContinuation, "momentum") > 1.0);
869    }
870
871    #[test]
872    fn test_volume_modifier_trend_continuation_dampens_mr() {
873        assert!(
874            volume_strength_modifier(&VolumeContext::TrendContinuation, "mean_reversion") < 1.0
875        );
876    }
877
878    #[test]
879    fn test_volume_modifier_exhaustion_dampens_trend() {
880        assert!(
881            volume_strength_modifier(&VolumeContext::ClimaxExhaustion, "trend_following") < 1.0
882        );
883        assert!(volume_strength_modifier(&VolumeContext::ClimaxExhaustion, "momentum") < 1.0);
884    }
885
886    #[test]
887    fn test_volume_modifier_exhaustion_amplifies_mr() {
888        assert!(volume_strength_modifier(&VolumeContext::ClimaxExhaustion, "mean_reversion") > 1.0);
889    }
890
891    #[test]
892    fn test_volume_modifier_reversal_confirm_amplifies_mr() {
893        assert!(volume_strength_modifier(&VolumeContext::ReversalConfirm, "mean_reversion") > 1.0);
894    }
895
896    #[test]
897    fn test_volume_modifier_reversal_confirm_dampens_trend() {
898        assert!(volume_strength_modifier(&VolumeContext::ReversalConfirm, "trend_following") < 1.0);
899    }
900
901    #[test]
902    fn test_volume_modifier_regime_change_dampens_mr_arb() {
903        assert!(volume_strength_modifier(&VolumeContext::RegimeChange, "mean_reversion") < 1.0);
904        assert!(volume_strength_modifier(&VolumeContext::RegimeChange, "stat_arb") < 1.0);
905    }
906
907    #[test]
908    fn test_volume_modifier_regime_change_amplifies_volatility() {
909        assert!(volume_strength_modifier(&VolumeContext::RegimeChange, "volatility") > 1.0);
910    }
911
912    #[test]
913    fn test_volume_context_serialization_roundtrip() {
914        let contexts = vec![
915            VolumeContext::TrendContinuation,
916            VolumeContext::ClimaxExhaustion,
917            VolumeContext::ReversalConfirm,
918            VolumeContext::RegimeChange,
919            VolumeContext::Normal,
920        ];
921        for ctx in contexts {
922            let json = serde_json::to_string(&ctx).unwrap();
923            let parsed: VolumeContext = serde_json::from_str(&json).unwrap();
924            assert_eq!(parsed, ctx);
925        }
926    }
927}