Skip to main content

indicators/
indicator.rs

1//! Core `Indicator` trait and `IndicatorOutput` type.
2//!
3//! Mirrors `indicators/base.py`:
4//! - `Indicator` ↔ `class Indicator(Component, ABC)`
5//! - `IndicatorOutput` ↔ `pd.DataFrame` return value
6//! - `required_columns()` ↔ `@classmethod required_columns()`
7//! - `calculate()` ↔ `def calculate(self, data, price_column)`
8//!
9//! Every indicator in `trend/`, `momentum/`, `volume/`, and `other/`
10//! must implement this trait.  The registry (`registry.rs`) stores
11//! `Box<dyn Indicator>` values so they can be created by name at
12//! runtime, matching Python's `@register_indicator` / `IndicatorRegistry`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::types::Candle;
18
19// ── IndicatorOutput ───────────────────────────────────────────────────────────
20
21/// Named column output, analogous to `pd.DataFrame` returned by Python `calculate()`.
22///
23/// Keys are column names such as `"SMA_20"`, `"MACD_line"`, `"ATR_14"`.
24/// Values are aligned `Vec<f64>` of the same length as the input slice.
25/// Leading warm-up entries are `f64::NAN`.
26#[derive(Debug, Clone, Default)]
27pub struct IndicatorOutput {
28    columns: HashMap<String, Vec<f64>>,
29}
30
31impl IndicatorOutput {
32    /// Create an empty output.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Insert a named column.
38    pub fn insert(&mut self, name: impl Into<String>, values: Vec<f64>) {
39        self.columns.insert(name.into(), values);
40    }
41
42    /// Build from an iterator of `(name, values)` pairs.
43    pub fn from_pairs<K: Into<String>>(pairs: impl IntoIterator<Item = (K, Vec<f64>)>) -> Self {
44        let mut out = Self::new();
45        for (k, v) in pairs {
46            out.insert(k, v);
47        }
48        out
49    }
50
51    /// Get the values for a named column.
52    pub fn get(&self, name: &str) -> Option<&[f64]> {
53        self.columns.get(name).map(Vec::as_slice)
54    }
55
56    /// Get the *last* (most recent) value of a named column, skipping `NaN`.
57    ///
58    /// Mirrors Python's `indicator.get_value(-1)`.
59    pub fn latest(&self, name: &str) -> Option<f64> {
60        self.columns
61            .get(name)?
62            .iter()
63            .rev()
64            .find(|v| !v.is_nan())
65            .copied()
66    }
67
68    /// All column names present in this output.
69    pub fn columns(&self) -> impl Iterator<Item = &str> {
70        self.columns.keys().map(String::as_str)
71    }
72
73    /// Number of rows (length of any column; all columns are guaranteed equal length).
74    pub fn len(&self) -> usize {
75        self.columns.values().next().map_or(0, Vec::len)
76    }
77
78    pub fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81
82    /// Consume into the underlying map.
83    pub fn into_inner(self) -> HashMap<String, Vec<f64>> {
84        self.columns
85    }
86}
87
88// ── Indicator trait ───────────────────────────────────────────────────────────
89
90/// The core trait every indicator must implement.
91///
92/// Analogous to `indicators/base.py :: class Indicator(ABC)`.
93///
94/// # Implementing an indicator
95///
96/// ```rust,ignore
97/// use crate::indicator::{Indicator, IndicatorOutput};
98/// use crate::error::IndicatorError;
99/// use crate::types::Candle;
100///
101/// pub struct Sma {
102///     pub period: usize,
103///     pub column: PriceColumn,
104/// }
105///
106/// impl Indicator for Sma {
107///     fn name(&self) -> &str { "SMA" }
108///
109///     fn required_len(&self) -> usize { self.period }
110///
111///     fn required_columns(&self) -> &[&str] { &["close"] }
112///
113///     fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
114///     }
115/// }
116/// ```
117pub trait Indicator: Send + Sync + std::fmt::Debug {
118    /// Short canonical name, e.g. `"SMA"`, `"RSI"`, `"MACD"`.
119    fn name(&self) -> &'static str;
120
121    /// Minimum number of candles required before output is non-`NaN`.
122    /// Mirrors Python's implicit warm-up period used for validation.
123    fn required_len(&self) -> usize;
124
125    /// Which OHLCV fields this indicator reads.
126    ///
127    /// Mirrors `@classmethod required_columns()` in Python.
128    /// Valid values: `"open"`, `"high"`, `"low"`, `"close"`, `"volume"`.
129    fn required_columns(&self) -> &[&'static str];
130
131    /// Compute the indicator over a full candle slice (batch mode).
132    ///
133    /// Mirrors `def calculate(self, data: pd.DataFrame, price_column) -> pd.DataFrame`.
134    ///
135    /// - Returns `IndicatorOutput` with one or more named columns.
136    /// - Leading warm-up rows should be `f64::NAN`.
137    /// - Returns `Err(IndicatorError::InsufficientData)` if `candles.len() < required_len()`.
138    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError>;
139
140    /// Validate that enough data was supplied, returning a descriptive error if not.
141    ///
142    /// Call this at the top of every `calculate()` implementation.
143    fn check_len(&self, candles: &[Candle]) -> Result<(), IndicatorError> {
144        let required = self.required_len();
145        if candles.len() < required {
146            Err(IndicatorError::InsufficientData {
147                required,
148                available: candles.len(),
149            })
150        } else {
151            Ok(())
152        }
153    }
154}
155
156// ── PriceColumn helper ────────────────────────────────────────────────────────
157
158/// Which single OHLCV field to extract as a price series.
159///
160/// Mirrors the `column` / `price_column` parameter in Python indicators.
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
162pub enum PriceColumn {
163    Open,
164    High,
165    Low,
166    #[default]
167    Close,
168    Volume,
169    /// `(High + Low + Close) / 3`
170    TypicalPrice,
171    /// `(High + Low) / 2`
172    HL2,
173}
174
175impl PriceColumn {
176    /// Extract the column as a `Vec<f64>` from a candle slice.
177    pub fn extract(self, candles: &[Candle]) -> Vec<f64> {
178        candles
179            .iter()
180            .map(|c| match self {
181                PriceColumn::Open => c.open,
182                PriceColumn::High => c.high,
183                PriceColumn::Low => c.low,
184                PriceColumn::Close => c.close,
185                PriceColumn::Volume => c.volume,
186                PriceColumn::TypicalPrice => (c.high + c.low + c.close) / 3.0,
187                PriceColumn::HL2 => (c.high + c.low) / 2.0,
188            })
189            .collect()
190    }
191
192    pub fn as_str(self) -> &'static str {
193        match self {
194            PriceColumn::Open => "open",
195            PriceColumn::High => "high",
196            PriceColumn::Low => "low",
197            PriceColumn::Close => "close",
198            PriceColumn::Volume => "volume",
199            PriceColumn::TypicalPrice => "typical_price",
200            PriceColumn::HL2 => "hl2",
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn indicator_output_insert_and_get() {
211        let mut out = IndicatorOutput::new();
212        out.insert("SMA_20", vec![f64::NAN, f64::NAN, 10.0, 12.0]);
213        assert_eq!(out.len(), 4);
214        assert_eq!(out.latest("SMA_20"), Some(12.0));
215        assert!(out.get("MISSING").is_none());
216    }
217
218    #[test]
219    fn indicator_output_from_pairs() {
220        let out = IndicatorOutput::from_pairs([
221            ("MACD_line", vec![1.0, 2.0]),
222            ("MACD_signal", vec![0.5, 1.5]),
223        ]);
224        assert!(out.get("MACD_line").is_some());
225        assert!(out.get("MACD_signal").is_some());
226    }
227
228    #[test]
229    fn price_column_extract() {
230        let candle = Candle {
231            time: 0,
232            open: 1.0,
233            high: 4.0,
234            low: 2.0,
235            close: 3.0,
236            volume: 100.0,
237        };
238        let candles = vec![candle];
239        assert_eq!(PriceColumn::Close.extract(&candles), vec![3.0]);
240        assert_eq!(
241            PriceColumn::TypicalPrice.extract(&candles),
242            vec![(4.0 + 2.0 + 3.0) / 3.0]
243        );
244    }
245}