Skip to main content

hyper_strategy/
strategy_indicator_config.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use hyper_ta::technical_analysis::TechnicalIndicators;
6
7// ---------------------------------------------------------------------------
8// StrategyIndicatorConfig
9// ---------------------------------------------------------------------------
10
11/// Per-strategy indicator configuration. Determines which indicators to include
12/// in the Claude prompt and what thresholds define signal zones.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct StrategyIndicatorConfig {
16    /// Which indicators to include (e.g. "RSI", "MACD", "BB", "ADX", "EMA", "SMA",
17    /// "Stochastic", "ATR", "CCI", "WilliamsR", "OBV", "MFI").
18    pub indicators: Vec<String>,
19    /// Thresholds per indicator: `(low, high)`.
20    /// For example, RSI: (30.0, 70.0) means oversold below 30, overbought above 70.
21    pub thresholds: HashMap<String, (f64, f64)>,
22}
23
24// ---------------------------------------------------------------------------
25// Preset configs per PromptTemplate
26// ---------------------------------------------------------------------------
27
28/// Return the indicator config for a given strategy template name.
29///
30/// Recognised template strings (case-insensitive):
31/// - `"TrendFollowing"` / `"trendfollowing"`
32/// - `"MeanReversion"` / `"meanreversion"`
33/// - `"Scalping"` / `"scalping"`
34/// - `"Conservative"` / `"conservative"`
35///
36/// Unknown templates fall back to the Conservative (all-indicators) config.
37pub fn get_strategy_indicator_config(template: &str) -> StrategyIndicatorConfig {
38    match template.to_lowercase().as_str() {
39        "trendfollowing" => trend_following_config(),
40        "meanreversion" => mean_reversion_config(),
41        "scalping" => scalping_config(),
42        "conservative" => conservative_config(),
43        _ => conservative_config(),
44    }
45}
46
47// --- Preset builders ---
48
49fn trend_following_config() -> StrategyIndicatorConfig {
50    let mut thresholds = HashMap::new();
51    thresholds.insert("MACD".to_string(), (0.0, 0.0)); // histogram sign is the signal
52    thresholds.insert("ADX".to_string(), (20.0, 50.0)); // weak < 20, strong > 50
53    thresholds.insert("EMA".to_string(), (0.0, 0.0)); // cross-based, no numeric threshold
54
55    StrategyIndicatorConfig {
56        indicators: vec!["MACD".to_string(), "ADX".to_string(), "EMA".to_string()],
57        thresholds,
58    }
59}
60
61fn mean_reversion_config() -> StrategyIndicatorConfig {
62    let mut thresholds = HashMap::new();
63    thresholds.insert("RSI".to_string(), (30.0, 70.0));
64    thresholds.insert("BB".to_string(), (-2.0, 2.0)); // std-dev bands
65    thresholds.insert("SMA".to_string(), (0.0, 0.0)); // deviation from SMA
66
67    StrategyIndicatorConfig {
68        indicators: vec!["RSI".to_string(), "BB".to_string(), "SMA".to_string()],
69        thresholds,
70    }
71}
72
73fn scalping_config() -> StrategyIndicatorConfig {
74    let mut thresholds = HashMap::new();
75    thresholds.insert("Stochastic".to_string(), (20.0, 80.0));
76    thresholds.insert("ATR".to_string(), (0.0, 100.0)); // absolute range
77    thresholds.insert("EMA".to_string(), (0.0, 0.0)); // short MA crossover
78
79    StrategyIndicatorConfig {
80        indicators: vec![
81            "Stochastic".to_string(),
82            "ATR".to_string(),
83            "EMA".to_string(),
84        ],
85        thresholds,
86    }
87}
88
89fn conservative_config() -> StrategyIndicatorConfig {
90    let mut thresholds = HashMap::new();
91    // Stricter thresholds for conservative mode
92    thresholds.insert("RSI".to_string(), (25.0, 75.0));
93    thresholds.insert("MACD".to_string(), (0.0, 0.0));
94    thresholds.insert("BB".to_string(), (-2.5, 2.5));
95    thresholds.insert("ADX".to_string(), (25.0, 50.0));
96    thresholds.insert("EMA".to_string(), (0.0, 0.0));
97    thresholds.insert("SMA".to_string(), (0.0, 0.0));
98    thresholds.insert("Stochastic".to_string(), (20.0, 80.0));
99    thresholds.insert("ATR".to_string(), (0.0, 100.0));
100    thresholds.insert("CCI".to_string(), (-100.0, 100.0));
101    thresholds.insert("WilliamsR".to_string(), (-80.0, -20.0));
102    thresholds.insert("OBV".to_string(), (0.0, 0.0));
103    thresholds.insert("MFI".to_string(), (20.0, 80.0));
104
105    StrategyIndicatorConfig {
106        indicators: vec![
107            "RSI".to_string(),
108            "MACD".to_string(),
109            "BB".to_string(),
110            "ADX".to_string(),
111            "EMA".to_string(),
112            "SMA".to_string(),
113            "Stochastic".to_string(),
114            "ATR".to_string(),
115            "CCI".to_string(),
116            "WilliamsR".to_string(),
117            "OBV".to_string(),
118            "MFI".to_string(),
119        ],
120        thresholds,
121    }
122}
123
124// ---------------------------------------------------------------------------
125// filter_indicators_for_prompt
126// ---------------------------------------------------------------------------
127
128/// Filter a `TechnicalIndicators` struct according to the given strategy config,
129/// producing a compact string suitable for inclusion in a Claude prompt.
130///
131/// Only indicators listed in `config.indicators` are emitted. For indicators
132/// that have thresholds, a zone label (e.g. "oversold", "overbought", "neutral")
133/// is appended.
134pub fn filter_indicators_for_prompt(
135    indicators: &TechnicalIndicators,
136    config: &StrategyIndicatorConfig,
137) -> String {
138    let allowed: std::collections::HashSet<&str> =
139        config.indicators.iter().map(|s| s.as_str()).collect();
140
141    let mut lines: Vec<String> = Vec::new();
142
143    // --- Trend ---
144    let mut trend_parts: Vec<String> = Vec::new();
145
146    if allowed.contains("SMA") {
147        if let Some(v) = indicators.sma_20 {
148            trend_parts.push(format!("SMA20={:.2}", v));
149        }
150        if let Some(v) = indicators.sma_50 {
151            trend_parts.push(format!("SMA50={:.2}", v));
152        }
153    }
154
155    if allowed.contains("EMA") {
156        if let Some(v) = indicators.ema_12 {
157            trend_parts.push(format!("EMA12={:.2}", v));
158        }
159        if let Some(v) = indicators.ema_26 {
160            trend_parts.push(format!("EMA26={:.2}", v));
161        }
162    }
163
164    if allowed.contains("MACD") {
165        if let Some(hist) = indicators.macd_histogram {
166            let sign = if hist >= 0.0 { "+" } else { "" };
167            let label = if hist > 0.0 {
168                "bullish"
169            } else if hist < 0.0 {
170                "bearish"
171            } else {
172                "neutral"
173            };
174            trend_parts.push(format!("MACD={}{:.4} ({})", sign, hist, label));
175        }
176    }
177
178    if allowed.contains("ADX") {
179        if let Some(v) = indicators.adx_14 {
180            let (low, _high) = config
181                .thresholds
182                .get("ADX")
183                .copied()
184                .unwrap_or((25.0, 50.0));
185            let strength = if v >= low { "strong" } else { "weak" };
186            trend_parts.push(format!("ADX={:.0} ({})", v, strength));
187        }
188    }
189
190    if !trend_parts.is_empty() {
191        lines.push(format!("Trend: {}", trend_parts.join(" ")));
192    }
193
194    // --- Momentum ---
195    let mut mom_parts: Vec<String> = Vec::new();
196
197    if allowed.contains("RSI") {
198        if let Some(v) = indicators.rsi_14 {
199            let (low, high) = config
200                .thresholds
201                .get("RSI")
202                .copied()
203                .unwrap_or((30.0, 70.0));
204            let zone = zone_label(v, low, high);
205            mom_parts.push(format!("RSI={:.0} ({})", v, zone));
206        }
207    }
208
209    if allowed.contains("Stochastic") {
210        if let Some(k) = indicators.stoch_k {
211            let (low, high) = config
212                .thresholds
213                .get("Stochastic")
214                .copied()
215                .unwrap_or((20.0, 80.0));
216            let zone = zone_label(k, low, high);
217            mom_parts.push(format!("Stoch={:.0} ({})", k, zone));
218        }
219    }
220
221    if allowed.contains("CCI") {
222        if let Some(v) = indicators.cci_20 {
223            let (low, high) = config
224                .thresholds
225                .get("CCI")
226                .copied()
227                .unwrap_or((-100.0, 100.0));
228            let zone = zone_label(v, low, high);
229            mom_parts.push(format!("CCI={:.0} ({})", v, zone));
230        }
231    }
232
233    if allowed.contains("WilliamsR") {
234        if let Some(v) = indicators.williams_r_14 {
235            let (low, high) = config
236                .thresholds
237                .get("WilliamsR")
238                .copied()
239                .unwrap_or((-80.0, -20.0));
240            // Williams %R: below low => oversold, above high => overbought
241            let zone = zone_label(v, low, high);
242            mom_parts.push(format!("WR={:.0} ({})", v, zone));
243        }
244    }
245
246    if allowed.contains("MFI") {
247        if let Some(v) = indicators.mfi_14 {
248            let (low, high) = config
249                .thresholds
250                .get("MFI")
251                .copied()
252                .unwrap_or((20.0, 80.0));
253            let zone = zone_label(v, low, high);
254            mom_parts.push(format!("MFI={:.0} ({})", v, zone));
255        }
256    }
257
258    if !mom_parts.is_empty() {
259        lines.push(format!("Momentum: {}", mom_parts.join(" ")));
260    }
261
262    // --- Volatility ---
263    let mut vol_parts: Vec<String> = Vec::new();
264
265    if allowed.contains("BB") {
266        if let (Some(bl), Some(bm), Some(bu)) = (
267            indicators.bb_lower,
268            indicators.bb_middle,
269            indicators.bb_upper,
270        ) {
271            vol_parts.push(format!("BB[{:.2} - {:.2} - {:.2}]", bl, bm, bu));
272        }
273    }
274
275    if allowed.contains("ATR") {
276        if let Some(v) = indicators.atr_14 {
277            vol_parts.push(format!("ATR={:.2}", v));
278        }
279    }
280
281    if !vol_parts.is_empty() {
282        lines.push(format!("Volatility: {}", vol_parts.join(" ")));
283    }
284
285    // --- Volume ---
286    let mut vol_line_parts: Vec<String> = Vec::new();
287
288    if allowed.contains("OBV") {
289        if let Some(v) = indicators.obv {
290            vol_line_parts.push(format!("OBV={:.0}", v));
291        }
292    }
293
294    if !vol_line_parts.is_empty() {
295        lines.push(format!("Volume: {}", vol_line_parts.join(" ")));
296    }
297
298    lines.join("\n")
299}
300
301// ---------------------------------------------------------------------------
302// Helpers
303// ---------------------------------------------------------------------------
304
305/// Simple zone label: below `low` => "oversold", above `high` => "overbought", else "neutral".
306fn zone_label(value: f64, low: f64, high: f64) -> &'static str {
307    if value <= low {
308        "oversold"
309    } else if value >= high {
310        "overbought"
311    } else {
312        "neutral"
313    }
314}
315
316// ---------------------------------------------------------------------------
317// Tests
318// ---------------------------------------------------------------------------
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    // Helper: build a TechnicalIndicators with all fields set.
325    fn full_indicators() -> TechnicalIndicators {
326        TechnicalIndicators {
327            sma_20: Some(64000.0),
328            sma_50: Some(62000.0),
329            ema_12: Some(64500.0),
330            ema_20: Some(64000.0),
331            ema_26: Some(63500.0),
332            ema_50: Some(62500.0),
333            rsi_14: Some(55.0),
334            macd_line: Some(500.0),
335            macd_signal: Some(400.0),
336            macd_histogram: Some(100.0),
337            bb_upper: Some(68000.0),
338            bb_middle: Some(65000.0),
339            bb_lower: Some(62000.0),
340            atr_14: Some(350.0),
341            adx_14: Some(30.0),
342            stoch_k: Some(65.0),
343            stoch_d: Some(60.0),
344            cci_20: Some(50.0),
345            williams_r_14: Some(-45.0),
346            obv: Some(12345678.0),
347            mfi_14: Some(55.0),
348            roc_12: Some(5.0),
349            donchian_upper_20: Some(66000.0),
350            donchian_lower_20: Some(60000.0),
351            donchian_upper_10: Some(65500.0),
352            donchian_lower_10: Some(60500.0),
353            close_zscore_20: Some(0.5),
354            volume_zscore_20: Some(0.3),
355            hv_20: Some(0.25),
356            hv_60: Some(0.30),
357            kc_upper_20: Some(66000.0),
358            kc_lower_20: Some(62000.0),
359            supertrend_value: Some(63000.0),
360            supertrend_direction: Some(1.0),
361            vwap: Some(64000.0),
362            plus_di_14: Some(25.0),
363            minus_di_14: Some(20.0),
364        }
365    }
366
367    // ---- get_strategy_indicator_config ----
368
369    #[test]
370    fn test_trend_following_config_indicators() {
371        let cfg = get_strategy_indicator_config("TrendFollowing");
372        assert!(cfg.indicators.contains(&"MACD".to_string()));
373        assert!(cfg.indicators.contains(&"ADX".to_string()));
374        assert!(cfg.indicators.contains(&"EMA".to_string()));
375        assert_eq!(cfg.indicators.len(), 3);
376    }
377
378    #[test]
379    fn test_mean_reversion_config_indicators() {
380        let cfg = get_strategy_indicator_config("MeanReversion");
381        assert!(cfg.indicators.contains(&"RSI".to_string()));
382        assert!(cfg.indicators.contains(&"BB".to_string()));
383        assert!(cfg.indicators.contains(&"SMA".to_string()));
384        assert_eq!(cfg.indicators.len(), 3);
385    }
386
387    #[test]
388    fn test_scalping_config_indicators() {
389        let cfg = get_strategy_indicator_config("Scalping");
390        assert!(cfg.indicators.contains(&"Stochastic".to_string()));
391        assert!(cfg.indicators.contains(&"ATR".to_string()));
392        assert!(cfg.indicators.contains(&"EMA".to_string()));
393        assert_eq!(cfg.indicators.len(), 3);
394    }
395
396    #[test]
397    fn test_conservative_config_includes_all_indicators() {
398        let cfg = get_strategy_indicator_config("Conservative");
399        assert!(cfg.indicators.len() >= 10);
400        // Should include everything
401        assert!(cfg.indicators.contains(&"RSI".to_string()));
402        assert!(cfg.indicators.contains(&"MACD".to_string()));
403        assert!(cfg.indicators.contains(&"BB".to_string()));
404        assert!(cfg.indicators.contains(&"ADX".to_string()));
405        assert!(cfg.indicators.contains(&"Stochastic".to_string()));
406        assert!(cfg.indicators.contains(&"ATR".to_string()));
407        assert!(cfg.indicators.contains(&"CCI".to_string()));
408        assert!(cfg.indicators.contains(&"OBV".to_string()));
409        assert!(cfg.indicators.contains(&"MFI".to_string()));
410    }
411
412    #[test]
413    fn test_unknown_template_falls_back_to_conservative() {
414        let cfg = get_strategy_indicator_config("UnknownStrategy");
415        let conservative = get_strategy_indicator_config("Conservative");
416        assert_eq!(cfg.indicators.len(), conservative.indicators.len());
417    }
418
419    #[test]
420    fn test_case_insensitive_lookup() {
421        let cfg_lower = get_strategy_indicator_config("trendfollowing");
422        let cfg_mixed = get_strategy_indicator_config("TrendFollowing");
423        assert_eq!(cfg_lower.indicators, cfg_mixed.indicators);
424    }
425
426    // ---- Threshold values ----
427
428    #[test]
429    fn test_trend_following_thresholds() {
430        let cfg = get_strategy_indicator_config("TrendFollowing");
431        let adx = cfg.thresholds.get("ADX").expect("ADX threshold missing");
432        assert_eq!(*adx, (20.0, 50.0));
433    }
434
435    #[test]
436    fn test_conservative_stricter_rsi_thresholds() {
437        let cfg = get_strategy_indicator_config("Conservative");
438        let rsi = cfg.thresholds.get("RSI").expect("RSI threshold missing");
439        // Conservative RSI thresholds are stricter: (25, 75) vs default (30, 70)
440        assert_eq!(*rsi, (25.0, 75.0));
441    }
442
443    // ---- filter_indicators_for_prompt ----
444
445    #[test]
446    fn test_filter_trend_following_only_includes_trend_indicators() {
447        let indicators = full_indicators();
448        let cfg = get_strategy_indicator_config("TrendFollowing");
449        let result = filter_indicators_for_prompt(&indicators, &cfg);
450
451        // Should include EMA, MACD, ADX
452        assert!(result.contains("EMA12="));
453        assert!(result.contains("MACD="));
454        assert!(result.contains("ADX="));
455
456        // Should NOT include RSI, BB, Stochastic, etc.
457        assert!(!result.contains("RSI="));
458        assert!(!result.contains("BB["));
459        assert!(!result.contains("Stoch="));
460        assert!(!result.contains("OBV="));
461    }
462
463    #[test]
464    fn test_filter_mean_reversion_includes_rsi_bb_sma() {
465        let indicators = full_indicators();
466        let cfg = get_strategy_indicator_config("MeanReversion");
467        let result = filter_indicators_for_prompt(&indicators, &cfg);
468
469        assert!(result.contains("RSI="));
470        assert!(result.contains("BB["));
471        assert!(result.contains("SMA20="));
472
473        // Should NOT include MACD, ADX, Stochastic
474        assert!(!result.contains("MACD="));
475        assert!(!result.contains("ADX="));
476        assert!(!result.contains("Stoch="));
477    }
478
479    #[test]
480    fn test_filter_scalping_includes_stochastic_atr_ema() {
481        let indicators = full_indicators();
482        let cfg = get_strategy_indicator_config("Scalping");
483        let result = filter_indicators_for_prompt(&indicators, &cfg);
484
485        assert!(result.contains("Stoch="));
486        assert!(result.contains("ATR="));
487        assert!(result.contains("EMA12="));
488
489        // Should NOT include RSI, BB, ADX
490        assert!(!result.contains("RSI="));
491        assert!(!result.contains("BB["));
492        assert!(!result.contains("ADX="));
493    }
494
495    #[test]
496    fn test_filter_conservative_includes_everything() {
497        let indicators = full_indicators();
498        let cfg = get_strategy_indicator_config("Conservative");
499        let result = filter_indicators_for_prompt(&indicators, &cfg);
500
501        assert!(result.contains("SMA20="));
502        assert!(result.contains("EMA12="));
503        assert!(result.contains("MACD="));
504        assert!(result.contains("ADX="));
505        assert!(result.contains("RSI="));
506        assert!(result.contains("Stoch="));
507        assert!(result.contains("BB["));
508        assert!(result.contains("ATR="));
509        assert!(result.contains("CCI="));
510        assert!(result.contains("WR="));
511        assert!(result.contains("OBV="));
512        assert!(result.contains("MFI="));
513    }
514
515    #[test]
516    fn test_filter_empty_indicators_returns_empty_string() {
517        let indicators = TechnicalIndicators::empty();
518        let cfg = get_strategy_indicator_config("TrendFollowing");
519        let result = filter_indicators_for_prompt(&indicators, &cfg);
520        assert!(result.is_empty());
521    }
522
523    #[test]
524    fn test_filter_respects_custom_thresholds() {
525        let mut indicators = TechnicalIndicators::empty();
526        indicators.rsi_14 = Some(28.0);
527
528        // Default mean-reversion: RSI thresholds (30, 70) => 28 is oversold
529        let cfg = get_strategy_indicator_config("MeanReversion");
530        let result = filter_indicators_for_prompt(&indicators, &cfg);
531        assert!(result.contains("oversold"));
532
533        // Conservative: RSI thresholds (25, 75) => 28 is neutral (above 25)
534        let cfg2 = get_strategy_indicator_config("Conservative");
535        let result2 = filter_indicators_for_prompt(&indicators, &cfg2);
536        assert!(result2.contains("neutral"));
537    }
538
539    #[test]
540    fn test_zone_label_boundaries() {
541        assert_eq!(zone_label(30.0, 30.0, 70.0), "oversold"); // exactly at low
542        assert_eq!(zone_label(70.0, 30.0, 70.0), "overbought"); // exactly at high
543        assert_eq!(zone_label(50.0, 30.0, 70.0), "neutral");
544        assert_eq!(zone_label(29.9, 30.0, 70.0), "oversold");
545        assert_eq!(zone_label(70.1, 30.0, 70.0), "overbought");
546    }
547
548    #[test]
549    fn test_serialization_roundtrip() {
550        let cfg = get_strategy_indicator_config("TrendFollowing");
551        let json = serde_json::to_string(&cfg).unwrap();
552        let deserialized: StrategyIndicatorConfig = serde_json::from_str(&json).unwrap();
553        assert_eq!(deserialized.indicators, cfg.indicators);
554        assert_eq!(deserialized.thresholds.len(), cfg.thresholds.len());
555    }
556
557    #[test]
558    fn test_filter_macd_bearish_label() {
559        let mut indicators = TechnicalIndicators::empty();
560        indicators.macd_histogram = Some(-150.0);
561        let cfg = get_strategy_indicator_config("TrendFollowing");
562        let result = filter_indicators_for_prompt(&indicators, &cfg);
563        assert!(result.contains("bearish"));
564        assert!(result.contains("MACD="));
565    }
566
567    #[test]
568    fn test_filter_adx_weak_vs_strong() {
569        let mut indicators = TechnicalIndicators::empty();
570        indicators.adx_14 = Some(15.0);
571        let cfg = get_strategy_indicator_config("TrendFollowing");
572        let result = filter_indicators_for_prompt(&indicators, &cfg);
573        assert!(result.contains("weak"));
574
575        indicators.adx_14 = Some(35.0);
576        let result2 = filter_indicators_for_prompt(&indicators, &cfg);
577        assert!(result2.contains("strong"));
578    }
579}