Skip to main content

finance_query/backtesting/condition/
comparison.rs

1//! Comparison conditions for indicator references.
2//!
3//! This module provides conditions that compare indicator values
4//! against thresholds or other indicators.
5
6use crate::backtesting::refs::IndicatorRef;
7use crate::backtesting::strategy::StrategyContext;
8use crate::indicators::Indicator;
9
10use super::Condition;
11
12// ============================================================================
13// THRESHOLD COMPARISONS
14// ============================================================================
15
16/// Condition: indicator is above a threshold.
17#[derive(Clone)]
18pub struct Above<R: IndicatorRef> {
19    indicator: R,
20    threshold: f64,
21}
22
23impl<R: IndicatorRef> Above<R> {
24    /// Create a new Above condition.
25    pub fn new(indicator: R, threshold: f64) -> Self {
26        Self {
27            indicator,
28            threshold,
29        }
30    }
31}
32
33impl<R: IndicatorRef> Condition for Above<R> {
34    fn evaluate(&self, ctx: &StrategyContext) -> bool {
35        self.indicator
36            .value(ctx)
37            .map(|v| v > self.threshold)
38            .unwrap_or(false)
39    }
40
41    fn required_indicators(&self) -> Vec<(String, Indicator)> {
42        self.indicator.required_indicators()
43    }
44
45    fn description(&self) -> String {
46        format!("{} > {:.2}", self.indicator.key(), self.threshold)
47    }
48}
49
50/// Condition: indicator is below a threshold.
51#[derive(Clone)]
52pub struct Below<R: IndicatorRef> {
53    indicator: R,
54    threshold: f64,
55}
56
57impl<R: IndicatorRef> Below<R> {
58    /// Create a new Below condition.
59    pub fn new(indicator: R, threshold: f64) -> Self {
60        Self {
61            indicator,
62            threshold,
63        }
64    }
65}
66
67impl<R: IndicatorRef> Condition for Below<R> {
68    fn evaluate(&self, ctx: &StrategyContext) -> bool {
69        self.indicator
70            .value(ctx)
71            .map(|v| v < self.threshold)
72            .unwrap_or(false)
73    }
74
75    fn required_indicators(&self) -> Vec<(String, Indicator)> {
76        self.indicator.required_indicators()
77    }
78
79    fn description(&self) -> String {
80        format!("{} < {:.2}", self.indicator.key(), self.threshold)
81    }
82}
83
84/// Condition: indicator crosses above a threshold.
85///
86/// True when the previous value was **at or below** the threshold (`prev <=
87/// threshold`) and the current value is **strictly above** it (`curr >
88/// threshold`). The inclusive previous-bar test prevents missing a crossover
89/// when the value touches the threshold exactly before rising.
90#[derive(Clone)]
91pub struct CrossesAbove<R: IndicatorRef> {
92    indicator: R,
93    threshold: f64,
94}
95
96impl<R: IndicatorRef> CrossesAbove<R> {
97    /// Create a new CrossesAbove condition.
98    pub fn new(indicator: R, threshold: f64) -> Self {
99        Self {
100            indicator,
101            threshold,
102        }
103    }
104}
105
106impl<R: IndicatorRef> Condition for CrossesAbove<R> {
107    fn evaluate(&self, ctx: &StrategyContext) -> bool {
108        let current = self.indicator.value(ctx);
109        let prev = self.indicator.prev_value(ctx);
110
111        match (current, prev) {
112            (Some(curr), Some(p)) => p <= self.threshold && curr > self.threshold,
113            _ => false,
114        }
115    }
116
117    fn required_indicators(&self) -> Vec<(String, Indicator)> {
118        self.indicator.required_indicators()
119    }
120
121    fn description(&self) -> String {
122        format!(
123            "{} crosses above {:.2}",
124            self.indicator.key(),
125            self.threshold
126        )
127    }
128}
129
130/// Condition: indicator crosses below a threshold.
131///
132/// True when the previous value was **at or above** the threshold (`prev >=
133/// threshold`) and the current value is **strictly below** it (`curr <
134/// threshold`). The inclusive previous-bar test prevents missing a crossover
135/// when the value touches the threshold exactly before falling.
136#[derive(Clone)]
137pub struct CrossesBelow<R: IndicatorRef> {
138    indicator: R,
139    threshold: f64,
140}
141
142impl<R: IndicatorRef> CrossesBelow<R> {
143    /// Create a new CrossesBelow condition.
144    pub fn new(indicator: R, threshold: f64) -> Self {
145        Self {
146            indicator,
147            threshold,
148        }
149    }
150}
151
152impl<R: IndicatorRef> Condition for CrossesBelow<R> {
153    fn evaluate(&self, ctx: &StrategyContext) -> bool {
154        let current = self.indicator.value(ctx);
155        let prev = self.indicator.prev_value(ctx);
156
157        match (current, prev) {
158            (Some(curr), Some(p)) => p >= self.threshold && curr < self.threshold,
159            _ => false,
160        }
161    }
162
163    fn required_indicators(&self) -> Vec<(String, Indicator)> {
164        self.indicator.required_indicators()
165    }
166
167    fn description(&self) -> String {
168        format!(
169            "{} crosses below {:.2}",
170            self.indicator.key(),
171            self.threshold
172        )
173    }
174}
175
176/// Condition: indicator is between two thresholds.
177///
178/// True when `low < value < high`.
179#[derive(Clone)]
180pub struct Between<R: IndicatorRef> {
181    indicator: R,
182    low: f64,
183    high: f64,
184}
185
186impl<R: IndicatorRef> Between<R> {
187    /// Create a new Between condition.
188    pub fn new(indicator: R, low: f64, high: f64) -> Self {
189        Self {
190            indicator,
191            low,
192            high,
193        }
194    }
195}
196
197impl<R: IndicatorRef> Condition for Between<R> {
198    fn evaluate(&self, ctx: &StrategyContext) -> bool {
199        self.indicator
200            .value(ctx)
201            .map(|v| v > self.low && v < self.high)
202            .unwrap_or(false)
203    }
204
205    fn required_indicators(&self) -> Vec<(String, Indicator)> {
206        self.indicator.required_indicators()
207    }
208
209    fn description(&self) -> String {
210        format!(
211            "{:.2} < {} < {:.2}",
212            self.low,
213            self.indicator.key(),
214            self.high
215        )
216    }
217}
218
219/// Condition: indicator equals a value (within tolerance).
220#[derive(Clone)]
221pub struct Equals<R: IndicatorRef> {
222    indicator: R,
223    value: f64,
224    tolerance: f64,
225}
226
227impl<R: IndicatorRef> Equals<R> {
228    /// Create a new Equals condition.
229    pub fn new(indicator: R, value: f64, tolerance: f64) -> Self {
230        Self {
231            indicator,
232            value,
233            tolerance,
234        }
235    }
236}
237
238impl<R: IndicatorRef> Condition for Equals<R> {
239    fn evaluate(&self, ctx: &StrategyContext) -> bool {
240        self.indicator
241            .value(ctx)
242            .map(|v| (v - self.value).abs() <= self.tolerance)
243            .unwrap_or(false)
244    }
245
246    fn required_indicators(&self) -> Vec<(String, Indicator)> {
247        self.indicator.required_indicators()
248    }
249
250    fn description(&self) -> String {
251        format!(
252            "{} ≈ {:.2} (±{:.4})",
253            self.indicator.key(),
254            self.value,
255            self.tolerance
256        )
257    }
258}
259
260// ============================================================================
261// INDICATOR VS INDICATOR COMPARISONS
262// ============================================================================
263
264/// Condition: indicator is above another indicator.
265#[derive(Clone)]
266pub struct AboveRef<R1: IndicatorRef, R2: IndicatorRef> {
267    indicator: R1,
268    other: R2,
269}
270
271impl<R1: IndicatorRef, R2: IndicatorRef> AboveRef<R1, R2> {
272    /// Create a new AboveRef condition.
273    pub fn new(indicator: R1, other: R2) -> Self {
274        Self { indicator, other }
275    }
276}
277
278impl<R1: IndicatorRef, R2: IndicatorRef> Condition for AboveRef<R1, R2> {
279    fn evaluate(&self, ctx: &StrategyContext) -> bool {
280        let v1 = self.indicator.value(ctx);
281        let v2 = self.other.value(ctx);
282
283        match (v1, v2) {
284            (Some(a), Some(b)) => a > b,
285            _ => false,
286        }
287    }
288
289    fn required_indicators(&self) -> Vec<(String, Indicator)> {
290        let mut indicators = self.indicator.required_indicators();
291        indicators.extend(self.other.required_indicators());
292        indicators
293    }
294
295    fn description(&self) -> String {
296        format!("{} > {}", self.indicator.key(), self.other.key())
297    }
298}
299
300/// Condition: indicator is below another indicator.
301#[derive(Clone)]
302pub struct BelowRef<R1: IndicatorRef, R2: IndicatorRef> {
303    indicator: R1,
304    other: R2,
305}
306
307impl<R1: IndicatorRef, R2: IndicatorRef> BelowRef<R1, R2> {
308    /// Create a new BelowRef condition.
309    pub fn new(indicator: R1, other: R2) -> Self {
310        Self { indicator, other }
311    }
312}
313
314impl<R1: IndicatorRef, R2: IndicatorRef> Condition for BelowRef<R1, R2> {
315    fn evaluate(&self, ctx: &StrategyContext) -> bool {
316        let v1 = self.indicator.value(ctx);
317        let v2 = self.other.value(ctx);
318
319        match (v1, v2) {
320            (Some(a), Some(b)) => a < b,
321            _ => false,
322        }
323    }
324
325    fn required_indicators(&self) -> Vec<(String, Indicator)> {
326        let mut indicators = self.indicator.required_indicators();
327        indicators.extend(self.other.required_indicators());
328        indicators
329    }
330
331    fn description(&self) -> String {
332        format!("{} < {}", self.indicator.key(), self.other.key())
333    }
334}
335
336/// Condition: indicator crosses above another indicator.
337///
338/// True when the fast indicator was **at or below** the slow indicator on the
339/// previous bar (`prev_fast <= prev_slow`) and is **strictly above** it on the
340/// current bar (`curr_fast > curr_slow`).
341///
342/// # Inclusive Previous Bar
343///
344/// The previous-bar test is inclusive (`<=`). This means the crossover fires
345/// even if fast == slow on the prior bar, treating that touch as "not yet
346/// crossed". This is the most common convention in technical analysis and
347/// avoids missing a crossover when the lines converge exactly.
348#[derive(Clone)]
349pub struct CrossesAboveRef<R1: IndicatorRef, R2: IndicatorRef> {
350    fast: R1,
351    slow: R2,
352}
353
354impl<R1: IndicatorRef, R2: IndicatorRef> CrossesAboveRef<R1, R2> {
355    /// Create a new CrossesAboveRef condition.
356    pub fn new(fast: R1, slow: R2) -> Self {
357        Self { fast, slow }
358    }
359}
360
361impl<R1: IndicatorRef, R2: IndicatorRef> Condition for CrossesAboveRef<R1, R2> {
362    fn evaluate(&self, ctx: &StrategyContext) -> bool {
363        let fast_now = self.fast.value(ctx);
364        let slow_now = self.slow.value(ctx);
365        let fast_prev = self.fast.prev_value(ctx);
366        let slow_prev = self.slow.prev_value(ctx);
367
368        match (fast_now, slow_now, fast_prev, slow_prev) {
369            (Some(fn_), Some(sn), Some(fp), Some(sp)) => fp <= sp && fn_ > sn,
370            _ => false,
371        }
372    }
373
374    fn required_indicators(&self) -> Vec<(String, Indicator)> {
375        let mut indicators = self.fast.required_indicators();
376        indicators.extend(self.slow.required_indicators());
377        indicators
378    }
379
380    fn description(&self) -> String {
381        format!("{} crosses above {}", self.fast.key(), self.slow.key())
382    }
383}
384
385/// Condition: indicator crosses below another indicator.
386///
387/// True when the fast indicator was **at or above** the slow indicator on the
388/// previous bar (`prev_fast >= prev_slow`) and is **strictly below** it on the
389/// current bar (`curr_fast < curr_slow`).
390///
391/// # Inclusive Previous Bar
392///
393/// The previous-bar test is inclusive (`>=`). See [`CrossesAboveRef`] for
394/// rationale.
395#[derive(Clone)]
396pub struct CrossesBelowRef<R1: IndicatorRef, R2: IndicatorRef> {
397    fast: R1,
398    slow: R2,
399}
400
401impl<R1: IndicatorRef, R2: IndicatorRef> CrossesBelowRef<R1, R2> {
402    /// Create a new CrossesBelowRef condition.
403    pub fn new(fast: R1, slow: R2) -> Self {
404        Self { fast, slow }
405    }
406}
407
408impl<R1: IndicatorRef, R2: IndicatorRef> Condition for CrossesBelowRef<R1, R2> {
409    fn evaluate(&self, ctx: &StrategyContext) -> bool {
410        let fast_now = self.fast.value(ctx);
411        let slow_now = self.slow.value(ctx);
412        let fast_prev = self.fast.prev_value(ctx);
413        let slow_prev = self.slow.prev_value(ctx);
414
415        match (fast_now, slow_now, fast_prev, slow_prev) {
416            (Some(fn_), Some(sn), Some(fp), Some(sp)) => fp >= sp && fn_ < sn,
417            _ => false,
418        }
419    }
420
421    fn required_indicators(&self) -> Vec<(String, Indicator)> {
422        let mut indicators = self.fast.required_indicators();
423        indicators.extend(self.slow.required_indicators());
424        indicators
425    }
426
427    fn description(&self) -> String {
428        format!("{} crosses below {}", self.fast.key(), self.slow.key())
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::backtesting::refs::{rsi, sma};
436
437    #[test]
438    fn test_above_description() {
439        let cond = Above::new(rsi(14), 70.0);
440        assert_eq!(cond.description(), "rsi_14 > 70.00");
441    }
442
443    #[test]
444    fn test_below_description() {
445        let cond = Below::new(rsi(14), 30.0);
446        assert_eq!(cond.description(), "rsi_14 < 30.00");
447    }
448
449    #[test]
450    fn test_crosses_above_description() {
451        let cond = CrossesAbove::new(rsi(14), 30.0);
452        assert_eq!(cond.description(), "rsi_14 crosses above 30.00");
453    }
454
455    #[test]
456    fn test_crosses_below_description() {
457        let cond = CrossesBelow::new(rsi(14), 70.0);
458        assert_eq!(cond.description(), "rsi_14 crosses below 70.00");
459    }
460
461    #[test]
462    fn test_between_description() {
463        let cond = Between::new(rsi(14), 30.0, 70.0);
464        assert_eq!(cond.description(), "30.00 < rsi_14 < 70.00");
465    }
466
467    #[test]
468    fn test_above_ref_description() {
469        let cond = AboveRef::new(sma(10), sma(20));
470        assert_eq!(cond.description(), "sma_10 > sma_20");
471    }
472
473    #[test]
474    fn test_crosses_above_ref_description() {
475        let cond = CrossesAboveRef::new(sma(50), sma(200));
476        assert_eq!(cond.description(), "sma_50 crosses above sma_200");
477    }
478
479    #[test]
480    fn test_required_indicators() {
481        let cond = Above::new(rsi(14), 70.0);
482        let indicators = cond.required_indicators();
483        assert_eq!(indicators.len(), 1);
484        assert_eq!(indicators[0].0, "rsi_14");
485    }
486
487    #[test]
488    fn test_cross_ref_required_indicators() {
489        let cond = CrossesAboveRef::new(sma(10), sma(20));
490        let indicators = cond.required_indicators();
491        assert_eq!(indicators.len(), 2);
492    }
493}