Skip to main content

indicators/
indicator.rs

1//! Core `Indicator` trait and `IndicatorOutput` type.
2//!
3//! Mirrors `indicators/base.py`:
4//! - `Indicator` ↔ `class Indicator(Component, ABC)`
5//! - `IndicatorOutput` ↔ `pd.DataFrame` return value
6//! - `required_columns()` ↔ `@classmethod required_columns()`
7//! - `calculate()` ↔ `def calculate(self, data, price_column)`
8//!
9//! Every indicator in `trend/`, `momentum/`, `volume/`, and `other/`
10//! must implement this trait.  The registry (`registry.rs`) stores
11//! `Box<dyn Indicator>` values so they can be created by name at
12//! runtime, matching Python's `@register_indicator` / `IndicatorRegistry`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::types::Candle;
18
19// ── IndicatorOutput ───────────────────────────────────────────────────────────
20
21/// Named column output, analogous to `pd.DataFrame` returned by Python `calculate()`.
22///
23/// Keys are column names such as `"SMA_20"`, `"MACD_line"`, `"ATR_14"`.
24/// Values are aligned `Vec<f64>` of the same length as the input slice.
25/// Leading warm-up entries are `f64::NAN`.
26#[derive(Debug, Clone, Default)]
27pub struct IndicatorOutput {
28    columns: HashMap<String, Vec<f64>>,
29}
30
31impl IndicatorOutput {
32    /// Create an empty output.
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Insert a named column.
38    pub fn insert(&mut self, name: impl Into<String>, values: Vec<f64>) {
39        self.columns.insert(name.into(), values);
40    }
41
42    /// Build from an iterator of `(name, values)` pairs.
43    pub fn from_pairs<K: Into<String>>(pairs: impl IntoIterator<Item = (K, Vec<f64>)>) -> Self {
44        let mut out = Self::new();
45        for (k, v) in pairs {
46            out.insert(k, v);
47        }
48        out
49    }
50
51    /// Get the values for a named column.
52    pub fn get(&self, name: &str) -> Option<&[f64]> {
53        self.columns.get(name).map(Vec::as_slice)
54    }
55
56    /// Get the *last* (most recent) value of a named column, skipping `NaN`.
57    ///
58    /// Mirrors Python's `indicator.get_value(-1)`.
59    pub fn latest(&self, name: &str) -> Option<f64> {
60        self.columns
61            .get(name)?
62            .iter()
63            .rev()
64            .find(|v| !v.is_nan())
65            .copied()
66    }
67
68    /// All column names present in this output.
69    pub fn columns(&self) -> impl Iterator<Item = &str> {
70        self.columns.keys().map(String::as_str)
71    }
72
73    /// Number of rows (length of any column; all columns are guaranteed equal length).
74    pub fn len(&self) -> usize {
75        self.columns.values().next().map_or(0, Vec::len)
76    }
77
78    pub fn is_empty(&self) -> bool {
79        self.len() == 0
80    }
81
82    /// Consume into the underlying map.
83    pub fn into_inner(self) -> HashMap<String, Vec<f64>> {
84        self.columns
85    }
86}
87
88// ── Indicator trait ───────────────────────────────────────────────────────────
89
90/// The core trait every indicator must implement.
91///
92/// Analogous to `indicators/base.py :: class Indicator(ABC)`.
93///
94/// # Implementing an indicator
95///
96/// ```rust,ignore
97/// use crate::indicator::{Indicator, IndicatorOutput};
98/// use crate::error::IndicatorError;
99/// use crate::types::Candle;
100///
101/// pub struct Sma {
102///     pub period: usize,
103///     pub column: PriceColumn,
104/// }
105///
106/// impl Indicator for Sma {
107///     fn name(&self) -> &str { "SMA" }
108///
109///     fn required_len(&self) -> usize { self.period }
110///
111///     fn required_columns(&self) -> &[&str] { &["close"] }
112///
113///     fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
114///         // port Python logic here
115///         todo!()
116///     }
117/// }
118/// ```
119pub trait Indicator: Send + Sync + std::fmt::Debug {
120    /// Short canonical name, e.g. `"SMA"`, `"RSI"`, `"MACD"`.
121    fn name(&self) -> &'static str;
122
123    /// Minimum number of candles required before output is non-`NaN`.
124    /// Mirrors Python's implicit warm-up period used for validation.
125    fn required_len(&self) -> usize;
126
127    /// Which OHLCV fields this indicator reads.
128    ///
129    /// Mirrors `@classmethod required_columns()` in Python.
130    /// Valid values: `"open"`, `"high"`, `"low"`, `"close"`, `"volume"`.
131    fn required_columns(&self) -> &[&'static str];
132
133    /// Compute the indicator over a full candle slice (batch mode).
134    ///
135    /// Mirrors `def calculate(self, data: pd.DataFrame, price_column) -> pd.DataFrame`.
136    ///
137    /// - Returns `IndicatorOutput` with one or more named columns.
138    /// - Leading warm-up rows should be `f64::NAN`.
139    /// - Returns `Err(IndicatorError::InsufficientData)` if `candles.len() < required_len()`.
140    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError>;
141
142    /// Validate that enough data was supplied, returning a descriptive error if not.
143    ///
144    /// Call this at the top of every `calculate()` implementation.
145    fn check_len(&self, candles: &[Candle]) -> Result<(), IndicatorError> {
146        let required = self.required_len();
147        if candles.len() < required {
148            Err(IndicatorError::InsufficientData {
149                required,
150                available: candles.len(),
151            })
152        } else {
153            Ok(())
154        }
155    }
156}
157
158// ── PriceColumn helper ────────────────────────────────────────────────────────
159
160/// Which single OHLCV field to extract as a price series.
161///
162/// Mirrors the `column` / `price_column` parameter in Python indicators.
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
164pub enum PriceColumn {
165    Open,
166    High,
167    Low,
168    #[default]
169    Close,
170    Volume,
171    /// `(High + Low + Close) / 3`
172    TypicalPrice,
173    /// `(High + Low) / 2`
174    HL2,
175}
176
177impl PriceColumn {
178    /// Extract the column as a `Vec<f64>` from a candle slice.
179    pub fn extract(self, candles: &[Candle]) -> Vec<f64> {
180        candles
181            .iter()
182            .map(|c| match self {
183                PriceColumn::Open => c.open,
184                PriceColumn::High => c.high,
185                PriceColumn::Low => c.low,
186                PriceColumn::Close => c.close,
187                PriceColumn::Volume => c.volume,
188                PriceColumn::TypicalPrice => (c.high + c.low + c.close) / 3.0,
189                PriceColumn::HL2 => (c.high + c.low) / 2.0,
190            })
191            .collect()
192    }
193
194    pub fn as_str(self) -> &'static str {
195        match self {
196            PriceColumn::Open => "open",
197            PriceColumn::High => "high",
198            PriceColumn::Low => "low",
199            PriceColumn::Close => "close",
200            PriceColumn::Volume => "volume",
201            PriceColumn::TypicalPrice => "typical_price",
202            PriceColumn::HL2 => "hl2",
203        }
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn indicator_output_insert_and_get() {
213        let mut out = IndicatorOutput::new();
214        out.insert("SMA_20", vec![f64::NAN, f64::NAN, 10.0, 12.0]);
215        assert_eq!(out.len(), 4);
216        assert_eq!(out.latest("SMA_20"), Some(12.0));
217        assert!(out.get("MISSING").is_none());
218    }
219
220    #[test]
221    fn indicator_output_from_pairs() {
222        let out = IndicatorOutput::from_pairs([
223            ("MACD_line", vec![1.0, 2.0]),
224            ("MACD_signal", vec![0.5, 1.5]),
225        ]);
226        assert!(out.get("MACD_line").is_some());
227        assert!(out.get("MACD_signal").is_some());
228    }
229
230    #[test]
231    fn price_column_extract() {
232        let candle = Candle {
233            time: 0,
234            open: 1.0,
235            high: 4.0,
236            low: 2.0,
237            close: 3.0,
238            volume: 100.0,
239        };
240        let candles = vec![candle];
241        assert_eq!(PriceColumn::Close.extract(&candles), vec![3.0]);
242        assert_eq!(
243            PriceColumn::TypicalPrice.extract(&candles),
244            vec![(4.0 + 2.0 + 3.0) / 3.0]
245        );
246    }
247}