indicators/volume/
vwap.rs1use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27#[derive(Debug, Clone)]
30pub struct VwapParams {
31 pub period: Option<usize>,
34}
35
36impl Default for VwapParams {
37 fn default() -> Self {
38 Self { period: None }
39 }
40}
41
42#[derive(Debug, Clone)]
45pub struct Vwap {
46 pub params: VwapParams,
47}
48
49impl Vwap {
50 pub fn new(params: VwapParams) -> Self {
51 Self { params }
52 }
53 pub fn cumulative() -> Self {
54 Self::new(VwapParams { period: None })
55 }
56 pub fn rolling(period: usize) -> Self {
57 Self::new(VwapParams {
58 period: Some(period),
59 })
60 }
61
62 fn output_key(&self) -> String {
63 match self.params.period {
64 None => "VWAP".to_string(),
65 Some(p) => format!("VWAP_{p}"),
66 }
67 }
68}
69
70impl Indicator for Vwap {
71 fn name(&self) -> &str {
72 "VWAP"
73 }
74
75 fn required_len(&self) -> usize {
76 self.params.period.unwrap_or(1)
77 }
78
79 fn required_columns(&self) -> &[&'static str] {
80 &["high", "low", "close", "volume"]
81 }
82
83 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
85 self.check_len(candles)?;
86
87 let n = candles.len();
88 let tp: Vec<f64> = candles.iter().map(|c| (c.high + c.low + c.close) / 3.0).collect();
89 let vp: Vec<f64> = candles.iter().zip(&tp).map(|(c, &t)| t * c.volume).collect();
90 let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
91
92 let values = match self.params.period {
93 None => {
94 let mut cum_vp = 0.0f64;
96 let mut cum_vol = 0.0f64;
97 vp.iter().zip(&vol).map(|(&v, &vol)| {
98 cum_vp += v;
99 cum_vol += vol;
100 if cum_vol == 0.0 {
101 f64::NAN
102 } else {
103 cum_vp / cum_vol
104 }
105 }).collect()
106 }
107 Some(period) => {
108 let mut values = vec![f64::NAN; n];
110 for i in (period - 1)..n {
111 let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
112 let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
113 values[i] = if sum_vol == 0.0 {
114 f64::NAN
115 } else {
116 sum_vp / sum_vol
117 };
118 }
119 values
120 }
121 };
122
123 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
124 }
125}
126
127pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
130 let period = if params.contains_key("period") {
131 Some(param_usize(params, "period", 0)?)
132 } else {
133 None
134 };
135 Ok(Box::new(Vwap::new(VwapParams { period })))
136}
137
138#[cfg(test)]
141mod tests {
142 use super::*;
143
144 fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
145 data.iter().enumerate().map(|(i, &(h, l, c, v))| Candle {
147 time: i as i64, open: c, high: h, low: l, close: c, volume: v,
148 }).collect()
149 }
150
151 #[test]
152 fn vwap_cumulative_single_bar() {
153 let bars = [(10.0, 8.0, 9.0, 100.0)];
154 let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
155 let vals = out.get("VWAP").unwrap();
156 assert!((vals[0] - 9.0).abs() < 1e-9);
158 }
159
160 #[test]
161 fn vwap_rolling_output_key() {
162 let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
163 let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
164 assert!(out.get("VWAP_3").is_some());
165 }
166
167 #[test]
168 fn factory_default_is_cumulative() {
169 let ind = factory(&HashMap::new()).unwrap();
170 assert_eq!(ind.name(), "VWAP");
171 }
172}