indicators/indicator.rs
1//! Core `Indicator` trait and `IndicatorOutput` type.
2//!
3//! Mirrors `indicators/base.py`:
4//! - `Indicator` ↔ `class Indicator(Component, ABC)`
5//! - `IndicatorOutput` ↔ `pd.DataFrame` return value
6//! - `required_columns()` ↔ `@classmethod required_columns()`
7//! - `calculate()` ↔ `def calculate(self, data, price_column)`
8//!
9//! Every indicator in `trend/`, `momentum/`, `volume/`, and `other/`
10//! must implement this trait. The registry (`registry.rs`) stores
11//! `Box<dyn Indicator>` values so they can be created by name at
12//! runtime, matching Python's `@register_indicator` / `IndicatorRegistry`.
13
14use std::collections::HashMap;
15
16use crate::error::IndicatorError;
17use crate::types::Candle;
18
19// ── IndicatorOutput ───────────────────────────────────────────────────────────
20
21/// Named column output, analogous to `pd.DataFrame` returned by Python `calculate()`.
22///
23/// Keys are column names such as `"SMA_20"`, `"MACD_line"`, `"ATR_14"`.
24/// Values are aligned `Vec<f64>` of the same length as the input slice.
25/// Leading warm-up entries are `f64::NAN`.
26#[derive(Debug, Clone, Default)]
27pub struct IndicatorOutput {
28 columns: HashMap<String, Vec<f64>>,
29}
30
31impl IndicatorOutput {
32 /// Create an empty output.
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 /// Insert a named column.
38 pub fn insert(&mut self, name: impl Into<String>, values: Vec<f64>) {
39 self.columns.insert(name.into(), values);
40 }
41
42 /// Build from an iterator of `(name, values)` pairs.
43 pub fn from_pairs<K: Into<String>>(pairs: impl IntoIterator<Item = (K, Vec<f64>)>) -> Self {
44 let mut out = Self::new();
45 for (k, v) in pairs {
46 out.insert(k, v);
47 }
48 out
49 }
50
51 /// Get the values for a named column.
52 pub fn get(&self, name: &str) -> Option<&[f64]> {
53 self.columns.get(name).map(Vec::as_slice)
54 }
55
56 /// Get the *last* (most recent) value of a named column, skipping `NaN`.
57 ///
58 /// Mirrors Python's `indicator.get_value(-1)`.
59 pub fn latest(&self, name: &str) -> Option<f64> {
60 self.columns
61 .get(name)?
62 .iter()
63 .rev()
64 .find(|v| !v.is_nan())
65 .copied()
66 }
67
68 /// All column names present in this output.
69 pub fn columns(&self) -> impl Iterator<Item = &str> {
70 self.columns.keys().map(String::as_str)
71 }
72
73 /// Number of rows (length of any column; all columns are guaranteed equal length).
74 pub fn len(&self) -> usize {
75 self.columns.values().next().map_or(0, Vec::len)
76 }
77
78 pub fn is_empty(&self) -> bool {
79 self.len() == 0
80 }
81
82 /// Consume into the underlying map.
83 pub fn into_inner(self) -> HashMap<String, Vec<f64>> {
84 self.columns
85 }
86}
87
88// ── Indicator trait ───────────────────────────────────────────────────────────
89
90/// The core trait every indicator must implement.
91///
92/// Analogous to `indicators/base.py :: class Indicator(ABC)`.
93///
94/// # Implementing an indicator
95///
96/// ```rust,ignore
97/// use crate::indicator::{Indicator, IndicatorOutput};
98/// use crate::error::IndicatorError;
99/// use crate::types::Candle;
100///
101/// pub struct Sma {
102/// pub period: usize,
103/// pub column: PriceColumn,
104/// }
105///
106/// impl Indicator for Sma {
107/// fn name(&self) -> &str { "SMA" }
108///
109/// fn required_len(&self) -> usize { self.period }
110///
111/// fn required_columns(&self) -> &[&str] { &["close"] }
112///
113/// fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
114/// // port Python logic here
115/// todo!()
116/// }
117/// }
118/// ```
119pub trait Indicator: Send + Sync + std::fmt::Debug {
120 /// Short canonical name, e.g. `"SMA"`, `"RSI"`, `"MACD"`.
121 fn name(&self) -> &'static str;
122
123 /// Minimum number of candles required before output is non-`NaN`.
124 /// Mirrors Python's implicit warm-up period used for validation.
125 fn required_len(&self) -> usize;
126
127 /// Which OHLCV fields this indicator reads.
128 ///
129 /// Mirrors `@classmethod required_columns()` in Python.
130 /// Valid values: `"open"`, `"high"`, `"low"`, `"close"`, `"volume"`.
131 fn required_columns(&self) -> &[&'static str];
132
133 /// Compute the indicator over a full candle slice (batch mode).
134 ///
135 /// Mirrors `def calculate(self, data: pd.DataFrame, price_column) -> pd.DataFrame`.
136 ///
137 /// - Returns `IndicatorOutput` with one or more named columns.
138 /// - Leading warm-up rows should be `f64::NAN`.
139 /// - Returns `Err(IndicatorError::InsufficientData)` if `candles.len() < required_len()`.
140 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError>;
141
142 /// Validate that enough data was supplied, returning a descriptive error if not.
143 ///
144 /// Call this at the top of every `calculate()` implementation.
145 fn check_len(&self, candles: &[Candle]) -> Result<(), IndicatorError> {
146 let required = self.required_len();
147 if candles.len() < required {
148 Err(IndicatorError::InsufficientData {
149 required,
150 available: candles.len(),
151 })
152 } else {
153 Ok(())
154 }
155 }
156}
157
158// ── PriceColumn helper ────────────────────────────────────────────────────────
159
160/// Which single OHLCV field to extract as a price series.
161///
162/// Mirrors the `column` / `price_column` parameter in Python indicators.
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
164pub enum PriceColumn {
165 Open,
166 High,
167 Low,
168 #[default]
169 Close,
170 Volume,
171 /// `(High + Low + Close) / 3`
172 TypicalPrice,
173 /// `(High + Low) / 2`
174 HL2,
175}
176
177impl PriceColumn {
178 /// Extract the column as a `Vec<f64>` from a candle slice.
179 pub fn extract(self, candles: &[Candle]) -> Vec<f64> {
180 candles
181 .iter()
182 .map(|c| match self {
183 PriceColumn::Open => c.open,
184 PriceColumn::High => c.high,
185 PriceColumn::Low => c.low,
186 PriceColumn::Close => c.close,
187 PriceColumn::Volume => c.volume,
188 PriceColumn::TypicalPrice => (c.high + c.low + c.close) / 3.0,
189 PriceColumn::HL2 => (c.high + c.low) / 2.0,
190 })
191 .collect()
192 }
193
194 pub fn as_str(self) -> &'static str {
195 match self {
196 PriceColumn::Open => "open",
197 PriceColumn::High => "high",
198 PriceColumn::Low => "low",
199 PriceColumn::Close => "close",
200 PriceColumn::Volume => "volume",
201 PriceColumn::TypicalPrice => "typical_price",
202 PriceColumn::HL2 => "hl2",
203 }
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn indicator_output_insert_and_get() {
213 let mut out = IndicatorOutput::new();
214 out.insert("SMA_20", vec![f64::NAN, f64::NAN, 10.0, 12.0]);
215 assert_eq!(out.len(), 4);
216 assert_eq!(out.latest("SMA_20"), Some(12.0));
217 assert!(out.get("MISSING").is_none());
218 }
219
220 #[test]
221 fn indicator_output_from_pairs() {
222 let out = IndicatorOutput::from_pairs([
223 ("MACD_line", vec![1.0, 2.0]),
224 ("MACD_signal", vec![0.5, 1.5]),
225 ]);
226 assert!(out.get("MACD_line").is_some());
227 assert!(out.get("MACD_signal").is_some());
228 }
229
230 #[test]
231 fn price_column_extract() {
232 let candle = Candle {
233 time: 0,
234 open: 1.0,
235 high: 4.0,
236 low: 2.0,
237 close: 3.0,
238 volume: 100.0,
239 };
240 let candles = vec![candle];
241 assert_eq!(PriceColumn::Close.extract(&candles), vec![3.0]);
242 assert_eq!(
243 PriceColumn::TypicalPrice.extract(&candles),
244 vec![(4.0 + 2.0 + 3.0) / 3.0]
245 );
246 }
247}