1use super::types::{CompareTarget, Condition, ConditionGroup, ConditionNode, Operator};
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[derive(Debug, Clone, PartialEq)]
9pub struct IndicatorRef {
10 pub name: String,
11}
12
13impl IndicatorRef {
14 pub fn new(name: impl Into<String>) -> Self {
16 Self { name: name.into() }
17 }
18
19 pub fn sma(period: usize) -> Self {
21 Self::new(format!("sma_{}", period))
22 }
23
24 pub fn ema(period: usize) -> Self {
26 Self::new(format!("ema_{}", period))
27 }
28
29 pub fn macd(fast: usize, slow: usize, signal: usize) -> Self {
31 Self::new(format!("macd_{}_{}_{}_line", fast, slow, signal))
32 }
33
34 pub fn macd_signal(fast: usize, slow: usize, signal: usize) -> Self {
36 Self::new(format!("macd_{}_{}_{}_signal", fast, slow, signal))
37 }
38
39 pub fn rsi(period: usize) -> Self {
41 Self::new(format!("rsi_{}", period))
42 }
43
44 pub fn stoch_k(k_period: usize, d_period: usize) -> Self {
46 Self::new(format!("stoch_{}_{}_k", k_period, d_period))
47 }
48
49 pub fn stoch_d(k_period: usize, d_period: usize) -> Self {
51 Self::new(format!("stoch_{}_{}_d", k_period, d_period))
52 }
53
54 pub fn bb_upper(period: usize, std_dev: f64) -> Self {
56 Self::new(format!("bb_{}_{}_upper", period, std_dev))
57 }
58
59 pub fn bb_middle(period: usize, std_dev: f64) -> Self {
61 Self::new(format!("bb_{}_{}_middle", period, std_dev))
62 }
63
64 pub fn bb_lower(period: usize, std_dev: f64) -> Self {
66 Self::new(format!("bb_{}_{}_lower", period, std_dev))
67 }
68
69 pub fn atr(period: usize) -> Self {
71 Self::new(format!("atr_{}", period))
72 }
73
74 pub fn volume_sma(period: usize) -> Self {
76 Self::new(format!("volume_sma_{}", period))
77 }
78
79 pub fn obv() -> Self {
81 Self::new("obv")
82 }
83
84 pub fn pivot_points() -> Self {
86 Self::new("pivot_points")
87 }
88
89 pub fn crosses_above(self, value: f64) -> ConditionNode {
93 ConditionNode::Condition(Condition::new(
94 self.name,
95 Operator::CrossesAbove,
96 CompareTarget::Value(value),
97 ))
98 }
99
100 pub fn crosses_above_indicator(self, other: IndicatorRef) -> ConditionNode {
102 ConditionNode::Condition(Condition::new(
103 self.name,
104 Operator::CrossesAbove,
105 CompareTarget::Indicator(other.name),
106 ))
107 }
108
109 pub fn crosses_below(self, value: f64) -> ConditionNode {
111 ConditionNode::Condition(Condition::new(
112 self.name,
113 Operator::CrossesBelow,
114 CompareTarget::Value(value),
115 ))
116 }
117
118 pub fn crosses_below_indicator(self, other: IndicatorRef) -> ConditionNode {
120 ConditionNode::Condition(Condition::new(
121 self.name,
122 Operator::CrossesBelow,
123 CompareTarget::Indicator(other.name),
124 ))
125 }
126
127 pub fn is_above(self, value: f64) -> ConditionNode {
129 ConditionNode::Condition(Condition::new(
130 self.name,
131 Operator::IsAbove,
132 CompareTarget::Value(value),
133 ))
134 }
135
136 pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
138 ConditionNode::Condition(Condition::new(
139 self.name,
140 Operator::IsAbove,
141 CompareTarget::Indicator(other.name),
142 ))
143 }
144
145 pub fn is_below(self, value: f64) -> ConditionNode {
147 ConditionNode::Condition(Condition::new(
148 self.name,
149 Operator::IsBelow,
150 CompareTarget::Value(value),
151 ))
152 }
153
154 pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
156 ConditionNode::Condition(Condition::new(
157 self.name,
158 Operator::IsBelow,
159 CompareTarget::Indicator(other.name),
160 ))
161 }
162
163 pub fn equals(self, value: f64) -> ConditionNode {
165 ConditionNode::Condition(Condition::new(
166 self.name,
167 Operator::Equals,
168 CompareTarget::Value(value),
169 ))
170 }
171
172 pub fn equals_indicator(self, other: IndicatorRef) -> ConditionNode {
174 ConditionNode::Condition(Condition::new(
175 self.name,
176 Operator::Equals,
177 CompareTarget::Indicator(other.name),
178 ))
179 }
180
181 pub fn is_between(self, lower: f64, upper: f64) -> ConditionNode {
183 ConditionNode::Condition(Condition::new(
184 self.name,
185 Operator::IsBetween,
186 CompareTarget::Range(lower, upper),
187 ))
188 }
189
190 pub fn is_rising(self, bars: u32) -> ConditionNode {
192 ConditionNode::Condition(Condition::new(
193 self.name,
194 Operator::IsRising(bars),
195 CompareTarget::None,
196 ))
197 }
198
199 pub fn is_falling(self, bars: u32) -> ConditionNode {
201 ConditionNode::Condition(Condition::new(
202 self.name,
203 Operator::IsFalling(bars),
204 CompareTarget::None,
205 ))
206 }
207
208 pub fn scaled(self, multiplier: f64) -> ScaledIndicatorRef {
210 ScaledIndicatorRef {
211 name: self.name,
212 multiplier,
213 }
214 }
215}
216
217#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
219#[derive(Debug, Clone, PartialEq)]
220pub struct ScaledIndicatorRef {
221 pub name: String,
222 pub multiplier: f64,
223}
224
225impl ScaledIndicatorRef {
226 pub fn is_above_value(self, value: f64) -> ConditionNode {
228 ConditionNode::Condition(Condition::new(
229 format!("{}*{}", self.name, self.multiplier),
230 Operator::IsAbove,
231 CompareTarget::Value(value),
232 ))
233 }
234
235 pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
237 ConditionNode::Condition(Condition::new(
238 format!("{}*{}", self.name, self.multiplier),
239 Operator::IsAbove,
240 CompareTarget::Indicator(other.name),
241 ))
242 }
243
244 pub fn is_below_value(self, value: f64) -> ConditionNode {
246 ConditionNode::Condition(Condition::new(
247 format!("{}*{}", self.name, self.multiplier),
248 Operator::IsBelow,
249 CompareTarget::Value(value),
250 ))
251 }
252
253 pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
255 ConditionNode::Condition(Condition::new(
256 format!("{}*{}", self.name, self.multiplier),
257 Operator::IsBelow,
258 CompareTarget::Indicator(other.name),
259 ))
260 }
261}
262
263pub fn all_of(conditions: Vec<ConditionNode>) -> ConditionNode {
265 ConditionNode::Group(ConditionGroup::AllOf(conditions))
266}
267
268pub fn any_of(conditions: Vec<ConditionNode>) -> ConditionNode {
270 ConditionNode::Group(ConditionGroup::AnyOf(conditions))
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn indicator_ref_convenience_constructors() {
279 let sma = IndicatorRef::sma(20);
280 assert_eq!(sma.name, "sma_20");
281
282 let ema = IndicatorRef::ema(14);
283 assert_eq!(ema.name, "ema_14");
284
285 let rsi = IndicatorRef::rsi(14);
286 assert_eq!(rsi.name, "rsi_14");
287
288 let obv = IndicatorRef::obv();
289 assert_eq!(obv.name, "obv");
290 }
291
292 #[test]
293 fn condition_building() {
294 let sma = IndicatorRef::sma(20);
295 let cond = sma.crosses_above(100.0);
296 assert!(matches!(cond, ConditionNode::Condition(_)));
297 }
298
299 #[test]
300 fn condition_grouping() {
301 let sma = IndicatorRef::sma(20);
302 let rsi = IndicatorRef::rsi(14);
303
304 let cond1 = sma.is_above(100.0);
305 let cond2 = rsi.is_below(70.0);
306
307 let group = all_of(vec![cond1, cond2]);
308 assert!(matches!(
309 group,
310 ConditionNode::Group(ConditionGroup::AllOf(_))
311 ));
312 }
313
314 #[test]
315 fn scaled_indicator_ref() {
316 let atr = IndicatorRef::atr(14);
317 let scaled = atr.scaled(2.0);
318 assert_eq!(scaled.multiplier, 2.0);
319 }
320
321 #[test]
322 fn scaled_is_above_indicator_has_correct_semantics() {
323 let cond = IndicatorRef::atr(14)
325 .scaled(2.0)
326 .is_above_indicator(IndicatorRef::new("price"));
327 match cond {
328 ConditionNode::Condition(c) => {
329 assert_eq!(c.left, "atr_14*2");
330 assert_eq!(c.operator, Operator::IsAbove);
331 assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
332 }
333 _ => panic!("expected Condition"),
334 }
335 }
336
337 #[test]
338 fn scaled_is_below_indicator_has_correct_semantics() {
339 let cond = IndicatorRef::atr(14)
341 .scaled(1.5)
342 .is_below_indicator(IndicatorRef::new("price"));
343 match cond {
344 ConditionNode::Condition(c) => {
345 assert_eq!(c.left, "atr_14*1.5");
346 assert_eq!(c.operator, Operator::IsBelow);
347 assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
348 }
349 _ => panic!("expected Condition"),
350 }
351 }
352
353 #[test]
354 fn scaled_is_above_value_has_correct_semantics() {
355 let cond = IndicatorRef::atr(14).scaled(2.0).is_above_value(50.0);
356 match cond {
357 ConditionNode::Condition(c) => {
358 assert_eq!(c.left, "atr_14*2");
359 assert_eq!(c.operator, Operator::IsAbove);
360 assert_eq!(c.right, CompareTarget::Value(50.0));
361 }
362 _ => panic!("expected Condition"),
363 }
364 }
365}