Skip to main content

indicators/momentum/
schaff_trend_cycle.rs

1//! Schaff Trend Cycle (STC).
2//!
3//! Python source: `indicators/other/schaff_trend_cycle.py :: class SchaffTrendCycle`
4//!
5//! # Python algorithm (to port)
6//! ```python
7//! short_ema  = df["Close"].ewm(span=self.short_ema).mean()
8//! long_ema   = df["Close"].ewm(span=self.long_ema).mean()
9//! macd       = short_ema - long_ema
10//! macd_sig   = macd.ewm(span=9).mean()
11//! macd_diff  = macd - macd_sig
12//!
13//! lowest  = macd_diff.rolling(self.stoch_period).min()
14//! highest = macd_diff.rolling(self.stoch_period).max()
15//! stc     = 100 * (macd_diff - lowest) / (highest - lowest)
16//!
17//! if self.signal_period > 0:
18//!     stc = stc.ewm(span=self.signal_period).mean()
19//! ```
20//!
21//! Readings above 75 → overbought; below 25 → oversold.
22//! Oscillates 0–100.
23//!
24//! Output column: `"STC"`.
25
26use std::collections::HashMap;
27
28use crate::error::IndicatorError;
29use crate::functions::{self};
30use crate::indicator::{Indicator, IndicatorOutput};
31use crate::registry::param_usize;
32use crate::types::Candle;
33
34#[derive(Debug, Clone)]
35pub struct StcParams {
36    pub short_ema: usize,
37    pub long_ema: usize,
38    pub stoch_period: usize,
39    pub signal_period: usize,
40}
41impl Default for StcParams {
42    fn default() -> Self {
43        Self {
44            short_ema: 12,
45            long_ema: 26,
46            stoch_period: 10,
47            signal_period: 3,
48        }
49    }
50}
51
52#[derive(Debug, Clone)]
53pub struct SchaffTrendCycle {
54    pub params: StcParams,
55}
56
57impl SchaffTrendCycle {
58    pub fn new(params: StcParams) -> Self {
59        Self { params }
60    }
61}
62
63impl Default for SchaffTrendCycle {
64    fn default() -> Self {
65        Self::new(StcParams::default())
66    }
67}
68
69impl Indicator for SchaffTrendCycle {
70    fn name(&self) -> &'static str {
71        "SchaffTrendCycle"
72    }
73
74    fn required_len(&self) -> usize {
75        self.params.long_ema + self.params.stoch_period + self.params.signal_period
76    }
77
78    fn required_columns(&self) -> &[&'static str] {
79        &["close"]
80    }
81
82    /// TODO: port Python MACD-then-Stochastic-then-EMA pipeline.
83    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
84        self.check_len(candles)?;
85
86        let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
87        let n = close.len();
88
89        // Step 1: MACD components.
90        let short_e = functions::ema(&close, self.params.short_ema)?;
91        let long_e = functions::ema(&close, self.params.long_ema)?;
92        let macd_line: Vec<f64> = (0..n)
93            .map(|i| {
94                if short_e[i].is_nan() || long_e[i].is_nan() {
95                    f64::NAN
96                } else {
97                    short_e[i] - long_e[i]
98                }
99            })
100            .collect();
101
102        // Signal of MACD (span=9).
103        let macd_sig = functions::ema(&macd_line, 9)?;
104        let macd_diff: Vec<f64> = (0..n)
105            .map(|i| {
106                if macd_line[i].is_nan() || macd_sig[i].is_nan() {
107                    f64::NAN
108                } else {
109                    macd_line[i] - macd_sig[i]
110                }
111            })
112            .collect();
113
114        // Step 2: Stochastic of MACD diff.
115        let sp = self.params.stoch_period;
116        let mut stc = vec![f64::NAN; n];
117        for i in (sp - 1)..n {
118            let window = &macd_diff[(i + 1 - sp)..=i];
119            let min_d = window.iter().copied().fold(f64::INFINITY, f64::min);
120            let max_d = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
121            let range = max_d - min_d;
122            if macd_diff[i].is_nan() || range == 0.0 {
123                stc[i] = f64::NAN;
124            } else {
125                stc[i] = 100.0 * (macd_diff[i] - min_d) / range;
126            }
127        }
128
129        // Step 3: optional EMA smoothing.
130        let values = if self.params.signal_period > 0 {
131            functions::ema(&stc, self.params.signal_period)?
132        } else {
133            stc
134        };
135
136        Ok(IndicatorOutput::from_pairs([("STC".to_string(), values)]))
137    }
138}
139
140pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
141    Ok(Box::new(SchaffTrendCycle::new(StcParams {
142        short_ema: param_usize(params, "short_ema", 12)?,
143        long_ema: param_usize(params, "long_ema", 26)?,
144        stoch_period: param_usize(params, "stoch_period", 10)?,
145        signal_period: param_usize(params, "signal_period", 3)?,
146    })))
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    fn candles(n: usize) -> Vec<Candle> {
154        (0..n)
155            .map(|i| Candle {
156                time: i64::try_from(i).expect("time index fits i64"),
157                open: 10.0,
158                high: 10.0 + (i % 5) as f64,
159                low: 10.0 - (i % 3) as f64,
160                close: 10.0 + (i as f64).sin(),
161                volume: 100.0,
162            })
163            .collect()
164    }
165
166    #[test]
167    fn stc_output_column() {
168        let p = StcParams::default();
169        let needed = p.long_ema + p.stoch_period + p.signal_period + 5;
170        let out = SchaffTrendCycle::default()
171            .calculate(&candles(needed))
172            .unwrap();
173        assert!(out.get("STC").is_some());
174    }
175
176    #[test]
177    fn factory_creates_stc() {
178        assert_eq!(factory(&HashMap::new()).unwrap().name(), "SchaffTrendCycle");
179    }
180}