Skip to main content

mantis_ta/strategy/
indicator_ref.rs

1use super::types::{CompareTarget, Condition, ConditionGroup, ConditionNode, Operator};
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6/// Reference to an indicator within a strategy, with convenience methods for building conditions.
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[derive(Debug, Clone, PartialEq)]
9pub struct IndicatorRef {
10    pub name: String,
11}
12
13impl IndicatorRef {
14    /// Create a new indicator reference.
15    pub fn new(name: impl Into<String>) -> Self {
16        Self { name: name.into() }
17    }
18
19    /// SMA convenience constructor.
20    pub fn sma(period: usize) -> Self {
21        Self::new(format!("sma{period}"))
22    }
23
24    /// EMA convenience constructor.
25    pub fn ema(period: usize) -> Self {
26        Self::new(format!("ema{period}"))
27    }
28
29    /// MACD convenience constructor.
30    pub fn macd(fast: usize, slow: usize, signal: usize) -> Self {
31        Self::new(format!("macd_{fast}_{slow}_{signal}_line"))
32    }
33
34    /// MACD signal line convenience constructor.
35    pub fn macd_signal(fast: usize, slow: usize, signal: usize) -> Self {
36        Self::new(format!("macd_{fast}_{slow}_{signal}_signal"))
37    }
38
39    /// RSI convenience constructor.
40    pub fn rsi(period: usize) -> Self {
41        Self::new(format!("rsi{period}"))
42    }
43
44    /// Stochastic %K convenience constructor.
45    pub fn stoch_k(k_period: usize, d_period: usize) -> Self {
46        Self::new(format!("stoch_{k_period}_{d_period}_k"))
47    }
48
49    /// Stochastic %D convenience constructor.
50    pub fn stoch_d(k_period: usize, d_period: usize) -> Self {
51        Self::new(format!("stoch_{k_period}_{d_period}_d"))
52    }
53
54    /// Bollinger Bands upper convenience constructor.
55    pub fn bb_upper(period: usize, std_dev: f64) -> Self {
56        Self::new(format!("bb_{period}_{std_dev}_upper"))
57    }
58
59    /// Bollinger Bands middle convenience constructor.
60    pub fn bb_middle(period: usize, std_dev: f64) -> Self {
61        Self::new(format!("bb_{period}_{std_dev}_middle"))
62    }
63
64    /// Bollinger Bands lower convenience constructor.
65    pub fn bb_lower(period: usize, std_dev: f64) -> Self {
66        Self::new(format!("bb_{period}_{std_dev}_lower"))
67    }
68
69    /// ATR convenience constructor.
70    pub fn atr(period: usize) -> Self {
71        Self::new(format!("atr{period}"))
72    }
73
74    /// Volume SMA convenience constructor.
75    pub fn volume_sma(period: usize) -> Self {
76        Self::new(format!("volume_sma_{period}"))
77    }
78
79    /// OBV convenience constructor.
80    pub fn obv() -> Self {
81        Self::new("obv")
82    }
83
84    /// Pivot Points convenience constructor.
85    pub fn pivot_points() -> Self {
86        Self::new("pivot_points")
87    }
88
89    /// ADX convenience constructor.
90    pub fn adx(period: usize) -> Self {
91        Self::new(format!("adx{period}"))
92    }
93
94    /// WMA convenience constructor.
95    pub fn wma(period: usize) -> Self {
96        Self::new(format!("wma{period}"))
97    }
98
99    /// DEMA convenience constructor.
100    pub fn dema(period: usize) -> Self {
101        Self::new(format!("dema{period}"))
102    }
103
104    /// TEMA convenience constructor.
105    pub fn tema(period: usize) -> Self {
106        Self::new(format!("tema{period}"))
107    }
108
109    /// CCI convenience constructor.
110    pub fn cci(period: usize) -> Self {
111        Self::new(format!("cci{period}"))
112    }
113
114    /// Williams %R convenience constructor.
115    pub fn williams_r(period: usize) -> Self {
116        Self::new(format!("williams_r{period}"))
117    }
118
119    /// ROC convenience constructor.
120    pub fn roc(period: usize) -> Self {
121        Self::new(format!("roc{period}"))
122    }
123
124    /// Standard Deviation convenience constructor.
125    pub fn stddev(period: usize) -> Self {
126        Self::new(format!("stddev{period}"))
127    }
128
129    // Condition building methods
130
131    /// Create a condition: this indicator crosses above a value.
132    pub fn crosses_above(self, value: f64) -> ConditionNode {
133        ConditionNode::Condition(Condition::new(
134            self.name,
135            Operator::CrossesAbove,
136            CompareTarget::Value(value),
137        ))
138    }
139
140    /// Create a condition: this indicator crosses above another indicator.
141    pub fn crosses_above_indicator(self, other: IndicatorRef) -> ConditionNode {
142        ConditionNode::Condition(Condition::new(
143            self.name,
144            Operator::CrossesAbove,
145            CompareTarget::Indicator(other.name),
146        ))
147    }
148
149    /// Create a condition: this indicator crosses below a value.
150    pub fn crosses_below(self, value: f64) -> ConditionNode {
151        ConditionNode::Condition(Condition::new(
152            self.name,
153            Operator::CrossesBelow,
154            CompareTarget::Value(value),
155        ))
156    }
157
158    /// Create a condition: this indicator crosses below another indicator.
159    pub fn crosses_below_indicator(self, other: IndicatorRef) -> ConditionNode {
160        ConditionNode::Condition(Condition::new(
161            self.name,
162            Operator::CrossesBelow,
163            CompareTarget::Indicator(other.name),
164        ))
165    }
166
167    /// Create a condition: this indicator is above a value.
168    pub fn is_above(self, value: f64) -> ConditionNode {
169        ConditionNode::Condition(Condition::new(
170            self.name,
171            Operator::IsAbove,
172            CompareTarget::Value(value),
173        ))
174    }
175
176    /// Create a condition: this indicator is above another indicator.
177    pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
178        ConditionNode::Condition(Condition::new(
179            self.name,
180            Operator::IsAbove,
181            CompareTarget::Indicator(other.name),
182        ))
183    }
184
185    /// Create a condition: this indicator is below a value.
186    pub fn is_below(self, value: f64) -> ConditionNode {
187        ConditionNode::Condition(Condition::new(
188            self.name,
189            Operator::IsBelow,
190            CompareTarget::Value(value),
191        ))
192    }
193
194    /// Create a condition: this indicator is below another indicator.
195    pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
196        ConditionNode::Condition(Condition::new(
197            self.name,
198            Operator::IsBelow,
199            CompareTarget::Indicator(other.name),
200        ))
201    }
202
203    /// Create a condition: this indicator equals a value (within epsilon).
204    pub fn equals(self, value: f64) -> ConditionNode {
205        ConditionNode::Condition(Condition::new(
206            self.name,
207            Operator::Equals,
208            CompareTarget::Value(value),
209        ))
210    }
211
212    /// Create a condition: this indicator equals another indicator (within epsilon).
213    pub fn equals_indicator(self, other: IndicatorRef) -> ConditionNode {
214        ConditionNode::Condition(Condition::new(
215            self.name,
216            Operator::Equals,
217            CompareTarget::Indicator(other.name),
218        ))
219    }
220
221    /// Create a condition: this indicator is between two values.
222    pub fn is_between(self, lower: f64, upper: f64) -> ConditionNode {
223        ConditionNode::Condition(Condition::new(
224            self.name,
225            Operator::IsBetween,
226            CompareTarget::Range(lower, upper),
227        ))
228    }
229
230    /// Create a condition: this indicator is rising over `bars` bars.
231    pub fn is_rising(self, bars: u32) -> ConditionNode {
232        ConditionNode::Condition(Condition::new(
233            self.name,
234            Operator::IsRising(bars),
235            CompareTarget::None,
236        ))
237    }
238
239    /// Create a condition: this indicator is falling over `bars` bars.
240    pub fn is_falling(self, bars: u32) -> ConditionNode {
241        ConditionNode::Condition(Condition::new(
242            self.name,
243            Operator::IsFalling(bars),
244            CompareTarget::None,
245        ))
246    }
247
248    /// Create a condition: this indicator scaled by a multiplier is above a value.
249    pub fn scaled(self, multiplier: f64) -> ScaledIndicatorRef {
250        ScaledIndicatorRef {
251            name: self.name,
252            multiplier,
253        }
254    }
255}
256
257/// A scaled indicator reference for use in conditions.
258#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
259#[derive(Debug, Clone, PartialEq)]
260pub struct ScaledIndicatorRef {
261    pub name: String,
262    pub multiplier: f64,
263}
264
265impl ScaledIndicatorRef {
266    /// Create a condition: this scaled indicator is above a value.
267    pub fn is_above_value(self, value: f64) -> ConditionNode {
268        ConditionNode::Condition(Condition::new(
269            format!("{}*{}", self.name, self.multiplier),
270            Operator::IsAbove,
271            CompareTarget::Value(value),
272        ))
273    }
274
275    /// Create a condition: this scaled indicator is above another indicator.
276    pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
277        ConditionNode::Condition(Condition::new(
278            format!("{}*{}", self.name, self.multiplier),
279            Operator::IsAbove,
280            CompareTarget::Indicator(other.name),
281        ))
282    }
283
284    /// Create a condition: this scaled indicator is below a value.
285    pub fn is_below_value(self, value: f64) -> ConditionNode {
286        ConditionNode::Condition(Condition::new(
287            format!("{}*{}", self.name, self.multiplier),
288            Operator::IsBelow,
289            CompareTarget::Value(value),
290        ))
291    }
292
293    /// Create a condition: this scaled indicator is below another indicator.
294    pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
295        ConditionNode::Condition(Condition::new(
296            format!("{}*{}", self.name, self.multiplier),
297            Operator::IsBelow,
298            CompareTarget::Indicator(other.name),
299        ))
300    }
301}
302
303/// Create an AllOf condition group.
304pub fn all_of(conditions: Vec<ConditionNode>) -> ConditionNode {
305    ConditionNode::Group(ConditionGroup::AllOf(conditions))
306}
307
308/// Create an AnyOf condition group.
309pub fn any_of(conditions: Vec<ConditionNode>) -> ConditionNode {
310    ConditionNode::Group(ConditionGroup::AnyOf(conditions))
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn indicator_ref_convenience_constructors() {
319        let sma = IndicatorRef::sma(20);
320        assert_eq!(sma.name, "sma20");
321
322        let ema = IndicatorRef::ema(14);
323        assert_eq!(ema.name, "ema14");
324
325        let rsi = IndicatorRef::rsi(14);
326        assert_eq!(rsi.name, "rsi14");
327
328        let obv = IndicatorRef::obv();
329        assert_eq!(obv.name, "obv");
330    }
331
332    #[test]
333    fn batch_a_indicator_ref_convenience_constructors() {
334        let adx = IndicatorRef::adx(14);
335        assert_eq!(adx.name, "adx14");
336
337        let wma = IndicatorRef::wma(20);
338        assert_eq!(wma.name, "wma20");
339
340        let dema = IndicatorRef::dema(10);
341        assert_eq!(dema.name, "dema10");
342
343        let tema = IndicatorRef::tema(10);
344        assert_eq!(tema.name, "tema10");
345
346        let cci = IndicatorRef::cci(20);
347        assert_eq!(cci.name, "cci20");
348
349        let williams_r = IndicatorRef::williams_r(14);
350        assert_eq!(williams_r.name, "williams_r14");
351
352        let roc = IndicatorRef::roc(12);
353        assert_eq!(roc.name, "roc12");
354
355        let stddev = IndicatorRef::stddev(20);
356        assert_eq!(stddev.name, "stddev20");
357    }
358
359    #[test]
360    fn condition_building() {
361        let sma = IndicatorRef::sma(20);
362        let cond = sma.crosses_above(100.0);
363        assert!(matches!(cond, ConditionNode::Condition(_)));
364    }
365
366    #[test]
367    fn condition_grouping() {
368        let sma = IndicatorRef::sma(20);
369        let rsi = IndicatorRef::rsi(14);
370
371        let cond1 = sma.is_above(100.0);
372        let cond2 = rsi.is_below(70.0);
373
374        let group = all_of(vec![cond1, cond2]);
375        assert!(matches!(
376            group,
377            ConditionNode::Group(ConditionGroup::AllOf(_))
378        ));
379    }
380
381    #[test]
382    fn scaled_indicator_ref() {
383        let atr = IndicatorRef::atr(14);
384        let scaled = atr.scaled(2.0);
385        assert_eq!(scaled.multiplier, 2.0);
386    }
387
388    #[test]
389    fn scaled_is_above_indicator_has_correct_semantics() {
390        // atr.scaled(2.0).is_above_indicator(price) should mean "atr*2 is above price"
391        let cond = IndicatorRef::atr(14)
392            .scaled(2.0)
393            .is_above_indicator(IndicatorRef::new("price"));
394        match cond {
395            ConditionNode::Condition(c) => {
396                assert_eq!(c.left, "atr14*2");
397                assert_eq!(c.operator, Operator::IsAbove);
398                assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
399            }
400            _ => panic!("expected Condition"),
401        }
402    }
403
404    #[test]
405    fn scaled_is_below_indicator_has_correct_semantics() {
406        // atr.scaled(1.5).is_below_indicator(price) should mean "atr*1.5 is below price"
407        let cond = IndicatorRef::atr(14)
408            .scaled(1.5)
409            .is_below_indicator(IndicatorRef::new("price"));
410        match cond {
411            ConditionNode::Condition(c) => {
412                assert_eq!(c.left, "atr14*1.5");
413                assert_eq!(c.operator, Operator::IsBelow);
414                assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
415            }
416            _ => panic!("expected Condition"),
417        }
418    }
419
420    #[test]
421    fn scaled_is_above_value_has_correct_semantics() {
422        let cond = IndicatorRef::atr(14).scaled(2.0).is_above_value(50.0);
423        match cond {
424            ConditionNode::Condition(c) => {
425                assert_eq!(c.left, "atr14*2");
426                assert_eq!(c.operator, Operator::IsAbove);
427                assert_eq!(c.right, CompareTarget::Value(50.0));
428            }
429            _ => panic!("expected Condition"),
430        }
431    }
432}