1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use hyper_ta::technical_analysis::TechnicalIndicators;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(rename_all = "camelCase")]
15pub struct StrategyIndicatorConfig {
16 pub indicators: Vec<String>,
19 pub thresholds: HashMap<String, (f64, f64)>,
22}
23
24pub fn get_strategy_indicator_config(template: &str) -> StrategyIndicatorConfig {
38 match template.to_lowercase().as_str() {
39 "trendfollowing" => trend_following_config(),
40 "meanreversion" => mean_reversion_config(),
41 "scalping" => scalping_config(),
42 "conservative" => conservative_config(),
43 _ => conservative_config(),
44 }
45}
46
47fn trend_following_config() -> StrategyIndicatorConfig {
50 let mut thresholds = HashMap::new();
51 thresholds.insert("MACD".to_string(), (0.0, 0.0)); thresholds.insert("ADX".to_string(), (20.0, 50.0)); thresholds.insert("EMA".to_string(), (0.0, 0.0)); StrategyIndicatorConfig {
56 indicators: vec!["MACD".to_string(), "ADX".to_string(), "EMA".to_string()],
57 thresholds,
58 }
59}
60
61fn mean_reversion_config() -> StrategyIndicatorConfig {
62 let mut thresholds = HashMap::new();
63 thresholds.insert("RSI".to_string(), (30.0, 70.0));
64 thresholds.insert("BB".to_string(), (-2.0, 2.0)); thresholds.insert("SMA".to_string(), (0.0, 0.0)); StrategyIndicatorConfig {
68 indicators: vec!["RSI".to_string(), "BB".to_string(), "SMA".to_string()],
69 thresholds,
70 }
71}
72
73fn scalping_config() -> StrategyIndicatorConfig {
74 let mut thresholds = HashMap::new();
75 thresholds.insert("Stochastic".to_string(), (20.0, 80.0));
76 thresholds.insert("ATR".to_string(), (0.0, 100.0)); thresholds.insert("EMA".to_string(), (0.0, 0.0)); StrategyIndicatorConfig {
80 indicators: vec![
81 "Stochastic".to_string(),
82 "ATR".to_string(),
83 "EMA".to_string(),
84 ],
85 thresholds,
86 }
87}
88
89fn conservative_config() -> StrategyIndicatorConfig {
90 let mut thresholds = HashMap::new();
91 thresholds.insert("RSI".to_string(), (25.0, 75.0));
93 thresholds.insert("MACD".to_string(), (0.0, 0.0));
94 thresholds.insert("BB".to_string(), (-2.5, 2.5));
95 thresholds.insert("ADX".to_string(), (25.0, 50.0));
96 thresholds.insert("EMA".to_string(), (0.0, 0.0));
97 thresholds.insert("SMA".to_string(), (0.0, 0.0));
98 thresholds.insert("Stochastic".to_string(), (20.0, 80.0));
99 thresholds.insert("ATR".to_string(), (0.0, 100.0));
100 thresholds.insert("CCI".to_string(), (-100.0, 100.0));
101 thresholds.insert("WilliamsR".to_string(), (-80.0, -20.0));
102 thresholds.insert("OBV".to_string(), (0.0, 0.0));
103 thresholds.insert("MFI".to_string(), (20.0, 80.0));
104
105 StrategyIndicatorConfig {
106 indicators: vec![
107 "RSI".to_string(),
108 "MACD".to_string(),
109 "BB".to_string(),
110 "ADX".to_string(),
111 "EMA".to_string(),
112 "SMA".to_string(),
113 "Stochastic".to_string(),
114 "ATR".to_string(),
115 "CCI".to_string(),
116 "WilliamsR".to_string(),
117 "OBV".to_string(),
118 "MFI".to_string(),
119 ],
120 thresholds,
121 }
122}
123
124pub fn filter_indicators_for_prompt(
135 indicators: &TechnicalIndicators,
136 config: &StrategyIndicatorConfig,
137) -> String {
138 let allowed: std::collections::HashSet<&str> =
139 config.indicators.iter().map(|s| s.as_str()).collect();
140
141 let mut lines: Vec<String> = Vec::new();
142
143 let mut trend_parts: Vec<String> = Vec::new();
145
146 if allowed.contains("SMA") {
147 if let Some(v) = indicators.sma_20 {
148 trend_parts.push(format!("SMA20={:.2}", v));
149 }
150 if let Some(v) = indicators.sma_50 {
151 trend_parts.push(format!("SMA50={:.2}", v));
152 }
153 }
154
155 if allowed.contains("EMA") {
156 if let Some(v) = indicators.ema_12 {
157 trend_parts.push(format!("EMA12={:.2}", v));
158 }
159 if let Some(v) = indicators.ema_26 {
160 trend_parts.push(format!("EMA26={:.2}", v));
161 }
162 }
163
164 if allowed.contains("MACD") {
165 if let Some(hist) = indicators.macd_histogram {
166 let sign = if hist >= 0.0 { "+" } else { "" };
167 let label = if hist > 0.0 {
168 "bullish"
169 } else if hist < 0.0 {
170 "bearish"
171 } else {
172 "neutral"
173 };
174 trend_parts.push(format!("MACD={}{:.4} ({})", sign, hist, label));
175 }
176 }
177
178 if allowed.contains("ADX") {
179 if let Some(v) = indicators.adx_14 {
180 let (low, _high) = config
181 .thresholds
182 .get("ADX")
183 .copied()
184 .unwrap_or((25.0, 50.0));
185 let strength = if v >= low { "strong" } else { "weak" };
186 trend_parts.push(format!("ADX={:.0} ({})", v, strength));
187 }
188 }
189
190 if !trend_parts.is_empty() {
191 lines.push(format!("Trend: {}", trend_parts.join(" ")));
192 }
193
194 let mut mom_parts: Vec<String> = Vec::new();
196
197 if allowed.contains("RSI") {
198 if let Some(v) = indicators.rsi_14 {
199 let (low, high) = config
200 .thresholds
201 .get("RSI")
202 .copied()
203 .unwrap_or((30.0, 70.0));
204 let zone = zone_label(v, low, high);
205 mom_parts.push(format!("RSI={:.0} ({})", v, zone));
206 }
207 }
208
209 if allowed.contains("Stochastic") {
210 if let Some(k) = indicators.stoch_k {
211 let (low, high) = config
212 .thresholds
213 .get("Stochastic")
214 .copied()
215 .unwrap_or((20.0, 80.0));
216 let zone = zone_label(k, low, high);
217 mom_parts.push(format!("Stoch={:.0} ({})", k, zone));
218 }
219 }
220
221 if allowed.contains("CCI") {
222 if let Some(v) = indicators.cci_20 {
223 let (low, high) = config
224 .thresholds
225 .get("CCI")
226 .copied()
227 .unwrap_or((-100.0, 100.0));
228 let zone = zone_label(v, low, high);
229 mom_parts.push(format!("CCI={:.0} ({})", v, zone));
230 }
231 }
232
233 if allowed.contains("WilliamsR") {
234 if let Some(v) = indicators.williams_r_14 {
235 let (low, high) = config
236 .thresholds
237 .get("WilliamsR")
238 .copied()
239 .unwrap_or((-80.0, -20.0));
240 let zone = zone_label(v, low, high);
242 mom_parts.push(format!("WR={:.0} ({})", v, zone));
243 }
244 }
245
246 if allowed.contains("MFI") {
247 if let Some(v) = indicators.mfi_14 {
248 let (low, high) = config
249 .thresholds
250 .get("MFI")
251 .copied()
252 .unwrap_or((20.0, 80.0));
253 let zone = zone_label(v, low, high);
254 mom_parts.push(format!("MFI={:.0} ({})", v, zone));
255 }
256 }
257
258 if !mom_parts.is_empty() {
259 lines.push(format!("Momentum: {}", mom_parts.join(" ")));
260 }
261
262 let mut vol_parts: Vec<String> = Vec::new();
264
265 if allowed.contains("BB") {
266 if let (Some(bl), Some(bm), Some(bu)) = (
267 indicators.bb_lower,
268 indicators.bb_middle,
269 indicators.bb_upper,
270 ) {
271 vol_parts.push(format!("BB[{:.2} - {:.2} - {:.2}]", bl, bm, bu));
272 }
273 }
274
275 if allowed.contains("ATR") {
276 if let Some(v) = indicators.atr_14 {
277 vol_parts.push(format!("ATR={:.2}", v));
278 }
279 }
280
281 if !vol_parts.is_empty() {
282 lines.push(format!("Volatility: {}", vol_parts.join(" ")));
283 }
284
285 let mut vol_line_parts: Vec<String> = Vec::new();
287
288 if allowed.contains("OBV") {
289 if let Some(v) = indicators.obv {
290 vol_line_parts.push(format!("OBV={:.0}", v));
291 }
292 }
293
294 if !vol_line_parts.is_empty() {
295 lines.push(format!("Volume: {}", vol_line_parts.join(" ")));
296 }
297
298 lines.join("\n")
299}
300
301fn zone_label(value: f64, low: f64, high: f64) -> &'static str {
307 if value <= low {
308 "oversold"
309 } else if value >= high {
310 "overbought"
311 } else {
312 "neutral"
313 }
314}
315
316#[cfg(test)]
321mod tests {
322 use super::*;
323
324 fn full_indicators() -> TechnicalIndicators {
326 TechnicalIndicators {
327 sma_20: Some(64000.0),
328 sma_50: Some(62000.0),
329 ema_12: Some(64500.0),
330 ema_20: Some(64000.0),
331 ema_26: Some(63500.0),
332 ema_50: Some(62500.0),
333 rsi_14: Some(55.0),
334 macd_line: Some(500.0),
335 macd_signal: Some(400.0),
336 macd_histogram: Some(100.0),
337 bb_upper: Some(68000.0),
338 bb_middle: Some(65000.0),
339 bb_lower: Some(62000.0),
340 atr_14: Some(350.0),
341 adx_14: Some(30.0),
342 stoch_k: Some(65.0),
343 stoch_d: Some(60.0),
344 cci_20: Some(50.0),
345 williams_r_14: Some(-45.0),
346 obv: Some(12345678.0),
347 mfi_14: Some(55.0),
348 roc_12: Some(5.0),
349 donchian_upper_20: Some(66000.0),
350 donchian_lower_20: Some(60000.0),
351 donchian_upper_10: Some(65500.0),
352 donchian_lower_10: Some(60500.0),
353 close_zscore_20: Some(0.5),
354 volume_zscore_20: Some(0.3),
355 hv_20: Some(0.25),
356 hv_60: Some(0.30),
357 kc_upper_20: Some(66000.0),
358 kc_lower_20: Some(62000.0),
359 supertrend_value: Some(63000.0),
360 supertrend_direction: Some(1.0),
361 vwap: Some(64000.0),
362 plus_di_14: Some(25.0),
363 minus_di_14: Some(20.0),
364 }
365 }
366
367 #[test]
370 fn test_trend_following_config_indicators() {
371 let cfg = get_strategy_indicator_config("TrendFollowing");
372 assert!(cfg.indicators.contains(&"MACD".to_string()));
373 assert!(cfg.indicators.contains(&"ADX".to_string()));
374 assert!(cfg.indicators.contains(&"EMA".to_string()));
375 assert_eq!(cfg.indicators.len(), 3);
376 }
377
378 #[test]
379 fn test_mean_reversion_config_indicators() {
380 let cfg = get_strategy_indicator_config("MeanReversion");
381 assert!(cfg.indicators.contains(&"RSI".to_string()));
382 assert!(cfg.indicators.contains(&"BB".to_string()));
383 assert!(cfg.indicators.contains(&"SMA".to_string()));
384 assert_eq!(cfg.indicators.len(), 3);
385 }
386
387 #[test]
388 fn test_scalping_config_indicators() {
389 let cfg = get_strategy_indicator_config("Scalping");
390 assert!(cfg.indicators.contains(&"Stochastic".to_string()));
391 assert!(cfg.indicators.contains(&"ATR".to_string()));
392 assert!(cfg.indicators.contains(&"EMA".to_string()));
393 assert_eq!(cfg.indicators.len(), 3);
394 }
395
396 #[test]
397 fn test_conservative_config_includes_all_indicators() {
398 let cfg = get_strategy_indicator_config("Conservative");
399 assert!(cfg.indicators.len() >= 10);
400 assert!(cfg.indicators.contains(&"RSI".to_string()));
402 assert!(cfg.indicators.contains(&"MACD".to_string()));
403 assert!(cfg.indicators.contains(&"BB".to_string()));
404 assert!(cfg.indicators.contains(&"ADX".to_string()));
405 assert!(cfg.indicators.contains(&"Stochastic".to_string()));
406 assert!(cfg.indicators.contains(&"ATR".to_string()));
407 assert!(cfg.indicators.contains(&"CCI".to_string()));
408 assert!(cfg.indicators.contains(&"OBV".to_string()));
409 assert!(cfg.indicators.contains(&"MFI".to_string()));
410 }
411
412 #[test]
413 fn test_unknown_template_falls_back_to_conservative() {
414 let cfg = get_strategy_indicator_config("UnknownStrategy");
415 let conservative = get_strategy_indicator_config("Conservative");
416 assert_eq!(cfg.indicators.len(), conservative.indicators.len());
417 }
418
419 #[test]
420 fn test_case_insensitive_lookup() {
421 let cfg_lower = get_strategy_indicator_config("trendfollowing");
422 let cfg_mixed = get_strategy_indicator_config("TrendFollowing");
423 assert_eq!(cfg_lower.indicators, cfg_mixed.indicators);
424 }
425
426 #[test]
429 fn test_trend_following_thresholds() {
430 let cfg = get_strategy_indicator_config("TrendFollowing");
431 let adx = cfg.thresholds.get("ADX").expect("ADX threshold missing");
432 assert_eq!(*adx, (20.0, 50.0));
433 }
434
435 #[test]
436 fn test_conservative_stricter_rsi_thresholds() {
437 let cfg = get_strategy_indicator_config("Conservative");
438 let rsi = cfg.thresholds.get("RSI").expect("RSI threshold missing");
439 assert_eq!(*rsi, (25.0, 75.0));
441 }
442
443 #[test]
446 fn test_filter_trend_following_only_includes_trend_indicators() {
447 let indicators = full_indicators();
448 let cfg = get_strategy_indicator_config("TrendFollowing");
449 let result = filter_indicators_for_prompt(&indicators, &cfg);
450
451 assert!(result.contains("EMA12="));
453 assert!(result.contains("MACD="));
454 assert!(result.contains("ADX="));
455
456 assert!(!result.contains("RSI="));
458 assert!(!result.contains("BB["));
459 assert!(!result.contains("Stoch="));
460 assert!(!result.contains("OBV="));
461 }
462
463 #[test]
464 fn test_filter_mean_reversion_includes_rsi_bb_sma() {
465 let indicators = full_indicators();
466 let cfg = get_strategy_indicator_config("MeanReversion");
467 let result = filter_indicators_for_prompt(&indicators, &cfg);
468
469 assert!(result.contains("RSI="));
470 assert!(result.contains("BB["));
471 assert!(result.contains("SMA20="));
472
473 assert!(!result.contains("MACD="));
475 assert!(!result.contains("ADX="));
476 assert!(!result.contains("Stoch="));
477 }
478
479 #[test]
480 fn test_filter_scalping_includes_stochastic_atr_ema() {
481 let indicators = full_indicators();
482 let cfg = get_strategy_indicator_config("Scalping");
483 let result = filter_indicators_for_prompt(&indicators, &cfg);
484
485 assert!(result.contains("Stoch="));
486 assert!(result.contains("ATR="));
487 assert!(result.contains("EMA12="));
488
489 assert!(!result.contains("RSI="));
491 assert!(!result.contains("BB["));
492 assert!(!result.contains("ADX="));
493 }
494
495 #[test]
496 fn test_filter_conservative_includes_everything() {
497 let indicators = full_indicators();
498 let cfg = get_strategy_indicator_config("Conservative");
499 let result = filter_indicators_for_prompt(&indicators, &cfg);
500
501 assert!(result.contains("SMA20="));
502 assert!(result.contains("EMA12="));
503 assert!(result.contains("MACD="));
504 assert!(result.contains("ADX="));
505 assert!(result.contains("RSI="));
506 assert!(result.contains("Stoch="));
507 assert!(result.contains("BB["));
508 assert!(result.contains("ATR="));
509 assert!(result.contains("CCI="));
510 assert!(result.contains("WR="));
511 assert!(result.contains("OBV="));
512 assert!(result.contains("MFI="));
513 }
514
515 #[test]
516 fn test_filter_empty_indicators_returns_empty_string() {
517 let indicators = TechnicalIndicators::empty();
518 let cfg = get_strategy_indicator_config("TrendFollowing");
519 let result = filter_indicators_for_prompt(&indicators, &cfg);
520 assert!(result.is_empty());
521 }
522
523 #[test]
524 fn test_filter_respects_custom_thresholds() {
525 let mut indicators = TechnicalIndicators::empty();
526 indicators.rsi_14 = Some(28.0);
527
528 let cfg = get_strategy_indicator_config("MeanReversion");
530 let result = filter_indicators_for_prompt(&indicators, &cfg);
531 assert!(result.contains("oversold"));
532
533 let cfg2 = get_strategy_indicator_config("Conservative");
535 let result2 = filter_indicators_for_prompt(&indicators, &cfg2);
536 assert!(result2.contains("neutral"));
537 }
538
539 #[test]
540 fn test_zone_label_boundaries() {
541 assert_eq!(zone_label(30.0, 30.0, 70.0), "oversold"); assert_eq!(zone_label(70.0, 30.0, 70.0), "overbought"); assert_eq!(zone_label(50.0, 30.0, 70.0), "neutral");
544 assert_eq!(zone_label(29.9, 30.0, 70.0), "oversold");
545 assert_eq!(zone_label(70.1, 30.0, 70.0), "overbought");
546 }
547
548 #[test]
549 fn test_serialization_roundtrip() {
550 let cfg = get_strategy_indicator_config("TrendFollowing");
551 let json = serde_json::to_string(&cfg).unwrap();
552 let deserialized: StrategyIndicatorConfig = serde_json::from_str(&json).unwrap();
553 assert_eq!(deserialized.indicators, cfg.indicators);
554 assert_eq!(deserialized.thresholds.len(), cfg.thresholds.len());
555 }
556
557 #[test]
558 fn test_filter_macd_bearish_label() {
559 let mut indicators = TechnicalIndicators::empty();
560 indicators.macd_histogram = Some(-150.0);
561 let cfg = get_strategy_indicator_config("TrendFollowing");
562 let result = filter_indicators_for_prompt(&indicators, &cfg);
563 assert!(result.contains("bearish"));
564 assert!(result.contains("MACD="));
565 }
566
567 #[test]
568 fn test_filter_adx_weak_vs_strong() {
569 let mut indicators = TechnicalIndicators::empty();
570 indicators.adx_14 = Some(15.0);
571 let cfg = get_strategy_indicator_config("TrendFollowing");
572 let result = filter_indicators_for_prompt(&indicators, &cfg);
573 assert!(result.contains("weak"));
574
575 indicators.adx_14 = Some(35.0);
576 let result2 = filter_indicators_for_prompt(&indicators, &cfg);
577 assert!(result2.contains("strong"));
578 }
579}