indicators/indicator.rs
1//! Core `Indicator` trait and `IndicatorOutput` type.
2//!
3//! Mirrors `indicators/base.py`:
4//! - `Indicator` ↔ `class Indicator(Component, ABC)`
5//! - `IndicatorOutput` ↔ `pd.DataFrame` return value
6//! - `required_columns()` ↔ `@classmethod required_columns()`
7//! - `calculate()` ↔ `def calculate(self, data, price_column)`
8//!
9//! Every indicator in `trend/`, `momentum/`, `volume/`, and `other/`
10//! must implement this trait. The registry (`registry.rs`) stores
11//! `Box<dyn Indicator>` values so they can be created by name at
12//! runtime, matching Python's `@register_indicator` / `IndicatorRegistry`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::types::Candle;
18
19// ── IndicatorOutput ───────────────────────────────────────────────────────────
20
21/// Named column output, analogous to `pd.DataFrame` returned by Python `calculate()`.
22///
23/// Keys are column names such as `"SMA_20"`, `"MACD_line"`, `"ATR_14"`.
24/// Values are aligned `Vec<f64>` of the same length as the input slice.
25/// Leading warm-up entries are `f64::NAN`.
26#[derive(Debug, Clone, Default)]
27pub struct IndicatorOutput {
28 columns: HashMap<String, Vec<f64>>,
29}
30
31impl IndicatorOutput {
32 /// Create an empty output.
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 /// Insert a named column.
38 pub fn insert(&mut self, name: impl Into<String>, values: Vec<f64>) {
39 self.columns.insert(name.into(), values);
40 }
41
42 /// Build from an iterator of `(name, values)` pairs.
43 pub fn from_pairs<K: Into<String>>(pairs: impl IntoIterator<Item = (K, Vec<f64>)>) -> Self {
44 let mut out = Self::new();
45 for (k, v) in pairs {
46 out.insert(k, v);
47 }
48 out
49 }
50
51 /// Get the values for a named column.
52 pub fn get(&self, name: &str) -> Option<&[f64]> {
53 self.columns.get(name).map(Vec::as_slice)
54 }
55
56 /// Get the *last* (most recent) value of a named column, skipping `NaN`.
57 ///
58 /// Mirrors Python's `indicator.get_value(-1)`.
59 pub fn latest(&self, name: &str) -> Option<f64> {
60 self.columns
61 .get(name)?
62 .iter()
63 .rev()
64 .find(|v| !v.is_nan())
65 .copied()
66 }
67
68 /// All column names present in this output.
69 pub fn columns(&self) -> impl Iterator<Item = &str> {
70 self.columns.keys().map(String::as_str)
71 }
72
73 /// Number of rows (length of any column; all columns are guaranteed equal length).
74 pub fn len(&self) -> usize {
75 self.columns.values().next().map_or(0, Vec::len)
76 }
77
78 pub fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81
82 /// Consume into the underlying map.
83 pub fn into_inner(self) -> HashMap<String, Vec<f64>> {
84 self.columns
85 }
86}
87
88// ── Indicator trait ───────────────────────────────────────────────────────────
89
90/// The core trait every indicator must implement.
91///
92/// Analogous to `indicators/base.py :: class Indicator(ABC)`.
93///
94/// # Implementing an indicator
95///
96/// ```rust,ignore
97/// use crate::indicator::{Indicator, IndicatorOutput};
98/// use crate::error::IndicatorError;
99/// use crate::types::Candle;
100///
101/// pub struct Sma {
102/// pub period: usize,
103/// pub column: PriceColumn,
104/// }
105///
106/// impl Indicator for Sma {
107/// fn name(&self) -> &str { "SMA" }
108///
109/// fn required_len(&self) -> usize { self.period }
110///
111/// fn required_columns(&self) -> &[&str] { &["close"] }
112///
113/// fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
114/// }
115/// }
116/// ```
117pub trait Indicator: Send + Sync + std::fmt::Debug {
118 /// Short canonical name, e.g. `"SMA"`, `"RSI"`, `"MACD"`.
119 fn name(&self) -> &'static str;
120
121 /// Minimum number of candles required before output is non-`NaN`.
122 /// Mirrors Python's implicit warm-up period used for validation.
123 fn required_len(&self) -> usize;
124
125 /// Which OHLCV fields this indicator reads.
126 ///
127 /// Mirrors `@classmethod required_columns()` in Python.
128 /// Valid values: `"open"`, `"high"`, `"low"`, `"close"`, `"volume"`.
129 fn required_columns(&self) -> &[&'static str];
130
131 /// Compute the indicator over a full candle slice (batch mode).
132 ///
133 /// Mirrors `def calculate(self, data: pd.DataFrame, price_column) -> pd.DataFrame`.
134 ///
135 /// - Returns `IndicatorOutput` with one or more named columns.
136 /// - Leading warm-up rows should be `f64::NAN`.
137 /// - Returns `Err(IndicatorError::InsufficientData)` if `candles.len() < required_len()`.
138 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError>;
139
140 /// Validate that enough data was supplied, returning a descriptive error if not.
141 ///
142 /// Call this at the top of every `calculate()` implementation.
143 fn check_len(&self, candles: &[Candle]) -> Result<(), IndicatorError> {
144 let required = self.required_len();
145 if candles.len() < required {
146 Err(IndicatorError::InsufficientData {
147 required,
148 available: candles.len(),
149 })
150 } else {
151 Ok(())
152 }
153 }
154}
155
156// ── PriceColumn helper ────────────────────────────────────────────────────────
157
158/// Which single OHLCV field to extract as a price series.
159///
160/// Mirrors the `column` / `price_column` parameter in Python indicators.
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
162pub enum PriceColumn {
163 Open,
164 High,
165 Low,
166 #[default]
167 Close,
168 Volume,
169 /// `(High + Low + Close) / 3`
170 TypicalPrice,
171 /// `(High + Low) / 2`
172 HL2,
173}
174
175impl PriceColumn {
176 /// Extract the column as a `Vec<f64>` from a candle slice.
177 pub fn extract(self, candles: &[Candle]) -> Vec<f64> {
178 candles
179 .iter()
180 .map(|c| match self {
181 PriceColumn::Open => c.open,
182 PriceColumn::High => c.high,
183 PriceColumn::Low => c.low,
184 PriceColumn::Close => c.close,
185 PriceColumn::Volume => c.volume,
186 PriceColumn::TypicalPrice => (c.high + c.low + c.close) / 3.0,
187 PriceColumn::HL2 => (c.high + c.low) / 2.0,
188 })
189 .collect()
190 }
191
192 pub fn as_str(self) -> &'static str {
193 match self {
194 PriceColumn::Open => "open",
195 PriceColumn::High => "high",
196 PriceColumn::Low => "low",
197 PriceColumn::Close => "close",
198 PriceColumn::Volume => "volume",
199 PriceColumn::TypicalPrice => "typical_price",
200 PriceColumn::HL2 => "hl2",
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn indicator_output_insert_and_get() {
211 let mut out = IndicatorOutput::new();
212 out.insert("SMA_20", vec![f64::NAN, f64::NAN, 10.0, 12.0]);
213 assert_eq!(out.len(), 4);
214 assert_eq!(out.latest("SMA_20"), Some(12.0));
215 assert!(out.get("MISSING").is_none());
216 }
217
218 #[test]
219 fn indicator_output_from_pairs() {
220 let out = IndicatorOutput::from_pairs([
221 ("MACD_line", vec![1.0, 2.0]),
222 ("MACD_signal", vec![0.5, 1.5]),
223 ]);
224 assert!(out.get("MACD_line").is_some());
225 assert!(out.get("MACD_signal").is_some());
226 }
227
228 #[test]
229 fn price_column_extract() {
230 let candle = Candle {
231 time: 0,
232 open: 1.0,
233 high: 4.0,
234 low: 2.0,
235 close: 3.0,
236 volume: 100.0,
237 };
238 let candles = vec![candle];
239 assert_eq!(PriceColumn::Close.extract(&candles), vec![3.0]);
240 assert_eq!(
241 PriceColumn::TypicalPrice.extract(&candles),
242 vec![(4.0 + 2.0 + 3.0) / 3.0]
243 );
244 }
245}