use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::types::Candle;
#[derive(Debug, Clone, Default)]
pub struct IndicatorOutput {
columns: HashMap<String, Vec<f64>>,
}
impl IndicatorOutput {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: impl Into<String>, values: Vec<f64>) {
self.columns.insert(name.into(), values);
}
pub fn from_pairs<K: Into<String>>(pairs: impl IntoIterator<Item = (K, Vec<f64>)>) -> Self {
let mut out = Self::new();
for (k, v) in pairs {
out.insert(k, v);
}
out
}
pub fn get(&self, name: &str) -> Option<&[f64]> {
self.columns.get(name).map(Vec::as_slice)
}
pub fn latest(&self, name: &str) -> Option<f64> {
self.columns
.get(name)?
.iter()
.rev()
.find(|v| !v.is_nan())
.copied()
}
pub fn columns(&self) -> impl Iterator<Item = &str> {
self.columns.keys().map(String::as_str)
}
pub fn len(&self) -> usize {
self.columns.values().next().map_or(0, Vec::len)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn into_inner(self) -> HashMap<String, Vec<f64>> {
self.columns
}
}
pub trait Indicator: Send + Sync + std::fmt::Debug {
fn name(&self) -> &'static str;
fn required_len(&self) -> usize;
fn required_columns(&self) -> &[&'static str];
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError>;
fn check_len(&self, candles: &[Candle]) -> Result<(), IndicatorError> {
let required = self.required_len();
if candles.len() < required {
Err(IndicatorError::InsufficientData {
required,
available: candles.len(),
})
} else {
Ok(())
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PriceColumn {
Open,
High,
Low,
#[default]
Close,
Volume,
TypicalPrice,
HL2,
}
impl PriceColumn {
pub fn extract(self, candles: &[Candle]) -> Vec<f64> {
candles
.iter()
.map(|c| match self {
PriceColumn::Open => c.open,
PriceColumn::High => c.high,
PriceColumn::Low => c.low,
PriceColumn::Close => c.close,
PriceColumn::Volume => c.volume,
PriceColumn::TypicalPrice => (c.high + c.low + c.close) / 3.0,
PriceColumn::HL2 => (c.high + c.low) / 2.0,
})
.collect()
}
pub fn as_str(self) -> &'static str {
match self {
PriceColumn::Open => "open",
PriceColumn::High => "high",
PriceColumn::Low => "low",
PriceColumn::Close => "close",
PriceColumn::Volume => "volume",
PriceColumn::TypicalPrice => "typical_price",
PriceColumn::HL2 => "hl2",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn indicator_output_insert_and_get() {
let mut out = IndicatorOutput::new();
out.insert("SMA_20", vec![f64::NAN, f64::NAN, 10.0, 12.0]);
assert_eq!(out.len(), 4);
assert_eq!(out.latest("SMA_20"), Some(12.0));
assert!(out.get("MISSING").is_none());
}
#[test]
fn indicator_output_from_pairs() {
let out = IndicatorOutput::from_pairs([
("MACD_line", vec![1.0, 2.0]),
("MACD_signal", vec![0.5, 1.5]),
]);
assert!(out.get("MACD_line").is_some());
assert!(out.get("MACD_signal").is_some());
}
#[test]
fn price_column_extract() {
let candle = Candle {
time: 0,
open: 1.0,
high: 4.0,
low: 2.0,
close: 3.0,
volume: 100.0,
};
let candles = vec![candle];
assert_eq!(PriceColumn::Close.extract(&candles), vec![3.0]);
assert_eq!(
PriceColumn::TypicalPrice.extract(&candles),
vec![(4.0 + 2.0 + 3.0) / 3.0]
);
}
}