1use std::collections::{HashMap, HashSet};
2
3use crate::indicators::{Indicator, ATR, CCI, DEMA, EMA, ROC, RSI, SMA, StdDev, TEMA, WMA, WilliamsR, ADX};
4use crate::strategy::types::{
5 CompareTarget, Condition, ConditionGroup, ConditionNode, Operator, Strategy,
6};
7use crate::types::{Candle, ExitReason, Side, Signal};
8
9#[allow(clippy::upper_case_acronyms)]
11#[derive(Debug)]
12enum IndicatorInstance {
13 SMA(SMA),
14 EMA(EMA),
15 RSI(RSI),
16 ATR(ATR),
17 WMA(WMA),
18 DEMA(DEMA),
19 TEMA(TEMA),
20 CCI(CCI),
21 WilliamsR(WilliamsR),
22 ROC(ROC),
23 StdDev(StdDev),
24 ADX(ADX),
25}
26
27impl IndicatorInstance {
28 fn next(&mut self, candle: &Candle) -> Option<f64> {
29 match self {
30 IndicatorInstance::SMA(i) => i.next(candle),
31 IndicatorInstance::EMA(i) => i.next(candle),
32 IndicatorInstance::RSI(i) => i.next(candle),
33 IndicatorInstance::ATR(i) => i.next(candle),
34 IndicatorInstance::WMA(i) => i.next(candle),
35 IndicatorInstance::DEMA(i) => i.next(candle),
36 IndicatorInstance::TEMA(i) => i.next(candle),
37 IndicatorInstance::CCI(i) => i.next(candle),
38 IndicatorInstance::WilliamsR(i) => i.next(candle),
39 IndicatorInstance::ROC(i) => i.next(candle),
40 IndicatorInstance::StdDev(i) => i.next(candle),
41 IndicatorInstance::ADX(i) => {
42 i.next(candle).map(|output| output.adx)
44 }
45 }
46 }
47}
48
49fn parse_indicator(name: &str) -> Option<IndicatorInstance> {
64 if let Some(rest) = name.strip_prefix("sma") {
65 if let Ok(p) = rest.parse::<usize>() {
66 return Some(IndicatorInstance::SMA(SMA::new(p)));
67 }
68 }
69 if let Some(rest) = name.strip_prefix("ema") {
70 if let Ok(p) = rest.parse::<usize>() {
71 return Some(IndicatorInstance::EMA(EMA::new(p)));
72 }
73 }
74 if let Some(rest) = name.strip_prefix("rsi") {
75 if let Ok(p) = rest.parse::<usize>() {
76 return Some(IndicatorInstance::RSI(RSI::new(p)));
77 }
78 }
79 if let Some(rest) = name.strip_prefix("atr") {
80 if let Ok(p) = rest.parse::<usize>() {
81 return Some(IndicatorInstance::ATR(ATR::new(p)));
82 }
83 }
84 if let Some(rest) = name.strip_prefix("wma") {
85 if let Ok(p) = rest.parse::<usize>() {
86 return Some(IndicatorInstance::WMA(WMA::new(p)));
87 }
88 }
89 if let Some(rest) = name.strip_prefix("dema") {
90 if let Ok(p) = rest.parse::<usize>() {
91 return Some(IndicatorInstance::DEMA(DEMA::new(p)));
92 }
93 }
94 if let Some(rest) = name.strip_prefix("tema") {
95 if let Ok(p) = rest.parse::<usize>() {
96 return Some(IndicatorInstance::TEMA(TEMA::new(p)));
97 }
98 }
99 if let Some(rest) = name.strip_prefix("cci") {
100 if let Ok(p) = rest.parse::<usize>() {
101 return Some(IndicatorInstance::CCI(CCI::new(p)));
102 }
103 }
104 if let Some(rest) = name.strip_prefix("williams_r") {
105 if let Ok(p) = rest.parse::<usize>() {
106 return Some(IndicatorInstance::WilliamsR(WilliamsR::new(p)));
107 }
108 }
109 if let Some(rest) = name.strip_prefix("roc") {
110 if let Ok(p) = rest.parse::<usize>() {
111 return Some(IndicatorInstance::ROC(ROC::new(p)));
112 }
113 }
114 if let Some(rest) = name.strip_prefix("stddev") {
115 if let Ok(p) = rest.parse::<usize>() {
116 return Some(IndicatorInstance::StdDev(StdDev::new(p)));
117 }
118 }
119 if let Some(rest) = name.strip_prefix("adx") {
120 if let Ok(p) = rest.parse::<usize>() {
121 return Some(IndicatorInstance::ADX(ADX::new(p)));
122 }
123 }
124 None
125}
126
127#[derive(Debug)]
129pub struct StrategyEngine {
130 strategy: Strategy,
131 indicators: HashMap<String, IndicatorInstance>,
132 required: HashSet<String>,
133 last_values: HashMap<String, f64>,
134}
135
136impl StrategyEngine {
137 pub fn new(strategy: Strategy) -> Self {
138 let mut indicators = HashMap::new();
139 collect_indicators_from_node(&strategy.entry, &mut indicators);
140 if let Some(exit) = &strategy.exit {
141 collect_indicators_from_node(exit, &mut indicators);
142 }
143 let required: HashSet<String> = indicators.keys().cloned().collect();
144 let mut instances = HashMap::new();
145 for name in indicators.keys() {
146 if let Some(inst) = parse_indicator(name) {
147 instances.insert(name.clone(), inst);
148 }
149 }
150 Self {
151 strategy,
152 indicators: instances,
153 required,
154 last_values: HashMap::new(),
155 }
156 }
157
158 pub fn next(&mut self, candle: &Candle) -> Signal {
160 let prev_values = self.last_values.clone();
162 self.last_values.clear();
163
164 for (name, inst) in self.indicators.iter_mut() {
166 if let Some(v) = inst.next(candle) {
167 self.last_values.insert(name.clone(), v);
168 }
169 }
170
171 if self
173 .required
174 .iter()
175 .any(|name| !self.last_values.contains_key(name))
176 {
177 return Signal::Hold;
178 }
179
180 let entry = eval_node(&self.strategy.entry, &self.last_values, &prev_values);
182 let exit = self
183 .strategy
184 .exit
185 .as_ref()
186 .and_then(|n| eval_node(n, &self.last_values, &prev_values));
187
188 if exit == Some(true) {
189 Signal::Exit(ExitReason::RuleTriggered)
190 } else if entry == Some(true) {
191 Signal::Entry(Side::Long)
192 } else {
193 Signal::Hold
194 }
195 }
196
197 pub fn evaluate(&mut self, candles: &[Candle]) -> Vec<Signal> {
199 candles.iter().map(|c| self.next(c)).collect()
200 }
201}
202
203fn get_value(name: &str, values: &HashMap<String, f64>) -> Option<f64> {
204 values.get(name).copied()
205}
206
207fn eval_node(
209 node: &ConditionNode,
210 curr: &HashMap<String, f64>,
211 prev: &HashMap<String, f64>,
212) -> Option<bool> {
213 match node {
214 ConditionNode::Condition(c) => eval_condition(c, curr, prev),
215 ConditionNode::Group(g) => match g {
216 ConditionGroup::AllOf(nodes) => {
217 let mut any_none = false;
218 for n in nodes {
219 match eval_node(n, curr, prev) {
220 Some(true) => {}
221 Some(false) => return Some(false),
222 None => any_none = true,
223 }
224 }
225 if any_none {
226 None
227 } else {
228 Some(true)
229 }
230 }
231 ConditionGroup::AnyOf(nodes) => {
232 let mut any_none = false;
233 for n in nodes {
234 match eval_node(n, curr, prev) {
235 Some(true) => return Some(true),
236 Some(false) => {}
237 None => any_none = true,
238 }
239 }
240 if any_none {
241 None
242 } else {
243 Some(false)
244 }
245 }
246 },
247 }
248}
249
250const EPS: f64 = 1e-9;
251
252fn get_prev_n(name: &str, prev: &HashMap<String, f64>, n: u32) -> Option<f64> {
253 if n == 1 {
254 get_value(name, prev)
255 } else {
256 None
257 }
258}
259
260fn eval_condition(
261 condition: &Condition,
262 curr: &HashMap<String, f64>,
263 prev: &HashMap<String, f64>,
264) -> Option<bool> {
265 let left = get_value(&condition.left, curr)?;
266 let right_curr = match &condition.right {
267 CompareTarget::Value(v) => Some(*v),
268 CompareTarget::Indicator(name) => get_value(name, curr),
269 CompareTarget::Scaled {
270 indicator,
271 multiplier,
272 } => get_value(indicator, curr).map(|v| v * multiplier),
273 CompareTarget::Range(_, _) => None, CompareTarget::None => None,
275 };
276
277 match condition.operator {
278 Operator::IsAbove => Some(left > right_curr?),
279 Operator::IsBelow => Some(left < right_curr?),
280 Operator::Equals => Some((left - right_curr?).abs() < EPS),
281 Operator::IsBetween => {
282 if let CompareTarget::Range(lower, upper) = condition.right {
283 Some(left >= lower && left <= upper)
284 } else {
285 right_curr.map(|r| left >= r)
286 }
287 }
288 Operator::CrossesAbove => {
289 let prev_left = get_value(&condition.left, prev)?;
290 let prev_right = match &condition.right {
291 CompareTarget::Value(v) => Some(*v),
292 CompareTarget::Indicator(name) => get_value(name, prev),
293 CompareTarget::Scaled {
294 indicator,
295 multiplier,
296 } => get_value(indicator, prev).map(|v| v * multiplier),
297 _ => None,
298 }?;
299 Some(left > right_curr? && prev_left <= prev_right)
300 }
301 Operator::CrossesBelow => {
302 let prev_left = get_value(&condition.left, prev)?;
303 let prev_right = match &condition.right {
304 CompareTarget::Value(v) => Some(*v),
305 CompareTarget::Indicator(name) => get_value(name, prev),
306 CompareTarget::Scaled {
307 indicator,
308 multiplier,
309 } => get_value(indicator, prev).map(|v| v * multiplier),
310 _ => None,
311 }?;
312 Some(left < right_curr? && prev_left >= prev_right)
313 }
314 Operator::IsRising(period) => {
315 let prev_left = get_prev_n(&condition.left, prev, period)?;
316 Some(left > prev_left)
317 }
318 Operator::IsFalling(period) => {
319 let prev_left = get_prev_n(&condition.left, prev, period)?;
320 Some(left < prev_left)
321 }
322 }
323}
324
325fn collect_indicators_from_node(node: &ConditionNode, set: &mut HashMap<String, ()>) {
327 match node {
328 ConditionNode::Condition(c) => {
329 set.insert(c.left.clone(), ());
330 if let CompareTarget::Indicator(name) = &c.right {
331 set.insert(name.clone(), ());
332 }
333 if let CompareTarget::Scaled { indicator, .. } = &c.right {
334 set.insert(indicator.clone(), ());
335 }
336 }
337 ConditionNode::Group(g) => match g {
338 ConditionGroup::AllOf(nodes) | ConditionGroup::AnyOf(nodes) => {
339 for n in nodes {
340 collect_indicators_from_node(n, set);
341 }
342 }
343 },
344 }
345}
346
347pub fn evaluate_strategy_batch(strategy: &Strategy, candles: &[Candle]) -> Vec<Signal> {
349 let mut engine = StrategyEngine::new(strategy.clone());
350 engine.evaluate(candles)
351}
352
353pub fn strategy_engine(strategy: Strategy) -> StrategyEngine {
355 StrategyEngine::new(strategy)
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::strategy::indicator_ref::IndicatorRef;
362 use crate::strategy::types::{
363 CompareTarget, Condition, ConditionGroup, ConditionNode, Operator,
364 };
365 use crate::strategy::StopLoss;
366
367 fn make_candles(prices: &[f64]) -> Vec<Candle> {
368 prices
369 .iter()
370 .enumerate()
371 .map(|(i, p)| Candle {
372 timestamp: i as i64,
373 open: *p,
374 high: *p,
375 low: *p,
376 close: *p,
377 volume: 0.0,
378 })
379 .collect()
380 }
381
382 #[test]
383 fn golden_cross_signals() {
384 let entry = IndicatorRef::sma(1).is_above(1.5);
386 let exit = IndicatorRef::sma(1).is_below(1.5);
387 let strategy = Strategy::builder("gc")
388 .entry(entry)
389 .exit(exit)
390 .stop_loss(StopLoss::FixedPercent(1.0))
391 .build()
392 .unwrap();
393
394 let prices = [1.0, 1.2, 1.6, 1.8, 1.4, 1.2];
396 let candles = make_candles(&prices);
397 let signals = evaluate_strategy_batch(&strategy, &candles);
398
399 assert_eq!(signals.len(), prices.len());
400 }
401
402 #[test]
403 fn rsi_mean_reversion_signals() {
404 let entry = IndicatorRef::rsi(2).is_below(40.0);
405 let exit = IndicatorRef::rsi(2).is_above(60.0);
406 let strategy = Strategy::builder("rsi")
407 .entry(entry)
408 .exit(exit)
409 .stop_loss(StopLoss::FixedPercent(2.0))
410 .build()
411 .unwrap();
412
413 let prices = [10.0, 9.5, 9.0, 8.5, 9.5, 10.5];
415 let candles = make_candles(&prices);
416 let signals = evaluate_strategy_batch(&strategy, &candles);
417
418 let mut rsi = crate::indicators::RSI::new(2);
420 let mut expected = Vec::new();
421 for c in &candles {
422 let v = rsi.next(c);
423 let sig = match v {
424 Some(x) if x > 60.0 => Signal::Exit(ExitReason::RuleTriggered),
425 Some(x) if x < 40.0 => Signal::Entry(Side::Long),
426 _ => Signal::Hold,
427 };
428 expected.push(sig);
429 }
430
431 assert_eq!(signals, expected);
432
433 let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
434 let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
435 assert!(entry_idx.is_some(), "expected at least one entry signal");
436 assert!(exit_idx.is_some(), "expected at least one exit signal");
437 if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
438 assert!(ei < xi, "entry should occur before exit");
439 }
440 }
441
442 #[test]
443 fn edge_single_condition_entry_only() {
444 let entry = IndicatorRef::sma(1).is_above(1.0);
445 let strategy = Strategy::builder("single")
446 .entry(entry)
447 .stop_loss(StopLoss::FixedPercent(1.0))
448 .build()
449 .unwrap();
450
451 let prices = [2.0, 2.0, 2.0];
452 let candles = make_candles(&prices);
453 let signals = evaluate_strategy_batch(&strategy, &candles);
454
455 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
456 }
457
458 #[test]
459 fn edge_max_conditions_group_all_of() {
460 let cond = || {
461 ConditionNode::Condition(Condition::new(
462 "sma1",
463 Operator::IsAbove,
464 CompareTarget::Value(1.0),
465 ))
466 };
467 let entry = ConditionNode::Group(ConditionGroup::AllOf((0..20).map(|_| cond()).collect()));
468 let strategy = Strategy::builder("max_group")
469 .entry(entry)
470 .stop_loss(StopLoss::FixedPercent(1.0))
471 .build()
472 .unwrap();
473
474 let prices = [2.0, 2.0, 2.0];
475 let candles = make_candles(&prices);
476 let signals = evaluate_strategy_batch(&strategy, &candles);
477
478 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
479 }
480
481 #[test]
482 fn edge_nested_groups() {
483 let always_true = ConditionNode::Condition(Condition::new(
484 "sma1",
485 Operator::IsAbove,
486 CompareTarget::Value(1.0),
487 ));
488 let always_false = ConditionNode::Condition(Condition::new(
489 "sma1",
490 Operator::IsAbove,
491 CompareTarget::Value(10.0),
492 ));
493
494 let entry = ConditionNode::Group(ConditionGroup::AllOf(vec![
496 always_true.clone(),
497 ConditionNode::Group(ConditionGroup::AnyOf(vec![always_false, always_true])),
498 ]));
499
500 let strategy = Strategy::builder("nested")
501 .entry(entry)
502 .stop_loss(StopLoss::FixedPercent(1.0))
503 .build()
504 .unwrap();
505
506 let prices = [2.0, 2.0, 2.0];
507 let candles = make_candles(&prices);
508 let signals = evaluate_strategy_batch(&strategy, &candles);
509
510 assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
511 }
512
513 #[test]
514 fn streaming_equals_batch() {
515 let entry = IndicatorRef::sma(2).crosses_above_indicator(IndicatorRef::sma(3));
516 let strategy = Strategy::builder("gc")
517 .entry(entry)
518 .stop_loss(StopLoss::FixedPercent(1.0))
519 .build()
520 .unwrap();
521
522 let prices = [1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 1.0];
523 let candles = make_candles(&prices);
524
525 let batch = evaluate_strategy_batch(&strategy, &candles);
526 let mut engine = strategy_engine(strategy);
527 let streaming: Vec<_> = candles.iter().map(|c| engine.next(c)).collect();
528
529 assert_eq!(batch, streaming);
530 }
531
532 #[test]
533 fn golden_cross_manual_verification() {
534 let prices = [1.0, 1.0, 1.0, 3.0, 3.0, 0.5];
539 let candles = make_candles(&prices);
540
541 let mut sma1 = crate::indicators::SMA::new(1);
543 let mut sma3 = crate::indicators::SMA::new(3);
544 let mut expected = Vec::new();
545 let mut prev_fast: Option<f64> = None;
546 let mut prev_slow: Option<f64> = None;
547
548 for c in &candles {
549 let fast = sma1.next(c);
550 let slow = sma3.next(c);
551
552 let sig = match (fast, slow, prev_fast, prev_slow) {
553 (Some(f), Some(s), Some(_), Some(_)) => {
554 if f > s {
555 Signal::Entry(Side::Long)
556 } else if f < s {
557 Signal::Exit(ExitReason::RuleTriggered)
558 } else {
559 Signal::Hold
560 }
561 }
562 _ => Signal::Hold,
563 };
564
565 expected.push(sig);
566 prev_fast = fast;
567 prev_slow = slow;
568 }
569
570 let entry = IndicatorRef::sma(1).is_above_indicator(IndicatorRef::sma(3));
571 let exit = IndicatorRef::sma(1).is_below_indicator(IndicatorRef::sma(3));
572 let strategy = Strategy::builder("gc_manual")
573 .entry(entry)
574 .exit(exit)
575 .stop_loss(StopLoss::FixedPercent(1.0))
576 .build()
577 .unwrap();
578
579 let signals = evaluate_strategy_batch(&strategy, &candles);
580
581 let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
582 let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
583
584 assert!(entry_idx.is_some(), "expected at least one entry signal");
585 assert!(exit_idx.is_some(), "expected at least one exit signal");
586 if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
587 assert!(ei < xi, "entry should occur before exit");
588 }
589 }
590
591 #[test]
592 fn batch_a_indicators_in_strategy_flow() {
593 let entry = IndicatorRef::wma(3).crosses_above_indicator(IndicatorRef::sma(3));
595 let strategy = Strategy::builder("wma_test")
596 .entry(entry)
597 .stop_loss(StopLoss::FixedPercent(1.0))
598 .build()
599 .unwrap();
600
601 let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
602 let candles = make_candles(&prices);
603 let signals = evaluate_strategy_batch(&strategy, &candles);
604 assert!(!signals.is_empty());
605
606 let entry = IndicatorRef::roc(2).is_above(0.0);
608 let strategy = Strategy::builder("roc_test")
609 .entry(entry)
610 .stop_loss(StopLoss::FixedPercent(1.0))
611 .build()
612 .unwrap();
613
614 let signals = evaluate_strategy_batch(&strategy, &candles);
615 assert!(!signals.is_empty());
616
617 let entry = IndicatorRef::stddev(3).is_above(0.5);
619 let strategy = Strategy::builder("stddev_test")
620 .entry(entry)
621 .stop_loss(StopLoss::FixedPercent(1.0))
622 .build()
623 .unwrap();
624
625 let signals = evaluate_strategy_batch(&strategy, &candles);
626 assert!(!signals.is_empty());
627
628 let entry = IndicatorRef::dema(3).crosses_above(2.5);
630 let strategy = Strategy::builder("dema_test")
631 .entry(entry)
632 .stop_loss(StopLoss::FixedPercent(1.0))
633 .build()
634 .unwrap();
635
636 let signals = evaluate_strategy_batch(&strategy, &candles);
637 assert!(!signals.is_empty());
638
639 let entry = IndicatorRef::tema(3).is_above(2.0);
641 let strategy = Strategy::builder("tema_test")
642 .entry(entry)
643 .stop_loss(StopLoss::FixedPercent(1.0))
644 .build()
645 .unwrap();
646
647 let signals = evaluate_strategy_batch(&strategy, &candles);
648 assert!(!signals.is_empty());
649
650 let entry = IndicatorRef::cci(3).is_above(0.0);
652 let strategy = Strategy::builder("cci_test")
653 .entry(entry)
654 .stop_loss(StopLoss::FixedPercent(1.0))
655 .build()
656 .unwrap();
657
658 let signals = evaluate_strategy_batch(&strategy, &candles);
659 assert!(!signals.is_empty());
660
661 let entry = IndicatorRef::williams_r(3).is_below(-50.0);
663 let strategy = Strategy::builder("williams_r_test")
664 .entry(entry)
665 .stop_loss(StopLoss::FixedPercent(1.0))
666 .build()
667 .unwrap();
668
669 let signals = evaluate_strategy_batch(&strategy, &candles);
670 assert!(!signals.is_empty());
671
672 let entry = IndicatorRef::adx(3).is_above(20.0);
674 let strategy = Strategy::builder("adx_test")
675 .entry(entry)
676 .stop_loss(StopLoss::FixedPercent(1.0))
677 .build()
678 .unwrap();
679
680 let signals = evaluate_strategy_batch(&strategy, &candles);
681 assert!(!signals.is_empty());
682 }
683}