1use super::types::{CompareTarget, Condition, ConditionGroup, ConditionNode, Operator};
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8#[derive(Debug, Clone, PartialEq)]
9pub struct IndicatorRef {
10 pub name: String,
11}
12
13impl IndicatorRef {
14 pub fn new(name: impl Into<String>) -> Self {
16 Self { name: name.into() }
17 }
18
19 pub fn sma(period: usize) -> Self {
21 Self::new(format!("sma{period}"))
22 }
23
24 pub fn ema(period: usize) -> Self {
26 Self::new(format!("ema{period}"))
27 }
28
29 pub fn macd(fast: usize, slow: usize, signal: usize) -> Self {
31 Self::new(format!("macd_{fast}_{slow}_{signal}_line"))
32 }
33
34 pub fn macd_signal(fast: usize, slow: usize, signal: usize) -> Self {
36 Self::new(format!("macd_{fast}_{slow}_{signal}_signal"))
37 }
38
39 pub fn rsi(period: usize) -> Self {
41 Self::new(format!("rsi{period}"))
42 }
43
44 pub fn stoch_k(k_period: usize, d_period: usize) -> Self {
46 Self::new(format!("stoch_{k_period}_{d_period}_k"))
47 }
48
49 pub fn stoch_d(k_period: usize, d_period: usize) -> Self {
51 Self::new(format!("stoch_{k_period}_{d_period}_d"))
52 }
53
54 pub fn bb_upper(period: usize, std_dev: f64) -> Self {
56 Self::new(format!("bb_{period}_{std_dev}_upper"))
57 }
58
59 pub fn bb_middle(period: usize, std_dev: f64) -> Self {
61 Self::new(format!("bb_{period}_{std_dev}_middle"))
62 }
63
64 pub fn bb_lower(period: usize, std_dev: f64) -> Self {
66 Self::new(format!("bb_{period}_{std_dev}_lower"))
67 }
68
69 pub fn atr(period: usize) -> Self {
71 Self::new(format!("atr{period}"))
72 }
73
74 pub fn volume_sma(period: usize) -> Self {
76 Self::new(format!("volume_sma_{period}"))
77 }
78
79 pub fn obv() -> Self {
81 Self::new("obv")
82 }
83
84 pub fn pivot_points() -> Self {
86 Self::new("pivot_points")
87 }
88
89 pub fn adx(period: usize) -> Self {
91 Self::new(format!("adx{period}"))
92 }
93
94 pub fn wma(period: usize) -> Self {
96 Self::new(format!("wma{period}"))
97 }
98
99 pub fn dema(period: usize) -> Self {
101 Self::new(format!("dema{period}"))
102 }
103
104 pub fn tema(period: usize) -> Self {
106 Self::new(format!("tema{period}"))
107 }
108
109 pub fn cci(period: usize) -> Self {
111 Self::new(format!("cci{period}"))
112 }
113
114 pub fn williams_r(period: usize) -> Self {
116 Self::new(format!("williams_r{period}"))
117 }
118
119 pub fn roc(period: usize) -> Self {
121 Self::new(format!("roc{period}"))
122 }
123
124 pub fn stddev(period: usize) -> Self {
126 Self::new(format!("stddev{period}"))
127 }
128
129 pub fn crosses_above(self, value: f64) -> ConditionNode {
133 ConditionNode::Condition(Condition::new(
134 self.name,
135 Operator::CrossesAbove,
136 CompareTarget::Value(value),
137 ))
138 }
139
140 pub fn crosses_above_indicator(self, other: IndicatorRef) -> ConditionNode {
142 ConditionNode::Condition(Condition::new(
143 self.name,
144 Operator::CrossesAbove,
145 CompareTarget::Indicator(other.name),
146 ))
147 }
148
149 pub fn crosses_below(self, value: f64) -> ConditionNode {
151 ConditionNode::Condition(Condition::new(
152 self.name,
153 Operator::CrossesBelow,
154 CompareTarget::Value(value),
155 ))
156 }
157
158 pub fn crosses_below_indicator(self, other: IndicatorRef) -> ConditionNode {
160 ConditionNode::Condition(Condition::new(
161 self.name,
162 Operator::CrossesBelow,
163 CompareTarget::Indicator(other.name),
164 ))
165 }
166
167 pub fn is_above(self, value: f64) -> ConditionNode {
169 ConditionNode::Condition(Condition::new(
170 self.name,
171 Operator::IsAbove,
172 CompareTarget::Value(value),
173 ))
174 }
175
176 pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
178 ConditionNode::Condition(Condition::new(
179 self.name,
180 Operator::IsAbove,
181 CompareTarget::Indicator(other.name),
182 ))
183 }
184
185 pub fn is_below(self, value: f64) -> ConditionNode {
187 ConditionNode::Condition(Condition::new(
188 self.name,
189 Operator::IsBelow,
190 CompareTarget::Value(value),
191 ))
192 }
193
194 pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
196 ConditionNode::Condition(Condition::new(
197 self.name,
198 Operator::IsBelow,
199 CompareTarget::Indicator(other.name),
200 ))
201 }
202
203 pub fn equals(self, value: f64) -> ConditionNode {
205 ConditionNode::Condition(Condition::new(
206 self.name,
207 Operator::Equals,
208 CompareTarget::Value(value),
209 ))
210 }
211
212 pub fn equals_indicator(self, other: IndicatorRef) -> ConditionNode {
214 ConditionNode::Condition(Condition::new(
215 self.name,
216 Operator::Equals,
217 CompareTarget::Indicator(other.name),
218 ))
219 }
220
221 pub fn is_between(self, lower: f64, upper: f64) -> ConditionNode {
223 ConditionNode::Condition(Condition::new(
224 self.name,
225 Operator::IsBetween,
226 CompareTarget::Range(lower, upper),
227 ))
228 }
229
230 pub fn is_rising(self, bars: u32) -> ConditionNode {
232 ConditionNode::Condition(Condition::new(
233 self.name,
234 Operator::IsRising(bars),
235 CompareTarget::None,
236 ))
237 }
238
239 pub fn is_falling(self, bars: u32) -> ConditionNode {
241 ConditionNode::Condition(Condition::new(
242 self.name,
243 Operator::IsFalling(bars),
244 CompareTarget::None,
245 ))
246 }
247
248 pub fn scaled(self, multiplier: f64) -> ScaledIndicatorRef {
250 ScaledIndicatorRef {
251 name: self.name,
252 multiplier,
253 }
254 }
255}
256
257#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
259#[derive(Debug, Clone, PartialEq)]
260pub struct ScaledIndicatorRef {
261 pub name: String,
262 pub multiplier: f64,
263}
264
265impl ScaledIndicatorRef {
266 pub fn is_above_value(self, value: f64) -> ConditionNode {
268 ConditionNode::Condition(Condition::new(
269 format!("{}*{}", self.name, self.multiplier),
270 Operator::IsAbove,
271 CompareTarget::Value(value),
272 ))
273 }
274
275 pub fn is_above_indicator(self, other: IndicatorRef) -> ConditionNode {
277 ConditionNode::Condition(Condition::new(
278 format!("{}*{}", self.name, self.multiplier),
279 Operator::IsAbove,
280 CompareTarget::Indicator(other.name),
281 ))
282 }
283
284 pub fn is_below_value(self, value: f64) -> ConditionNode {
286 ConditionNode::Condition(Condition::new(
287 format!("{}*{}", self.name, self.multiplier),
288 Operator::IsBelow,
289 CompareTarget::Value(value),
290 ))
291 }
292
293 pub fn is_below_indicator(self, other: IndicatorRef) -> ConditionNode {
295 ConditionNode::Condition(Condition::new(
296 format!("{}*{}", self.name, self.multiplier),
297 Operator::IsBelow,
298 CompareTarget::Indicator(other.name),
299 ))
300 }
301}
302
303pub fn all_of(conditions: Vec<ConditionNode>) -> ConditionNode {
305 ConditionNode::Group(ConditionGroup::AllOf(conditions))
306}
307
308pub fn any_of(conditions: Vec<ConditionNode>) -> ConditionNode {
310 ConditionNode::Group(ConditionGroup::AnyOf(conditions))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn indicator_ref_convenience_constructors() {
319 let sma = IndicatorRef::sma(20);
320 assert_eq!(sma.name, "sma20");
321
322 let ema = IndicatorRef::ema(14);
323 assert_eq!(ema.name, "ema14");
324
325 let rsi = IndicatorRef::rsi(14);
326 assert_eq!(rsi.name, "rsi14");
327
328 let obv = IndicatorRef::obv();
329 assert_eq!(obv.name, "obv");
330 }
331
332 #[test]
333 fn batch_a_indicator_ref_convenience_constructors() {
334 let adx = IndicatorRef::adx(14);
335 assert_eq!(adx.name, "adx14");
336
337 let wma = IndicatorRef::wma(20);
338 assert_eq!(wma.name, "wma20");
339
340 let dema = IndicatorRef::dema(10);
341 assert_eq!(dema.name, "dema10");
342
343 let tema = IndicatorRef::tema(10);
344 assert_eq!(tema.name, "tema10");
345
346 let cci = IndicatorRef::cci(20);
347 assert_eq!(cci.name, "cci20");
348
349 let williams_r = IndicatorRef::williams_r(14);
350 assert_eq!(williams_r.name, "williams_r14");
351
352 let roc = IndicatorRef::roc(12);
353 assert_eq!(roc.name, "roc12");
354
355 let stddev = IndicatorRef::stddev(20);
356 assert_eq!(stddev.name, "stddev20");
357 }
358
359 #[test]
360 fn condition_building() {
361 let sma = IndicatorRef::sma(20);
362 let cond = sma.crosses_above(100.0);
363 assert!(matches!(cond, ConditionNode::Condition(_)));
364 }
365
366 #[test]
367 fn condition_grouping() {
368 let sma = IndicatorRef::sma(20);
369 let rsi = IndicatorRef::rsi(14);
370
371 let cond1 = sma.is_above(100.0);
372 let cond2 = rsi.is_below(70.0);
373
374 let group = all_of(vec![cond1, cond2]);
375 assert!(matches!(
376 group,
377 ConditionNode::Group(ConditionGroup::AllOf(_))
378 ));
379 }
380
381 #[test]
382 fn scaled_indicator_ref() {
383 let atr = IndicatorRef::atr(14);
384 let scaled = atr.scaled(2.0);
385 assert_eq!(scaled.multiplier, 2.0);
386 }
387
388 #[test]
389 fn scaled_is_above_indicator_has_correct_semantics() {
390 let cond = IndicatorRef::atr(14)
392 .scaled(2.0)
393 .is_above_indicator(IndicatorRef::new("price"));
394 match cond {
395 ConditionNode::Condition(c) => {
396 assert_eq!(c.left, "atr14*2");
397 assert_eq!(c.operator, Operator::IsAbove);
398 assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
399 }
400 _ => panic!("expected Condition"),
401 }
402 }
403
404 #[test]
405 fn scaled_is_below_indicator_has_correct_semantics() {
406 let cond = IndicatorRef::atr(14)
408 .scaled(1.5)
409 .is_below_indicator(IndicatorRef::new("price"));
410 match cond {
411 ConditionNode::Condition(c) => {
412 assert_eq!(c.left, "atr14*1.5");
413 assert_eq!(c.operator, Operator::IsBelow);
414 assert_eq!(c.right, CompareTarget::Indicator("price".to_string()));
415 }
416 _ => panic!("expected Condition"),
417 }
418 }
419
420 #[test]
421 fn scaled_is_above_value_has_correct_semantics() {
422 let cond = IndicatorRef::atr(14).scaled(2.0).is_above_value(50.0);
423 match cond {
424 ConditionNode::Condition(c) => {
425 assert_eq!(c.left, "atr14*2");
426 assert_eq!(c.operator, Operator::IsAbove);
427 assert_eq!(c.right, CompareTarget::Value(50.0));
428 }
429 _ => panic!("expected Condition"),
430 }
431 }
432}