use std::collections::HashMap;
use crate::error::IndicatorError;
use crate::indicator::{Indicator, IndicatorOutput};
use crate::registry::param_usize;
use crate::types::Candle;
#[derive(Debug, Clone, Default)]
pub struct VwapParams {
pub period: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct Vwap {
pub params: VwapParams,
}
impl Vwap {
pub fn new(params: VwapParams) -> Self {
Self { params }
}
pub fn cumulative() -> Self {
Self::new(VwapParams { period: None })
}
pub fn rolling(period: usize) -> Self {
Self::new(VwapParams {
period: Some(period),
})
}
fn output_key(&self) -> String {
match self.params.period {
None => "VWAP".to_string(),
Some(p) => format!("VWAP_{p}"),
}
}
}
impl Indicator for Vwap {
fn name(&self) -> &'static str {
"VWAP"
}
fn required_len(&self) -> usize {
self.params.period.unwrap_or(1)
}
fn required_columns(&self) -> &[&'static str] {
&["high", "low", "close", "volume"]
}
fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
self.check_len(candles)?;
let n = candles.len();
let tp: Vec<f64> = candles
.iter()
.map(|c| (c.high + c.low + c.close) / 3.0)
.collect();
let vp: Vec<f64> = candles
.iter()
.zip(&tp)
.map(|(c, &t)| t * c.volume)
.collect();
let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
let values = match self.params.period {
None => {
let mut cum_vp = 0.0f64;
let mut cum_vol = 0.0f64;
vp.iter()
.zip(&vol)
.map(|(&v, &vol)| {
cum_vp += v;
cum_vol += vol;
if cum_vol == 0.0 {
f64::NAN
} else {
cum_vp / cum_vol
}
})
.collect()
}
Some(period) => {
let mut values = vec![f64::NAN; n];
for i in (period - 1)..n {
let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
values[i] = if sum_vol == 0.0 {
f64::NAN
} else {
sum_vp / sum_vol
};
}
values
}
};
Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
}
}
pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
let period = if params.contains_key("period") {
Some(param_usize(params, "period", 0)?)
} else {
None
};
Ok(Box::new(Vwap::new(VwapParams { period })))
}
#[cfg(test)]
mod tests {
use super::*;
fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
data.iter()
.enumerate()
.map(|(i, &(h, l, c, v))| Candle {
time: i64::try_from(i).expect("time index fits i64"),
open: c,
high: h,
low: l,
close: c,
volume: v,
})
.collect()
}
#[test]
fn vwap_cumulative_single_bar() {
let bars = [(10.0, 8.0, 9.0, 100.0)];
let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
let vals = out.get("VWAP").unwrap();
assert!((vals[0] - 9.0).abs() < 1e-9);
}
#[test]
fn vwap_rolling_output_key() {
let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
assert!(out.get("VWAP_3").is_some());
}
#[test]
fn factory_default_is_cumulative() {
let ind = factory(&HashMap::new()).unwrap();
assert_eq!(ind.name(), "VWAP");
}
}