Skip to main content

indicators/momentum/
schaff_trend_cycle.rs

1//! Schaff Trend Cycle (STC).
2//!
3//! Readings above 75 → overbought; below 25 → oversold.
4//! Oscillates 0–100.
5//!
6//! Output column: `"STC"`.
7
8use std::collections::HashMap;
9
10use crate::error::IndicatorError;
11use crate::functions::{self};
12use crate::indicator::{Indicator, IndicatorOutput};
13use crate::registry::param_usize;
14use crate::types::Candle;
15
16#[derive(Debug, Clone)]
17pub struct StcParams {
18    pub short_ema: usize,
19    pub long_ema: usize,
20    pub stoch_period: usize,
21    pub signal_period: usize,
22}
23impl Default for StcParams {
24    fn default() -> Self {
25        Self {
26            short_ema: 12,
27            long_ema: 26,
28            stoch_period: 10,
29            signal_period: 3,
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35pub struct SchaffTrendCycle {
36    pub params: StcParams,
37}
38
39impl SchaffTrendCycle {
40    pub fn new(params: StcParams) -> Self {
41        Self { params }
42    }
43}
44
45impl Default for SchaffTrendCycle {
46    fn default() -> Self {
47        Self::new(StcParams::default())
48    }
49}
50
51impl Indicator for SchaffTrendCycle {
52    fn name(&self) -> &'static str {
53        "SchaffTrendCycle"
54    }
55
56    fn required_len(&self) -> usize {
57        // The minimum data required for at least some non-NaN output is the
58        // slow EMA warm-up period.  The stochastic and signal stages add
59        // additional latency but do not require extra candles at the input
60        // boundary — they simply produce NaN for their own warm-up bars.
61        self.params.long_ema
62    }
63
64    fn required_columns(&self) -> &[&'static str] {
65        &["close"]
66    }
67
68    /// Ports the three-stage MACD → Stochastic → EMA pipeline.
69    ///
70    /// # EMA seeding difference vs Python
71    /// The Python source calls `ewm(span=...)` with the **default** `adjust=True`,
72    /// which uses decaying weights rather than the recursive formula.
73    /// `functions::ema()` implements the `adjust=False` (recursive) variant.
74    /// For series longer than ~3× the span the two converge; for shorter series
75    /// the warm-up values will differ slightly.
76    ///
77    /// # Zero-range stochastic handling
78    /// When `max_macd_diff == min_macd_diff` across the window, Python produces
79    /// `NaN` via `.replace(0, np.nan)` before division.  The Rust guards the
80    /// same condition with an explicit `range == 0.0` check.
81    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
82        self.check_len(candles)?;
83
84        let close: Vec<f64> = candles.iter().map(|c| c.close).collect();
85        let n = close.len();
86
87        // Step 1: MACD components.
88        let short_e = functions::ema(&close, self.params.short_ema)?;
89        let long_e = functions::ema(&close, self.params.long_ema)?;
90        let macd_line: Vec<f64> = (0..n)
91            .map(|i| {
92                if short_e[i].is_nan() || long_e[i].is_nan() {
93                    f64::NAN
94                } else {
95                    short_e[i] - long_e[i]
96                }
97            })
98            .collect();
99
100        // Signal of MACD (span=9).
101        // macd_line has leading NaN (warm-up from long_ema); use the NaN-aware
102        // EMA so it seeds from the first valid value rather than propagating NaN.
103        let macd_sig = functions::ema_nan_aware(&macd_line, 9)?;
104        let macd_diff: Vec<f64> = (0..n)
105            .map(|i| {
106                if macd_line[i].is_nan() || macd_sig[i].is_nan() {
107                    f64::NAN
108                } else {
109                    macd_line[i] - macd_sig[i]
110                }
111            })
112            .collect();
113
114        // Step 2: Stochastic of MACD diff.
115        let sp = self.params.stoch_period;
116        let mut stc = vec![f64::NAN; n];
117        for i in (sp - 1)..n {
118            let window = &macd_diff[(i + 1 - sp)..=i];
119            let min_d = window.iter().copied().fold(f64::INFINITY, f64::min);
120            let max_d = window.iter().copied().fold(f64::NEG_INFINITY, f64::max);
121            let range = max_d - min_d;
122            if macd_diff[i].is_nan() || range == 0.0 {
123                stc[i] = f64::NAN;
124            } else {
125                stc[i] = 100.0 * (macd_diff[i] - min_d) / range;
126            }
127        }
128
129        // Step 3: optional EMA smoothing.
130        // `stc` has leading NaN from the stochastic warm-up; use the NaN-aware
131        // EMA so it seeds from the first valid stochastic value.
132        let values = if self.params.signal_period > 0 {
133            functions::ema_nan_aware(&stc, self.params.signal_period)?
134        } else {
135            stc
136        };
137
138        Ok(IndicatorOutput::from_pairs([("STC".to_string(), values)]))
139    }
140}
141
142pub fn factory<S: ::std::hash::BuildHasher>(
143    params: &HashMap<String, String, S>,
144) -> Result<Box<dyn Indicator>, IndicatorError> {
145    Ok(Box::new(SchaffTrendCycle::new(StcParams {
146        short_ema: param_usize(params, "short_ema", 12)?,
147        long_ema: param_usize(params, "long_ema", 26)?,
148        stoch_period: param_usize(params, "stoch_period", 10)?,
149        signal_period: param_usize(params, "signal_period", 3)?,
150    })))
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    fn candles(n: usize) -> Vec<Candle> {
158        (0..n)
159            .map(|i| Candle {
160                time: i64::try_from(i).expect("time index fits i64"),
161                open: 10.0,
162                high: 10.0 + (i % 5) as f64,
163                low: 10.0 - (i % 3) as f64,
164                close: 10.0 + (i as f64).sin(),
165                volume: 100.0,
166            })
167            .collect()
168    }
169
170    #[test]
171    fn stc_output_column() {
172        let p = StcParams::default();
173        let needed = p.long_ema + p.stoch_period + p.signal_period + 5;
174        let out = SchaffTrendCycle::default()
175            .calculate(&candles(needed))
176            .unwrap();
177        assert!(out.get("STC").is_some());
178    }
179
180    #[test]
181    fn factory_creates_stc() {
182        assert_eq!(factory(&HashMap::new()).unwrap().name(), "SchaffTrendCycle");
183    }
184}