Skip to main content

mantis_ta/strategy/
indicator_ref.rs

1use super::types::{CompareTarget, Condition, ConditionGroup, ConditionNode, Operator};
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6/// Reference to an indicator within a strategy, with convenience methods for building conditions.
7#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[derive(Debug, Clone, PartialEq)]
9pub struct IndicatorRef {
10    pub name: String,
11}
12
13impl IndicatorRef {
14    /// Create a new indicator reference.
15    pub fn new(name: impl Into<String>) -> Self {
16        Self { name: name.into() }
17    }
18
19    /// SMA convenience constructor.
20    pub fn sma(period: usize) -> Self {
21        Self::new(format!("sma_{}", period))
22    }
23
24    /// EMA convenience constructor.
25    pub fn ema(period: usize) -> Self {
26        Self::new(format!("ema_{}", period))
27    }
28
29    /// MACD convenience constructor.
30    pub fn macd(fast: usize, slow: usize, signal: usize) -> Self {
31        Self::new(format!("macd_{}_{}_{}_line", fast, slow, signal))
32    }
33
34    /// MACD signal line convenience constructor.
35    pub fn macd_signal(fast: usize, slow: usize, signal: usize) -> Self {
36        Self::new(format!("macd_{}_{}_{}_signal", fast, slow, signal))
37    }
38
39    /// RSI convenience constructor.
40    pub fn rsi(period: usize) -> Self {
41        Self::new(format!("rsi_{}", period))
42    }
43
44    /// Stochastic %K convenience constructor.
45    pub fn stoch_k(k_period: usize, d_period: usize) -> Self {
46        Self::new(format!("stoch_{}_{}_k", k_period, d_period))
47    }
48
49    /// Stochastic %D convenience constructor.
50    pub fn stoch_d(k_period: usize, d_period: usize) -> Self {
51        Self::new(format!("stoch_{}_{}_d", k_period, d_period))
52    }
53
54    /// Bollinger Bands upper convenience constructor.
55    pub fn bb_upper(period: usize, std_dev: f64) -> Self {
56        Self::new(format!("bb_{}_{}_upper", period, std_dev))
57    }
58
59    /// Bollinger Bands middle convenience constructor.
60    pub fn bb_middle(period: usize, std_dev: f64) -> Self {
61        Self::new(format!("bb_{}_{}_middle", period, std_dev))
62    }
63
64    /// Bollinger Bands lower convenience constructor.
65    pub fn bb_lower(period: usize, std_dev: f64) -> Self {
66        Self::new(format!("bb_{}_{}_lower", period, std_dev))
67    }
68
69    /// ATR convenience constructor.
70    pub fn atr(period: usize) -> Self {
71        Self::new(format!("atr_{}", period))
72    }
73
74    /// Volume SMA convenience constructor.
75    pub fn volume_sma(period: usize) -> Self {
76        Self::new(format!("volume_sma_{}", period))
77    }
78
79    /// OBV convenience constructor.
80    pub fn obv() -> Self {
81        Self::new("obv")
82    }
83
84    /// Pivot Points convenience constructor.
85    pub fn pivot_points() -> Self {
86        Self::new("pivot_points")
87    }
88
89    // Condition building methods
90
91    /// Create a condition: this indicator crosses above a value.
92    pub fn crosses_above(self, value: f64) -> ConditionNode {
93        ConditionNode::Condition(Condition::new(
94            self.name,
95            Operator::CrossesAbove,
96            CompareTarget::Value(value),
97        ))
98    }
99
100    /// Create a condition: this indicator crosses above another indicator.
101    pub fn crosses_above_indicator(self, other: IndicatorRef) -> ConditionNode {
102        ConditionNode::Condition(Condition::new(
103            self.name,
104            Operator::CrossesAbove,
105            CompareTarget::Indicator(other.name),
106        ))
107    }
108
109    /// Create a condition: this indicator crosses below a value.
110    pub fn crosses_below(self, value: f64) -> ConditionNode {
111        ConditionNode::Condition(Condition::new(
112            self.name,
113            Operator::CrossesBelow,
114            CompareTarget::Value(value),
115        ))
116    }
117
118    /// Create a condition: this indicator crosses below another indicator.
119    pub fn crosses_below_indicator(self, other: IndicatorRef) -> ConditionNode {
120        ConditionNode::Condition(Condition::new(
121            self.name,
122            Operator::CrossesBelow,
123            CompareTarget::Indicator(other.name),
124        ))
125    }
126
127    /// Create a condition: this indicator is above a value.
128    pub fn is_above(self, value: f64) -> ConditionNode {
129        ConditionNode::Condition(Condition::new(
130            self.name,
131            Operator::IsAbove,
132            CompareTarget::Value(value),
133        ))
134    }
135
136    /// Create a condition: this indicator is above another indicator.
137    pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
138        ConditionNode::Condition(Condition::new(
139            self.name,
140            Operator::IsAbove,
141            CompareTarget::Indicator(other.name),
142        ))
143    }
144
145    /// Create a condition: this indicator is below a value.
146    pub fn is_below(self, value: f64) -> ConditionNode {
147        ConditionNode::Condition(Condition::new(
148            self.name,
149            Operator::IsBelow,
150            CompareTarget::Value(value),
151        ))
152    }
153
154    /// Create a condition: this indicator is below another indicator.
155    pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
156        ConditionNode::Condition(Condition::new(
157            self.name,
158            Operator::IsBelow,
159            CompareTarget::Indicator(other.name),
160        ))
161    }
162
163    /// Create a condition: this indicator equals a value (within epsilon).
164    pub fn equals(self, value: f64) -> ConditionNode {
165        ConditionNode::Condition(Condition::new(
166            self.name,
167            Operator::Equals,
168            CompareTarget::Value(value),
169        ))
170    }
171
172    /// Create a condition: this indicator equals another indicator (within epsilon).
173    pub fn equals_indicator(self, other: IndicatorRef) -> ConditionNode {
174        ConditionNode::Condition(Condition::new(
175            self.name,
176            Operator::Equals,
177            CompareTarget::Indicator(other.name),
178        ))
179    }
180
181    /// Create a condition: this indicator is between two values.
182    pub fn is_between(self, lower: f64, upper: f64) -> ConditionNode {
183        ConditionNode::Condition(Condition::new(
184            self.name,
185            Operator::IsBetween,
186            CompareTarget::Range(lower, upper),
187        ))
188    }
189
190    /// Create a condition: this indicator is rising over `bars` bars.
191    pub fn is_rising(self, bars: u32) -> ConditionNode {
192        ConditionNode::Condition(Condition::new(
193            self.name,
194            Operator::IsRising(bars),
195            CompareTarget::None,
196        ))
197    }
198
199    /// Create a condition: this indicator is falling over `bars` bars.
200    pub fn is_falling(self, bars: u32) -> ConditionNode {
201        ConditionNode::Condition(Condition::new(
202            self.name,
203            Operator::IsFalling(bars),
204            CompareTarget::None,
205        ))
206    }
207
208    /// Create a condition: this indicator scaled by a multiplier is above a value.
209    pub fn scaled(self, multiplier: f64) -> ScaledIndicatorRef {
210        ScaledIndicatorRef {
211            name: self.name,
212            multiplier,
213        }
214    }
215}
216
217/// A scaled indicator reference for use in conditions.
218#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
219#[derive(Debug, Clone, PartialEq)]
220pub struct ScaledIndicatorRef {
221    pub name: String,
222    pub multiplier: f64,
223}
224
225impl ScaledIndicatorRef {
226    /// Create a condition: this scaled indicator is above a value.
227    pub fn is_above_value(self, value: f64) -> ConditionNode {
228        ConditionNode::Condition(Condition::new(
229            format!("{}*{}", self.name, self.multiplier),
230            Operator::IsAbove,
231            CompareTarget::Value(value),
232        ))
233    }
234
235    /// Create a condition: this scaled indicator is above another indicator.
236    pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
237        ConditionNode::Condition(Condition::new(
238            format!("{}*{}", self.name, self.multiplier),
239            Operator::IsAbove,
240            CompareTarget::Indicator(other.name),
241        ))
242    }
243
244    /// Create a condition: this scaled indicator is below a value.
245    pub fn is_below_value(self, value: f64) -> ConditionNode {
246        ConditionNode::Condition(Condition::new(
247            format!("{}*{}", self.name, self.multiplier),
248            Operator::IsBelow,
249            CompareTarget::Value(value),
250        ))
251    }
252
253    /// Create a condition: this scaled indicator is below another indicator.
254    pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
255        ConditionNode::Condition(Condition::new(
256            format!("{}*{}", self.name, self.multiplier),
257            Operator::IsBelow,
258            CompareTarget::Indicator(other.name),
259        ))
260    }
261}
262
263/// Create an AllOf condition group.
264pub fn all_of(conditions: Vec<ConditionNode>) -> ConditionNode {
265    ConditionNode::Group(ConditionGroup::AllOf(conditions))
266}
267
268/// Create an AnyOf condition group.
269pub fn any_of(conditions: Vec<ConditionNode>) -> ConditionNode {
270    ConditionNode::Group(ConditionGroup::AnyOf(conditions))
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn indicator_ref_convenience_constructors() {
279        let sma = IndicatorRef::sma(20);
280        assert_eq!(sma.name, "sma_20");
281
282        let ema = IndicatorRef::ema(14);
283        assert_eq!(ema.name, "ema_14");
284
285        let rsi = IndicatorRef::rsi(14);
286        assert_eq!(rsi.name, "rsi_14");
287
288        let obv = IndicatorRef::obv();
289        assert_eq!(obv.name, "obv");
290    }
291
292    #[test]
293    fn condition_building() {
294        let sma = IndicatorRef::sma(20);
295        let cond = sma.crosses_above(100.0);
296        assert!(matches!(cond, ConditionNode::Condition(_)));
297    }
298
299    #[test]
300    fn condition_grouping() {
301        let sma = IndicatorRef::sma(20);
302        let rsi = IndicatorRef::rsi(14);
303
304        let cond1 = sma.is_above(100.0);
305        let cond2 = rsi.is_below(70.0);
306
307        let group = all_of(vec![cond1, cond2]);
308        assert!(matches!(
309            group,
310            ConditionNode::Group(ConditionGroup::AllOf(_))
311        ));
312    }
313
314    #[test]
315    fn scaled_indicator_ref() {
316        let atr = IndicatorRef::atr(14);
317        let scaled = atr.scaled(2.0);
318        assert_eq!(scaled.multiplier, 2.0);
319    }
320
321    #[test]
322    fn scaled_is_above_indicator_has_correct_semantics() {
323        // atr.scaled(2.0).is_above_indicator(price) should mean "atr*2 is above price"
324        let cond = IndicatorRef::atr(14)
325            .scaled(2.0)
326            .is_above_indicator(IndicatorRef::new("price"));
327        match cond {
328            ConditionNode::Condition(c) => {
329                assert_eq!(c.left, "atr_14*2");
330                assert_eq!(c.operator, Operator::IsAbove);
331                assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
332            }
333            _ => panic!("expected Condition"),
334        }
335    }
336
337    #[test]
338    fn scaled_is_below_indicator_has_correct_semantics() {
339        // atr.scaled(1.5).is_below_indicator(price) should mean "atr*1.5 is below price"
340        let cond = IndicatorRef::atr(14)
341            .scaled(1.5)
342            .is_below_indicator(IndicatorRef::new("price"));
343        match cond {
344            ConditionNode::Condition(c) => {
345                assert_eq!(c.left, "atr_14*1.5");
346                assert_eq!(c.operator, Operator::IsBelow);
347                assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
348            }
349            _ => panic!("expected Condition"),
350        }
351    }
352
353    #[test]
354    fn scaled_is_above_value_has_correct_semantics() {
355        let cond = IndicatorRef::atr(14).scaled(2.0).is_above_value(50.0);
356        match cond {
357            ConditionNode::Condition(c) => {
358                assert_eq!(c.left, "atr_14*2");
359                assert_eq!(c.operator, Operator::IsAbove);
360                assert_eq!(c.right, CompareTarget::Value(50.0));
361            }
362            _ => panic!("expected Condition"),
363        }
364    }
365}