indicators/momentum/
schaff_trend_cycle.rs1use std::collections::HashMap;
9
10use crate::error::IndicatorError;
11use crate::functions::{self};
12use crate::indicator::{Indicator, IndicatorOutput};
13use crate::registry::param_usize;
14use crate::types::Candle;
15
16#[derive(Debug, Clone)]
17pub struct StcParams {
18 pub short_ema: usize,
19 pub long_ema: usize,
20 pub stoch_period: usize,
21 pub signal_period: usize,
22}
23impl Default for StcParams {
24 fn default() -> Self {
25 Self {
26 short_ema: 12,
27 long_ema: 26,
28 stoch_period: 10,
29 signal_period: 3,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35pub struct SchaffTrendCycle {
36 pub params: StcParams,
37}
38
39impl SchaffTrendCycle {
40 pub fn new(params: StcParams) -> Self {
41 Self { params }
42 }
43}
44
45impl Default for SchaffTrendCycle {
46 fn default() -> Self {
47 Self::new(StcParams::default())
48 }
49}
50
51impl Indicator for SchaffTrendCycle {
52 fn name(&self) -> &'static str {
53 "SchaffTrendCycle"
54 }
55
56 fn required_len(&self) -> usize {
57 self.params.long_ema
62 }
63
64 fn required_columns(&self) -> &[&'static str] {
65 &["close"]
66 }
67
68 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
82 self.check_len(candles)?;
83
84 let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
85 let n = close.len();
86
87 let short_e = functions::ema(&close, self.params.short_ema)?;
89 let long_e = functions::ema(&close, self.params.long_ema)?;
90 let macd_line: Vec<f64> = (0..n)
91 .map(|i| {
92 if short_e[i].is_nan() || long_e[i].is_nan() {
93 f64::NAN
94 } else {
95 short_e[i] - long_e[i]
96 }
97 })
98 .collect();
99
100 let macd_sig = functions::ema_nan_aware(&macd_line, 9)?;
104 let macd_diff: Vec<f64> = (0..n)
105 .map(|i| {
106 if macd_line[i].is_nan() || macd_sig[i].is_nan() {
107 f64::NAN
108 } else {
109 macd_line[i] - macd_sig[i]
110 }
111 })
112 .collect();
113
114 let sp = self.params.stoch_period;
116 let mut stc = vec![f64::NAN; n];
117 for i in (sp - 1)..n {
118 let window = &macd_diff[(i + 1 - sp)..=i];
119 let min_d = window.iter().copied().fold(f64::INFINITY, f64::min);
120 let max_d = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
121 let range = max_d - min_d;
122 if macd_diff[i].is_nan() || range == 0.0 {
123 stc[i] = f64::NAN;
124 } else {
125 stc[i] = 100.0 * (macd_diff[i] - min_d) / range;
126 }
127 }
128
129 let values = if self.params.signal_period > 0 {
133 functions::ema_nan_aware(&stc, self.params.signal_period)?
134 } else {
135 stc
136 };
137
138 Ok(IndicatorOutput::from_pairs([("STC".to_string(), values)]))
139 }
140}
141
142pub fn factory<S: ::std::hash::BuildHasher>(
143 params: &HashMap<String, String, S>,
144) -> Result<Box<dyn Indicator>, IndicatorError> {
145 Ok(Box::new(SchaffTrendCycle::new(StcParams {
146 short_ema: param_usize(params, "short_ema", 12)?,
147 long_ema: param_usize(params, "long_ema", 26)?,
148 stoch_period: param_usize(params, "stoch_period", 10)?,
149 signal_period: param_usize(params, "signal_period", 3)?,
150 })))
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 fn candles(n: usize) -> Vec<Candle> {
158 (0..n)
159 .map(|i| Candle {
160 time: i64::try_from(i).expect("time index fits i64"),
161 open: 10.0,
162 high: 10.0 + (i % 5) as f64,
163 low: 10.0 - (i % 3) as f64,
164 close: 10.0 + (i as f64).sin(),
165 volume: 100.0,
166 })
167 .collect()
168 }
169
170 #[test]
171 fn stc_output_column() {
172 let p = StcParams::default();
173 let needed = p.long_ema + p.stoch_period + p.signal_period + 5;
174 let out = SchaffTrendCycle::default()
175 .calculate(&candles(needed))
176 .unwrap();
177 assert!(out.get("STC").is_some());
178 }
179
180 #[test]
181 fn factory_creates_stc() {
182 assert_eq!(factory(&HashMap::new()).unwrap().name(), "SchaffTrendCycle");
183 }
184}