Skip to main content

mantis_ta/strategy/
evaluator.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::indicators::{
4    ADX, ATR, CCI, DEMA, EMA, Indicator, ROC, RSI, SMA, StdDev, TEMA, WMA, WilliamsR,
5};
6use crate::strategy::types::{
7    CompareTarget, Condition, ConditionGroup, ConditionNode, Operator, Strategy,
8};
9use crate::types::{Candle, ExitReason, Side, Signal};
10
11/// Wrapper over supported indicator implementations that produce scalar outputs.
12#[allow(clippy::upper_case_acronyms)]
13#[derive(Debug)]
14enum IndicatorInstance {
15    SMA(SMA),
16    EMA(EMA),
17    RSI(RSI),
18    ATR(ATR),
19    WMA(WMA),
20    DEMA(DEMA),
21    TEMA(TEMA),
22    CCI(CCI),
23    WilliamsR(WilliamsR),
24    ROC(ROC),
25    StdDev(StdDev),
26    ADX(ADX),
27}
28
29impl IndicatorInstance {
30    fn next(&mut self, candle: &Candle) -> Option<f64> {
31        match self {
32            IndicatorInstance::SMA(i) => i.next(candle),
33            IndicatorInstance::EMA(i) => i.next(candle),
34            IndicatorInstance::RSI(i) => i.next(candle),
35            IndicatorInstance::ATR(i) => i.next(candle),
36            IndicatorInstance::WMA(i) => i.next(candle),
37            IndicatorInstance::DEMA(i) => i.next(candle),
38            IndicatorInstance::TEMA(i) => i.next(candle),
39            IndicatorInstance::CCI(i) => i.next(candle),
40            IndicatorInstance::WilliamsR(i) => i.next(candle),
41            IndicatorInstance::ROC(i) => i.next(candle),
42            IndicatorInstance::StdDev(i) => i.next(candle),
43            IndicatorInstance::ADX(i) => {
44                // ADX returns AdxOutput, extract the adx value
45                i.next(candle).map(|output| output.adx)
46            }
47        }
48    }
49}
50
51/// Parse an indicator reference name into a concrete indicator instance.
52/// Supported forms:
53/// - "sma{period}"
54/// - "ema{period}"
55/// - "rsi{period}"
56/// - "atr{period}"
57/// - "wma{period}"
58/// - "dema{period}"
59/// - "tema{period}"
60/// - "cci{period}"
61/// - "williams_r{period}"
62/// - "roc{period}"
63/// - "stddev{period}"
64/// - "adx{period}"
65fn parse_indicator(name: &str) -> Option<IndicatorInstance> {
66    if let Some(rest) = name.strip_prefix("sma")
67        && let Ok(p) = rest.parse::<usize>()
68    {
69        return Some(IndicatorInstance::SMA(SMA::new(p)));
70    }
71    if let Some(rest) = name.strip_prefix("ema")
72        && let Ok(p) = rest.parse::<usize>()
73    {
74        return Some(IndicatorInstance::EMA(EMA::new(p)));
75    }
76    if let Some(rest) = name.strip_prefix("rsi")
77        && let Ok(p) = rest.parse::<usize>()
78    {
79        return Some(IndicatorInstance::RSI(RSI::new(p)));
80    }
81    if let Some(rest) = name.strip_prefix("atr")
82        && let Ok(p) = rest.parse::<usize>()
83    {
84        return Some(IndicatorInstance::ATR(ATR::new(p)));
85    }
86    if let Some(rest) = name.strip_prefix("wma")
87        && let Ok(p) = rest.parse::<usize>()
88    {
89        return Some(IndicatorInstance::WMA(WMA::new(p)));
90    }
91    if let Some(rest) = name.strip_prefix("dema")
92        && let Ok(p) = rest.parse::<usize>()
93    {
94        return Some(IndicatorInstance::DEMA(DEMA::new(p)));
95    }
96    if let Some(rest) = name.strip_prefix("tema")
97        && let Ok(p) = rest.parse::<usize>()
98    {
99        return Some(IndicatorInstance::TEMA(TEMA::new(p)));
100    }
101    if let Some(rest) = name.strip_prefix("cci")
102        && let Ok(p) = rest.parse::<usize>()
103    {
104        return Some(IndicatorInstance::CCI(CCI::new(p)));
105    }
106    if let Some(rest) = name.strip_prefix("williams_r")
107        && let Ok(p) = rest.parse::<usize>()
108    {
109        return Some(IndicatorInstance::WilliamsR(WilliamsR::new(p)));
110    }
111    if let Some(rest) = name.strip_prefix("roc")
112        && let Ok(p) = rest.parse::<usize>()
113    {
114        return Some(IndicatorInstance::ROC(ROC::new(p)));
115    }
116    if let Some(rest) = name.strip_prefix("stddev")
117        && let Ok(p) = rest.parse::<usize>()
118    {
119        return Some(IndicatorInstance::StdDev(StdDev::new(p)));
120    }
121    if let Some(rest) = name.strip_prefix("adx")
122        && let Ok(p) = rest.parse::<usize>()
123    {
124        return Some(IndicatorInstance::ADX(ADX::new(p)));
125    }
126    None
127}
128
129/// Strategy evaluation engine for streaming signals.
130#[derive(Debug)]
131pub struct StrategyEngine {
132    strategy: Strategy,
133    indicators: HashMap<String, IndicatorInstance>,
134    required: HashSet<String>,
135    last_values: HashMap<String, f64>,
136}
137
138impl StrategyEngine {
139    pub fn new(strategy: Strategy) -> Self {
140        let mut indicators = HashMap::new();
141        collect_indicators_from_node(&strategy.entry, &mut indicators);
142        if let Some(exit) = &strategy.exit {
143            collect_indicators_from_node(exit, &mut indicators);
144        }
145        let required: HashSet<String> = indicators.keys().cloned().collect();
146        let mut instances = HashMap::new();
147        for name in indicators.keys() {
148            if let Some(inst) = parse_indicator(name) {
149                instances.insert(name.clone(), inst);
150            }
151        }
152        Self {
153            strategy,
154            indicators: instances,
155            required,
156            last_values: HashMap::new(),
157        }
158    }
159
160    /// Evaluate one candle and emit a signal.
161    pub fn next(&mut self, candle: &Candle) -> Signal {
162        // Capture previous values for cross/rising/falling detection
163        let prev_values = self.last_values.clone();
164        self.last_values.clear();
165
166        // Advance indicators
167        for (name, inst) in self.indicators.iter_mut() {
168            if let Some(v) = inst.next(candle) {
169                self.last_values.insert(name.clone(), v);
170            }
171        }
172
173        // Warmup: if any required indicator has not produced a value yet, hold
174        if self
175            .required
176            .iter()
177            .any(|name| !self.last_values.contains_key(name))
178        {
179            return Signal::Hold;
180        }
181
182        // Evaluate entry/exit; if required indicator values are missing, return Hold
183        let entry = eval_node(&self.strategy.entry, &self.last_values, &prev_values);
184        let exit = self
185            .strategy
186            .exit
187            .as_ref()
188            .and_then(|n| eval_node(n, &self.last_values, &prev_values));
189
190        if exit == Some(true) {
191            Signal::Exit(ExitReason::RuleTriggered)
192        } else if entry == Some(true) {
193            Signal::Entry(Side::Long)
194        } else {
195            Signal::Hold
196        }
197    }
198
199    /// Batch evaluation over a candle slice.
200    pub fn evaluate(&mut self, candles: &[Candle]) -> Vec<Signal> {
201        candles.iter().map(|c| self.next(c)).collect()
202    }
203}
204
205fn get_value(name: &str, values: &HashMap<String, f64>) -> Option<f64> {
206    values.get(name).copied()
207}
208
209/// Evaluate a condition tree. Returns None if data is insufficient (warmup).
210fn eval_node(
211    node: &ConditionNode,
212    curr: &HashMap<String, f64>,
213    prev: &HashMap<String, f64>,
214) -> Option<bool> {
215    match node {
216        ConditionNode::Condition(c) => eval_condition(c, curr, prev),
217        ConditionNode::Group(g) => match g {
218            ConditionGroup::AllOf(nodes) => {
219                let mut any_none = false;
220                for n in nodes {
221                    match eval_node(n, curr, prev) {
222                        Some(true) => {}
223                        Some(false) => return Some(false),
224                        None => any_none = true,
225                    }
226                }
227                if any_none { None } else { Some(true) }
228            }
229            ConditionGroup::AnyOf(nodes) => {
230                let mut any_none = false;
231                for n in nodes {
232                    match eval_node(n, curr, prev) {
233                        Some(true) => return Some(true),
234                        Some(false) => {}
235                        None => any_none = true,
236                    }
237                }
238                if any_none { None } else { Some(false) }
239            }
240        },
241    }
242}
243
244const EPS: f64 = 1e-9;
245
246fn get_prev_n(name: &str, prev: &HashMap<String, f64>, n: u32) -> Option<f64> {
247    if n == 1 { get_value(name, prev) } else { None }
248}
249
250fn eval_condition(
251    condition: &Condition,
252    curr: &HashMap<String, f64>,
253    prev: &HashMap<String, f64>,
254) -> Option<bool> {
255    let left = get_value(&condition.left, curr)?;
256    let right_curr = match &condition.right {
257        CompareTarget::Value(v) => Some(*v),
258        CompareTarget::Indicator(name) => get_value(name, curr),
259        CompareTarget::Scaled {
260            indicator,
261            multiplier,
262        } => get_value(indicator, curr).map(|v| v * multiplier),
263        CompareTarget::Range(_, _) => None, // handled per-operator
264        CompareTarget::None => None,
265    };
266
267    match condition.operator {
268        Operator::IsAbove => Some(left > right_curr?),
269        Operator::IsBelow => Some(left < right_curr?),
270        Operator::Equals => Some((left - right_curr?).abs() < EPS),
271        Operator::IsBetween => {
272            if let CompareTarget::Range(lower, upper) = condition.right {
273                Some(left >= lower && left <= upper)
274            } else {
275                right_curr.map(|r| left >= r)
276            }
277        }
278        Operator::CrossesAbove => {
279            let prev_left = get_value(&condition.left, prev)?;
280            let prev_right = match &condition.right {
281                CompareTarget::Value(v) => Some(*v),
282                CompareTarget::Indicator(name) => get_value(name, prev),
283                CompareTarget::Scaled {
284                    indicator,
285                    multiplier,
286                } => get_value(indicator, prev).map(|v| v * multiplier),
287                _ => None,
288            }?;
289            Some(left > right_curr? && prev_left <= prev_right)
290        }
291        Operator::CrossesBelow => {
292            let prev_left = get_value(&condition.left, prev)?;
293            let prev_right = match &condition.right {
294                CompareTarget::Value(v) => Some(*v),
295                CompareTarget::Indicator(name) => get_value(name, prev),
296                CompareTarget::Scaled {
297                    indicator,
298                    multiplier,
299                } => get_value(indicator, prev).map(|v| v * multiplier),
300                _ => None,
301            }?;
302            Some(left < right_curr? && prev_left >= prev_right)
303        }
304        Operator::IsRising(period) => {
305            let prev_left = get_prev_n(&condition.left, prev, period)?;
306            Some(left > prev_left)
307        }
308        Operator::IsFalling(period) => {
309            let prev_left = get_prev_n(&condition.left, prev, period)?;
310            Some(left < prev_left)
311        }
312    }
313}
314
315/// Walk a condition tree to collect indicator names referenced.
316fn collect_indicators_from_node(node: &ConditionNode, set: &mut HashMap<String, ()>) {
317    match node {
318        ConditionNode::Condition(c) => {
319            set.insert(c.left.clone(), ());
320            if let CompareTarget::Indicator(name) = &c.right {
321                set.insert(name.clone(), ());
322            }
323            if let CompareTarget::Scaled { indicator, .. } = &c.right {
324                set.insert(indicator.clone(), ());
325            }
326        }
327        ConditionNode::Group(g) => match g {
328            ConditionGroup::AllOf(nodes) | ConditionGroup::AnyOf(nodes) => {
329                for n in nodes {
330                    collect_indicators_from_node(n, set);
331                }
332            }
333        },
334    }
335}
336
337/// Batch evaluation helper: convenience wrapper over StrategyEngine.
338pub fn evaluate_strategy_batch(strategy: &Strategy, candles: &[Candle]) -> Vec<Signal> {
339    let mut engine = StrategyEngine::new(strategy.clone());
340    engine.evaluate(candles)
341}
342
343/// Streaming evaluation helper: create an engine from a strategy.
344pub fn strategy_engine(strategy: Strategy) -> StrategyEngine {
345    StrategyEngine::new(strategy)
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::strategy::StopLoss;
352    use crate::strategy::indicator_ref::IndicatorRef;
353    use crate::strategy::types::{
354        CompareTarget, Condition, ConditionGroup, ConditionNode, Operator,
355    };
356
357    fn make_candles(prices: &[f64]) -> Vec<Candle> {
358        prices
359            .iter()
360            .enumerate()
361            .map(|(i, p)| Candle {
362                timestamp: i as i64,
363                open: *p,
364                high: *p,
365                low: *p,
366                close: *p,
367                volume: 0.0,
368            })
369            .collect()
370    }
371
372    #[test]
373    fn golden_cross_signals() {
374        // Use very small periods to reduce warmup
375        let entry = IndicatorRef::sma(1).is_above(1.5);
376        let exit = IndicatorRef::sma(1).is_below(1.5);
377        let strategy = Strategy::builder("gc")
378            .entry(entry)
379            .exit(exit)
380            .stop_loss(StopLoss::FixedPercent(1.0))
381            .build()
382            .unwrap();
383
384        // Prices designed to cross upward then downward
385        let prices = [1.0, 1.2, 1.6, 1.8, 1.4, 1.2];
386        let candles = make_candles(&prices);
387        let signals = evaluate_strategy_batch(&strategy, &candles);
388
389        assert_eq!(signals.len(), prices.len());
390    }
391
392    #[test]
393    fn rsi_mean_reversion_signals() {
394        let entry = IndicatorRef::rsi(2).is_below(40.0);
395        let exit = IndicatorRef::rsi(2).is_above(60.0);
396        let strategy = Strategy::builder("rsi")
397            .entry(entry)
398            .exit(exit)
399            .stop_loss(StopLoss::FixedPercent(2.0))
400            .build()
401            .unwrap();
402
403        // Construct prices to push RSI below 40 then above 60
404        let prices = [10.0, 9.5, 9.0, 8.5, 9.5, 10.5];
405        let candles = make_candles(&prices);
406        let signals = evaluate_strategy_batch(&strategy, &candles);
407
408        // Manual RSI verification: compute RSI and derive expected signals from thresholds.
409        let mut rsi = crate::indicators::RSI::new(2);
410        let mut expected = Vec::new();
411        for c in &candles {
412            let v = rsi.next(c);
413            let sig = match v {
414                Some(x) if x > 60.0 => Signal::Exit(ExitReason::RuleTriggered),
415                Some(x) if x < 40.0 => Signal::Entry(Side::Long),
416                _ => Signal::Hold,
417            };
418            expected.push(sig);
419        }
420
421        assert_eq!(signals, expected);
422
423        let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
424        let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
425        assert!(entry_idx.is_some(), "expected at least one entry signal");
426        assert!(exit_idx.is_some(), "expected at least one exit signal");
427        if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
428            assert!(ei < xi, "entry should occur before exit");
429        }
430    }
431
432    #[test]
433    fn edge_single_condition_entry_only() {
434        let entry = IndicatorRef::sma(1).is_above(1.0);
435        let strategy = Strategy::builder("single")
436            .entry(entry)
437            .stop_loss(StopLoss::FixedPercent(1.0))
438            .build()
439            .unwrap();
440
441        let prices = [2.0, 2.0, 2.0];
442        let candles = make_candles(&prices);
443        let signals = evaluate_strategy_batch(&strategy, &candles);
444
445        assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
446    }
447
448    #[test]
449    fn edge_max_conditions_group_all_of() {
450        let cond = || {
451            ConditionNode::Condition(Condition::new(
452                "sma1",
453                Operator::IsAbove,
454                CompareTarget::Value(1.0),
455            ))
456        };
457        let entry = ConditionNode::Group(ConditionGroup::AllOf((0..20).map(|_| cond()).collect()));
458        let strategy = Strategy::builder("max_group")
459            .entry(entry)
460            .stop_loss(StopLoss::FixedPercent(1.0))
461            .build()
462            .unwrap();
463
464        let prices = [2.0, 2.0, 2.0];
465        let candles = make_candles(&prices);
466        let signals = evaluate_strategy_batch(&strategy, &candles);
467
468        assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
469    }
470
471    #[test]
472    fn edge_nested_groups() {
473        let always_true = ConditionNode::Condition(Condition::new(
474            "sma1",
475            Operator::IsAbove,
476            CompareTarget::Value(1.0),
477        ));
478        let always_false = ConditionNode::Condition(Condition::new(
479            "sma1",
480            Operator::IsAbove,
481            CompareTarget::Value(10.0),
482        ));
483
484        // Entry: sma1 > 1 AND (sma1 > 10 OR sma1 > 1)
485        let entry = ConditionNode::Group(ConditionGroup::AllOf(vec![
486            always_true.clone(),
487            ConditionNode::Group(ConditionGroup::AnyOf(vec![always_false, always_true])),
488        ]));
489
490        let strategy = Strategy::builder("nested")
491            .entry(entry)
492            .stop_loss(StopLoss::FixedPercent(1.0))
493            .build()
494            .unwrap();
495
496        let prices = [2.0, 2.0, 2.0];
497        let candles = make_candles(&prices);
498        let signals = evaluate_strategy_batch(&strategy, &candles);
499
500        assert!(signals.iter().all(|s| matches!(s, Signal::Entry(_))));
501    }
502
503    #[test]
504    fn streaming_equals_batch() {
505        let entry = IndicatorRef::sma(2).crosses_above_indicator(IndicatorRef::sma(3));
506        let strategy = Strategy::builder("gc")
507            .entry(entry)
508            .stop_loss(StopLoss::FixedPercent(1.0))
509            .build()
510            .unwrap();
511
512        let prices = [1.0, 1.0, 1.0, 2.0, 3.0, 2.0, 1.0];
513        let candles = make_candles(&prices);
514
515        let batch = evaluate_strategy_batch(&strategy, &candles);
516        let mut engine = strategy_engine(strategy);
517        let streaming: Vec<_> = candles.iter().map(|c| engine.next(c)).collect();
518
519        assert_eq!(batch, streaming);
520    }
521
522    #[test]
523    fn golden_cross_manual_verification() {
524        // Deterministic sequence: fast SMA(1) vs slow SMA(3)
525        // - Warmup: indices 0-2 (SMA3 not ready)
526        // - Index 3: fast=3.0, slow≈1.67 -> above (Entry)
527        // - Index 5: fast=0.5, slow≈2.17 -> below (Exit)
528        let prices = [1.0, 1.0, 1.0, 3.0, 3.0, 0.5];
529        let candles = make_candles(&prices);
530
531        // Manual SMA values and expected signals (above/below semantics)
532        let mut sma1 = crate::indicators::SMA::new(1);
533        let mut sma3 = crate::indicators::SMA::new(3);
534        let mut expected = Vec::new();
535        let mut prev_fast: Option<f64> = None;
536        let mut prev_slow: Option<f64> = None;
537
538        for c in &candles {
539            let fast = sma1.next(c);
540            let slow = sma3.next(c);
541
542            let sig = match (fast, slow, prev_fast, prev_slow) {
543                (Some(f), Some(s), Some(_), Some(_)) => {
544                    if f > s {
545                        Signal::Entry(Side::Long)
546                    } else if f < s {
547                        Signal::Exit(ExitReason::RuleTriggered)
548                    } else {
549                        Signal::Hold
550                    }
551                }
552                _ => Signal::Hold,
553            };
554
555            expected.push(sig);
556            prev_fast = fast;
557            prev_slow = slow;
558        }
559
560        let entry = IndicatorRef::sma(1).is_above_indicator(IndicatorRef::sma(3));
561        let exit = IndicatorRef::sma(1).is_below_indicator(IndicatorRef::sma(3));
562        let strategy = Strategy::builder("gc_manual")
563            .entry(entry)
564            .exit(exit)
565            .stop_loss(StopLoss::FixedPercent(1.0))
566            .build()
567            .unwrap();
568
569        let signals = evaluate_strategy_batch(&strategy, &candles);
570
571        let entry_idx = signals.iter().position(|s| matches!(s, Signal::Entry(_)));
572        let exit_idx = signals.iter().position(|s| matches!(s, Signal::Exit(_)));
573
574        assert!(entry_idx.is_some(), "expected at least one entry signal");
575        assert!(exit_idx.is_some(), "expected at least one exit signal");
576        if let (Some(ei), Some(xi)) = (entry_idx, exit_idx) {
577            assert!(ei < xi, "entry should occur before exit");
578        }
579    }
580
581    #[test]
582    fn batch_a_indicators_in_strategy_flow() {
583        // Test WMA indicator in strategy
584        let entry = IndicatorRef::wma(3).crosses_above_indicator(IndicatorRef::sma(3));
585        let strategy = Strategy::builder("wma_test")
586            .entry(entry)
587            .stop_loss(StopLoss::FixedPercent(1.0))
588            .build()
589            .unwrap();
590
591        let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
592        let candles = make_candles(&prices);
593        let signals = evaluate_strategy_batch(&strategy, &candles);
594        assert!(!signals.is_empty());
595
596        // Test ROC indicator in strategy
597        let entry = IndicatorRef::roc(2).is_above(0.0);
598        let strategy = Strategy::builder("roc_test")
599            .entry(entry)
600            .stop_loss(StopLoss::FixedPercent(1.0))
601            .build()
602            .unwrap();
603
604        let signals = evaluate_strategy_batch(&strategy, &candles);
605        assert!(!signals.is_empty());
606
607        // Test StdDev indicator in strategy
608        let entry = IndicatorRef::stddev(3).is_above(0.5);
609        let strategy = Strategy::builder("stddev_test")
610            .entry(entry)
611            .stop_loss(StopLoss::FixedPercent(1.0))
612            .build()
613            .unwrap();
614
615        let signals = evaluate_strategy_batch(&strategy, &candles);
616        assert!(!signals.is_empty());
617
618        // Test DEMA indicator in strategy
619        let entry = IndicatorRef::dema(3).crosses_above(2.5);
620        let strategy = Strategy::builder("dema_test")
621            .entry(entry)
622            .stop_loss(StopLoss::FixedPercent(1.0))
623            .build()
624            .unwrap();
625
626        let signals = evaluate_strategy_batch(&strategy, &candles);
627        assert!(!signals.is_empty());
628
629        // Test TEMA indicator in strategy
630        let entry = IndicatorRef::tema(3).is_above(2.0);
631        let strategy = Strategy::builder("tema_test")
632            .entry(entry)
633            .stop_loss(StopLoss::FixedPercent(1.0))
634            .build()
635            .unwrap();
636
637        let signals = evaluate_strategy_batch(&strategy, &candles);
638        assert!(!signals.is_empty());
639
640        // Test CCI indicator in strategy
641        let entry = IndicatorRef::cci(3).is_above(0.0);
642        let strategy = Strategy::builder("cci_test")
643            .entry(entry)
644            .stop_loss(StopLoss::FixedPercent(1.0))
645            .build()
646            .unwrap();
647
648        let signals = evaluate_strategy_batch(&strategy, &candles);
649        assert!(!signals.is_empty());
650
651        // Test Williams %R indicator in strategy
652        let entry = IndicatorRef::williams_r(3).is_below(-50.0);
653        let strategy = Strategy::builder("williams_r_test")
654            .entry(entry)
655            .stop_loss(StopLoss::FixedPercent(1.0))
656            .build()
657            .unwrap();
658
659        let signals = evaluate_strategy_batch(&strategy, &candles);
660        assert!(!signals.is_empty());
661
662        // Test ADX indicator in strategy
663        let entry = IndicatorRef::adx(3).is_above(20.0);
664        let strategy = Strategy::builder("adx_test")
665            .entry(entry)
666            .stop_loss(StopLoss::FixedPercent(1.0))
667            .build()
668            .unwrap();
669
670        let signals = evaluate_strategy_batch(&strategy, &candles);
671        assert!(!signals.is_empty());
672    }
673}