1use std::collections::{HashMap, HashSet};
2
3use crate::indicators::{
4 ADX, ATR, CCI, DEMA, EMA, Indicator, ROC, RSI, SMA, StdDev, TEMA, WMA, WilliamsR,
5};
6use crate::strategy::types::{
7 CompareTarget, Condition, ConditionGroup, ConditionNode, Operator, Strategy,
8};
9use crate::types::{Candle, ExitReason, Side, Signal};
10
11#[allow(clippy::upper_case_acronyms)]
13#[derive(Debug)]
14enum IndicatorInstance {
15 SMA(SMA),
16 EMA(EMA),
17 RSI(RSI),
18 ATR(ATR),
19 WMA(WMA),
20 DEMA(DEMA),
21 TEMA(TEMA),
22 CCI(CCI),
23 WilliamsR(WilliamsR),
24 ROC(ROC),
25 StdDev(StdDev),
26 ADX(ADX),
27}
28
29impl IndicatorInstance {
30 fn next(&mut self, candle: &Candle) -> Option<f64> {
31 match self {
32 IndicatorInstance::SMA(i) => i.next(candle),
33 IndicatorInstance::EMA(i) => i.next(candle),
34 IndicatorInstance::RSI(i) => i.next(candle),
35 IndicatorInstance::ATR(i) => i.next(candle),
36 IndicatorInstance::WMA(i) => i.next(candle),
37 IndicatorInstance::DEMA(i) => i.next(candle),
38 IndicatorInstance::TEMA(i) => i.next(candle),
39 IndicatorInstance::CCI(i) => i.next(candle),
40 IndicatorInstance::WilliamsR(i) => i.next(candle),
41 IndicatorInstance::ROC(i) => i.next(candle),
42 IndicatorInstance::StdDev(i) => i.next(candle),
43 IndicatorInstance::ADX(i) => {
44 i.next(candle).map(|output| output.adx)
46 }
47 }
48 }
49}
50
51fn parse_indicator(name: &str) -> Option<IndicatorInstance> {
66 if let Some(rest) = name.strip_prefix("sma")
67 && let Ok(p) = rest.parse::<usize>()
68 {
69 return Some(IndicatorInstance::SMA(SMA::new(p)));
70 }
71 if let Some(rest) = name.strip_prefix("ema")
72 && let Ok(p) = rest.parse::<usize>()
73 {
74 return Some(IndicatorInstance::EMA(EMA::new(p)));
75 }
76 if let Some(rest) = name.strip_prefix("rsi")
77 && let Ok(p) = rest.parse::<usize>()
78 {
79 return Some(IndicatorInstance::RSI(RSI::new(p)));
80 }
81 if let Some(rest) = name.strip_prefix("atr")
82 && let Ok(p) = rest.parse::<usize>()
83 {
84 return Some(IndicatorInstance::ATR(ATR::new(p)));
85 }
86 if let Some(rest) = name.strip_prefix("wma")
87 && let Ok(p) = rest.parse::<usize>()
88 {
89 return Some(IndicatorInstance::WMA(WMA::new(p)));
90 }
91 if let Some(rest) = name.strip_prefix("dema")
92 && let Ok(p) = rest.parse::<usize>()
93 {
94 return Some(IndicatorInstance::DEMA(DEMA::new(p)));
95 }
96 if let Some(rest) = name.strip_prefix("tema")
97 && let Ok(p) = rest.parse::<usize>()
98 {
99 return Some(IndicatorInstance::TEMA(TEMA::new(p)));
100 }
101 if let Some(rest) = name.strip_prefix("cci")
102 && let Ok(p) = rest.parse::<usize>()
103 {
104 return Some(IndicatorInstance::CCI(CCI::new(p)));
105 }
106 if let Some(rest) = name.strip_prefix("williams_r")
107 && let Ok(p) = rest.parse::<usize>()
108 {
109 return Some(IndicatorInstance::WilliamsR(WilliamsR::new(p)));
110 }
111 if let Some(rest) = name.strip_prefix("roc")
112 && let Ok(p) = rest.parse::<usize>()
113 {
114 return Some(IndicatorInstance::ROC(ROC::new(p)));
115 }
116 if let Some(rest) = name.strip_prefix("stddev")
117 && let Ok(p) = rest.parse::<usize>()
118 {
119 return Some(IndicatorInstance::StdDev(StdDev::new(p)));
120 }
121 if let Some(rest) = name.strip_prefix("adx")
122 && let Ok(p) = rest.parse::<usize>()
123 {
124 return Some(IndicatorInstance::ADX(ADX::new(p)));
125 }
126 None
127}
128
129#[derive(Debug)]
131pub struct StrategyEngine {
132 strategy: Strategy,
133 indicators: HashMap<String, IndicatorInstance>,
134 required: HashSet<String>,
135 last_values: HashMap<String, f64>,
136}
137
138impl StrategyEngine {
139 pub fn new(strategy: Strategy) -> Self {
140 let mut indicators = HashMap::new();
141 collect_indicators_from_node(&strategy.entry, &mut indicators);
142 if let Some(exit) = &strategy.exit {
143 collect_indicators_from_node(exit, &mut indicators);
144 }
145 let required: HashSet<String> = indicators.keys().cloned().collect();
146 let mut instances = HashMap::new();
147 for name in indicators.keys() {
148 if let Some(inst) = parse_indicator(name) {
149 instances.insert(name.clone(), inst);
150 }
151 }
152 Self {
153 strategy,
154 indicators: instances,
155 required,
156 last_values: HashMap::new(),
157 }
158 }
159
160 pub fn next(&mut self, candle: &Candle) -> Signal {
162 let prev_values = self.last_values.clone();
164 self.last_values.clear();
165
166 for (name, inst) in self.indicators.iter_mut() {
168 if let Some(v) = inst.next(candle) {
169 self.last_values.insert(name.clone(), v);
170 }
171 }
172
173 if self
175 .required
176 .iter()
177 .any(|name| !self.last_values.contains_key(name))
178 {
179 return Signal::Hold;
180 }
181
182 let entry = eval_node(&self.strategy.entry, &self.last_values, &prev_values);
184 let exit = self
185 .strategy
186 .exit
187 .as_ref()
188 .and_then(|n| eval_node(n, &self.last_values, &prev_values));
189
190 if exit == Some(true) {
191 Signal::Exit(ExitReason::RuleTriggered)
192 } else if entry == Some(true) {
193 Signal::Entry(Side::Long)
194 } else {
195 Signal::Hold
196 }
197 }
198
199 pub fn evaluate(&mut self, candles: &[Candle]) -> Vec<Signal> {
201 candles.iter().map(|c| self.next(c)).collect()
202 }
203}
204
205fn get_value(name: &str, values: &HashMap<String, f64>) -> Option<f64> {
206 values.get(name).copied()
207}
208
209fn eval_node(
211 node: &ConditionNode,
212 curr: &HashMap<String, f64>,
213 prev: &HashMap<String, f64>,
214) -> Option<bool> {
215 match node {
216 ConditionNode::Condition(c) => eval_condition(c, curr, prev),
217 ConditionNode::Group(g) => match g {
218 ConditionGroup::AllOf(nodes) => {
219 let mut any_none = false;
220 for n in nodes {
221 match eval_node(n, curr, prev) {
222 Some(true) => {}
223 Some(false) => return Some(false),
224 None => any_none = true,
225 }
226 }
227 if any_none { None } else { Some(true) }
228 }
229 ConditionGroup::AnyOf(nodes) => {
230 let mut any_none = false;
231 for n in nodes {
232 match eval_node(n, curr, prev) {
233 Some(true) => return Some(true),
234 Some(false) => {}
235 None => any_none = true,
236 }
237 }
238 if any_none { None } else { Some(false) }
239 }
240 },
241 }
242}
243
244const EPS: f64 = 1e-9;
245
246fn get_prev_n(name: &str, prev: &HashMap<String, f64>, n: u32) -> Option<f64> {
247 if n == 1 { get_value(name, prev) } else { None }
248}
249
250fn eval_condition(
251 condition: &Condition,
252 curr: &HashMap<String, f64>,
253 prev: &HashMap<String, f64>,
254) -> Option<bool> {
255 let left = get_value(&condition.left, curr)?;
256 let right_curr = match &condition.right {
257 CompareTarget::Value(v) => Some(*v),
258 CompareTarget::Indicator(name) => get_value(name, curr),
259 CompareTarget::Scaled {
260 indicator,
261 multiplier,
262 } => get_value(indicator, curr).map(|v| v * multiplier),
263 CompareTarget::Range(_, _) => None, CompareTarget::None => None,
265 };
266
267 match condition.operator {
268 Operator::IsAbove => Some(left > right_curr?),
269 Operator::IsBelow => Some(left < right_curr?),
270 Operator::Equals => Some((left - right_curr?).abs() < EPS),
271 Operator::IsBetween => {
272 if let CompareTarget::Range(lower, upper) = condition.right {
273 Some(left >= lower && left <= upper)
274 } else {
275 right_curr.map(|r| left >= r)
276 }
277 }
278 Operator::CrossesAbove => {
279 let prev_left = get_value(&condition.left, prev)?;
280 let prev_right = match &condition.right {
281 CompareTarget::Value(v) => Some(*v),
282 CompareTarget::Indicator(name) => get_value(name, prev),
283 CompareTarget::Scaled {
284 indicator,
285 multiplier,
286 } => get_value(indicator, prev).map(|v| v * multiplier),
287 _ => None,
288 }?;
289 Some(left > right_curr? && prev_left <= prev_right)
290 }
291 Operator::CrossesBelow => {
292 let prev_left = get_value(&condition.left, prev)?;
293 let prev_right = match &condition.right {
294 CompareTarget::Value(v) => Some(*v),
295 CompareTarget::Indicator(name) => get_value(name, prev),
296 CompareTarget::Scaled {
297 indicator,
298 multiplier,
299 } => get_value(indicator, prev).map(|v| v * multiplier),
300 _ => None,
301 }?;
302 Some(left < right_curr? && prev_left >= prev_right)
303 }
304 Operator::IsRising(period) => {
305 let prev_left = get_prev_n(&condition.left, prev, period)?;
306 Some(left > prev_left)
307 }
308 Operator::IsFalling(period) => {
309 let prev_left = get_prev_n(&condition.left, prev, period)?;
310 Some(left < prev_left)
311 }
312 }
313}
314
315fn collect_indicators_from_node(node: &ConditionNode, set: &mut HashMap<String, ()>) {
317 match node {
318 ConditionNode::Condition(c) => {
319 set.insert(c.left.clone(), ());
320 if let CompareTarget::Indicator(name) = &c.right {
321 set.insert(name.clone(), ());
322 }
323 if let CompareTarget::Scaled { indicator, .. } = &c.right {
324 set.insert(indicator.clone(), ());
325 }
326 }
327 ConditionNode::Group(g) => match g {
328 ConditionGroup::AllOf(nodes) | ConditionGroup::AnyOf(nodes) => {
329 for n in nodes {
330 collect_indicators_from_node(n, set);
331 }
332 }
333 },
334 }
335}
336
337pub fn evaluate_strategy_batch(strategy: &Strategy, candles: &[Candle]) -> Vec<Signal> {
339 let mut engine = StrategyEngine::new(strategy.clone());
340 engine.evaluate(candles)
341}
342
343pub fn strategy_engine(strategy: Strategy) -> StrategyEngine {
345 StrategyEngine::new(strategy)
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use crate::strategy::StopLoss;
352 use crate::strategy::indicator_ref::IndicatorRef;
353 use crate::strategy::types::{
354 CompareTarget, Condition, ConditionGroup, ConditionNode, Operator,
355 };
356
357 fn make_candles(prices: &[f64]) -> Vec<Candle> {
358 prices
359 .iter()
360 .enumerate()
361 .map(|(i, p)| Candle {
362 timestamp: i as i64,
363 open: *p,
364 high: *p,
365 low: *p,
366 close: *p,
367 volume: 0.0,
368 })
369 .collect()
370 }
371
372 #[test]
373 fn golden_cross_signals() {
374 let entry = IndicatorRef::sma(1).is_above(1.5);
376 let exit = IndicatorRef::sma(1).is_below(1.5);
377 let strategy = Strategy::builder("gc")
378 .entry(entry)
379 .exit(exit)
380 .stop_loss(StopLoss::FixedPercent(1.0))
381 .build()
382 .unwrap();
383
384 let prices = [1.0, 1.2, 1.6, 1.8, 1.4, 1.2];
386 let candles = make_candles(&prices);
387 let signals = evaluate_strategy_batch(&strategy, &candles);
388
389 assert_eq!(signals.len(), prices.len());
390 }
391
392 #[test]
393 fn rsi_mean_reversion_signals() {
394 let entry = IndicatorRef::rsi(2).is_below(40.0);
395 let exit = IndicatorRef::rsi(2).is_above(60.0);
396 let strategy = Strategy::builder("rsi")
397 .entry(entry)
398 .exit(exit)
399 .stop_loss(StopLoss::FixedPercent(2.0))
400 .build()
401 .unwrap();
402
403 let prices = [10.0, 9.5, 9.0, 8.5, 9.5, 10.5];
405 let candles = make_candles(&prices);
406 let signals = evaluate_strategy_batch(&strategy, &candles);
407
408 let mut rsi = crate::indicators::RSI::new(2);
410 let mut expected = Vec::new();
411 for c in &candles {
412 let v = rsi.next(c);
413 let sig = match v {
414 Some(x) if x > 60.0 => Signal::Exit(ExitReason::RuleTriggered),
415 Some(x) if x < 40.0 => Signal::Entry(Side::Long),
416 _ => Signal::Hold,
417 };
418 expected.push(sig);
419 }
420
421 assert_eq!(signals, expected);
422
423 let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
424 let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
425 assert!(entry_idx.is_some(), "expected at least one entry signal");
426 assert!(exit_idx.is_some(), "expected at least one exit signal");
427 if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
428 assert!(ei < xi, "entry should occur before exit");
429 }
430 }
431
432 #[test]
433 fn edge_single_condition_entry_only() {
434 let entry = IndicatorRef::sma(1).is_above(1.0);
435 let strategy = Strategy::builder("single")
436 .entry(entry)
437 .stop_loss(StopLoss::FixedPercent(1.0))
438 .build()
439 .unwrap();
440
441 let prices = [2.0, 2.0, 2.0];
442 let candles = make_candles(&prices);
443 let signals = evaluate_strategy_batch(&strategy, &candles);
444
445 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
446 }
447
448 #[test]
449 fn edge_max_conditions_group_all_of() {
450 let cond = || {
451 ConditionNode::Condition(Condition::new(
452 "sma1",
453 Operator::IsAbove,
454 CompareTarget::Value(1.0),
455 ))
456 };
457 let entry = ConditionNode::Group(ConditionGroup::AllOf((0..20).map(|_| cond()).collect()));
458 let strategy = Strategy::builder("max_group")
459 .entry(entry)
460 .stop_loss(StopLoss::FixedPercent(1.0))
461 .build()
462 .unwrap();
463
464 let prices = [2.0, 2.0, 2.0];
465 let candles = make_candles(&prices);
466 let signals = evaluate_strategy_batch(&strategy, &candles);
467
468 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
469 }
470
471 #[test]
472 fn edge_nested_groups() {
473 let always_true = ConditionNode::Condition(Condition::new(
474 "sma1",
475 Operator::IsAbove,
476 CompareTarget::Value(1.0),
477 ));
478 let always_false = ConditionNode::Condition(Condition::new(
479 "sma1",
480 Operator::IsAbove,
481 CompareTarget::Value(10.0),
482 ));
483
484 let entry = ConditionNode::Group(ConditionGroup::AllOf(vec![
486 always_true.clone(),
487 ConditionNode::Group(ConditionGroup::AnyOf(vec![always_false, always_true])),
488 ]));
489
490 let strategy = Strategy::builder("nested")
491 .entry(entry)
492 .stop_loss(StopLoss::FixedPercent(1.0))
493 .build()
494 .unwrap();
495
496 let prices = [2.0, 2.0, 2.0];
497 let candles = make_candles(&prices);
498 let signals = evaluate_strategy_batch(&strategy, &candles);
499
500 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
501 }
502
503 #[test]
504 fn streaming_equals_batch() {
505 let entry = IndicatorRef::sma(2).crosses_above_indicator(IndicatorRef::sma(3));
506 let strategy = Strategy::builder("gc")
507 .entry(entry)
508 .stop_loss(StopLoss::FixedPercent(1.0))
509 .build()
510 .unwrap();
511
512 let prices = [1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 1.0];
513 let candles = make_candles(&prices);
514
515 let batch = evaluate_strategy_batch(&strategy, &candles);
516 let mut engine = strategy_engine(strategy);
517 let streaming: Vec<_> = candles.iter().map(|c| engine.next(c)).collect();
518
519 assert_eq!(batch, streaming);
520 }
521
522 #[test]
523 fn golden_cross_manual_verification() {
524 let prices = [1.0, 1.0, 1.0, 3.0, 3.0, 0.5];
529 let candles = make_candles(&prices);
530
531 let mut sma1 = crate::indicators::SMA::new(1);
533 let mut sma3 = crate::indicators::SMA::new(3);
534 let mut expected = Vec::new();
535 let mut prev_fast: Option<f64> = None;
536 let mut prev_slow: Option<f64> = None;
537
538 for c in &candles {
539 let fast = sma1.next(c);
540 let slow = sma3.next(c);
541
542 let sig = match (fast, slow, prev_fast, prev_slow) {
543 (Some(f), Some(s), Some(_), Some(_)) => {
544 if f > s {
545 Signal::Entry(Side::Long)
546 } else if f < s {
547 Signal::Exit(ExitReason::RuleTriggered)
548 } else {
549 Signal::Hold
550 }
551 }
552 _ => Signal::Hold,
553 };
554
555 expected.push(sig);
556 prev_fast = fast;
557 prev_slow = slow;
558 }
559
560 let entry = IndicatorRef::sma(1).is_above_indicator(IndicatorRef::sma(3));
561 let exit = IndicatorRef::sma(1).is_below_indicator(IndicatorRef::sma(3));
562 let strategy = Strategy::builder("gc_manual")
563 .entry(entry)
564 .exit(exit)
565 .stop_loss(StopLoss::FixedPercent(1.0))
566 .build()
567 .unwrap();
568
569 let signals = evaluate_strategy_batch(&strategy, &candles);
570
571 let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
572 let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
573
574 assert!(entry_idx.is_some(), "expected at least one entry signal");
575 assert!(exit_idx.is_some(), "expected at least one exit signal");
576 if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
577 assert!(ei < xi, "entry should occur before exit");
578 }
579 }
580
581 #[test]
582 fn batch_a_indicators_in_strategy_flow() {
583 let entry = IndicatorRef::wma(3).crosses_above_indicator(IndicatorRef::sma(3));
585 let strategy = Strategy::builder("wma_test")
586 .entry(entry)
587 .stop_loss(StopLoss::FixedPercent(1.0))
588 .build()
589 .unwrap();
590
591 let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
592 let candles = make_candles(&prices);
593 let signals = evaluate_strategy_batch(&strategy, &candles);
594 assert!(!signals.is_empty());
595
596 let entry = IndicatorRef::roc(2).is_above(0.0);
598 let strategy = Strategy::builder("roc_test")
599 .entry(entry)
600 .stop_loss(StopLoss::FixedPercent(1.0))
601 .build()
602 .unwrap();
603
604 let signals = evaluate_strategy_batch(&strategy, &candles);
605 assert!(!signals.is_empty());
606
607 let entry = IndicatorRef::stddev(3).is_above(0.5);
609 let strategy = Strategy::builder("stddev_test")
610 .entry(entry)
611 .stop_loss(StopLoss::FixedPercent(1.0))
612 .build()
613 .unwrap();
614
615 let signals = evaluate_strategy_batch(&strategy, &candles);
616 assert!(!signals.is_empty());
617
618 let entry = IndicatorRef::dema(3).crosses_above(2.5);
620 let strategy = Strategy::builder("dema_test")
621 .entry(entry)
622 .stop_loss(StopLoss::FixedPercent(1.0))
623 .build()
624 .unwrap();
625
626 let signals = evaluate_strategy_batch(&strategy, &candles);
627 assert!(!signals.is_empty());
628
629 let entry = IndicatorRef::tema(3).is_above(2.0);
631 let strategy = Strategy::builder("tema_test")
632 .entry(entry)
633 .stop_loss(StopLoss::FixedPercent(1.0))
634 .build()
635 .unwrap();
636
637 let signals = evaluate_strategy_batch(&strategy, &candles);
638 assert!(!signals.is_empty());
639
640 let entry = IndicatorRef::cci(3).is_above(0.0);
642 let strategy = Strategy::builder("cci_test")
643 .entry(entry)
644 .stop_loss(StopLoss::FixedPercent(1.0))
645 .build()
646 .unwrap();
647
648 let signals = evaluate_strategy_batch(&strategy, &candles);
649 assert!(!signals.is_empty());
650
651 let entry = IndicatorRef::williams_r(3).is_below(-50.0);
653 let strategy = Strategy::builder("williams_r_test")
654 .entry(entry)
655 .stop_loss(StopLoss::FixedPercent(1.0))
656 .build()
657 .unwrap();
658
659 let signals = evaluate_strategy_batch(&strategy, &candles);
660 assert!(!signals.is_empty());
661
662 let entry = IndicatorRef::adx(3).is_above(20.0);
664 let strategy = Strategy::builder("adx_test")
665 .entry(entry)
666 .stop_loss(StopLoss::FixedPercent(1.0))
667 .build()
668 .unwrap();
669
670 let signals = evaluate_strategy_batch(&strategy, &candles);
671 assert!(!signals.is_empty());
672 }
673}