indicators/momentum/
schaff_trend_cycle.rs1use std::collections::HashMap;
27
28use crate::error::IndicatorError;
29use crate::functions::{self};
30use crate::indicator::{Indicator, IndicatorOutput};
31use crate::registry::param_usize;
32use crate::types::Candle;
33
34#[derive(Debug, Clone)]
35pub struct StcParams {
36 pub short_ema: usize,
37 pub long_ema: usize,
38 pub stoch_period: usize,
39 pub signal_period: usize,
40}
41impl Default for StcParams {
42 fn default() -> Self {
43 Self {
44 short_ema: 12,
45 long_ema: 26,
46 stoch_period: 10,
47 signal_period: 3,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
53pub struct SchaffTrendCycle {
54 pub params: StcParams,
55}
56
57impl SchaffTrendCycle {
58 pub fn new(params: StcParams) -> Self {
59 Self { params }
60 }
61 pub fn default() -> Self {
62 Self::new(StcParams::default())
63 }
64}
65
66impl Indicator for SchaffTrendCycle {
67 fn name(&self) -> &str {
68 "SchaffTrendCycle"
69 }
70
71 fn required_len(&self) -> usize {
72 self.params.long_ema + self.params.stoch_period + self.params.signal_period
73 }
74
75 fn required_columns(&self) -> &[&'static str] {
76 &["close"]
77 }
78
79 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
81 self.check_len(candles)?;
82
83 let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
84 let n = close.len();
85
86 let short_e = functions::ema(&close, self.params.short_ema)?;
88 let long_e = functions::ema(&close, self.params.long_ema)?;
89 let macd_line: Vec<f64> = (0..n).map(|i| {
90 if short_e[i].is_nan() || long_e[i].is_nan() { f64::NAN }
91 else { short_e[i] - long_e[i] }
92 }).collect();
93
94 let macd_sig = functions::ema(&macd_line, 9)?;
96 let macd_diff: Vec<f64> = (0..n).map(|i| {
97 if macd_line[i].is_nan() || macd_sig[i].is_nan() { f64::NAN }
98 else { macd_line[i] - macd_sig[i] }
99 }).collect();
100
101 let sp = self.params.stoch_period;
103 let mut stc = vec![f64::NAN; n];
104 for i in (sp - 1)..n {
105 let window = &macd_diff[(i + 1 - sp)..=i];
106 let min_d = window.iter().cloned().fold(f64::INFINITY, f64::min);
107 let max_d = window.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
108 let range = max_d - min_d;
109 if macd_diff[i].is_nan() || range == 0.0 {
110 stc[i] = f64::NAN;
111 } else {
112 stc[i] = 100.0 * (macd_diff[i] - min_d) / range;
113 }
114 }
115
116 let values = if self.params.signal_period > 0 {
118 functions::ema(&stc, self.params.signal_period)?
119 } else {
120 stc
121 };
122
123 Ok(IndicatorOutput::from_pairs([("STC".to_string(), values)]))
124 }
125}
126
127pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
128 Ok(Box::new(SchaffTrendCycle::new(StcParams {
129 short_ema: param_usize(params, "short_ema", 12)?,
130 long_ema: param_usize(params, "long_ema", 26)?,
131 stoch_period: param_usize(params, "stoch_period", 10)?,
132 signal_period: param_usize(params, "signal_period", 3)?,
133 })))
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 fn candles(n: usize) -> Vec<Candle> {
141 (0..n).map(|i| Candle {
142 time: i as i64, open: 10.0, high: 10.0 + (i % 5) as f64,
143 low: 10.0 - (i % 3) as f64, close: 10.0 + (i as f64).sin(),
144 volume: 100.0,
145 }).collect()
146 }
147
148 #[test]
149 fn stc_output_column() {
150 let p = StcParams::default();
151 let needed = p.long_ema + p.stoch_period + p.signal_period + 5;
152 let out = SchaffTrendCycle::default().calculate(&candles(needed)).unwrap();
153 assert!(out.get("STC").is_some());
154 }
155
156 #[test]
157 fn factory_creates_stc() {
158 assert_eq!(factory(&HashMap::new()).unwrap().name(), "SchaffTrendCycle");
159 }
160}