indicators/volume/
vwap.rs1use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27#[derive(Debug, Clone, Default)]
30pub struct VwapParams {
31 pub period: Option<usize>,
34}
35
36#[derive(Debug, Clone)]
39pub struct Vwap {
40 pub params: VwapParams,
41}
42
43impl Vwap {
44 pub fn new(params: VwapParams) -> Self {
45 Self { params }
46 }
47 pub fn cumulative() -> Self {
48 Self::new(VwapParams { period: None })
49 }
50 pub fn rolling(period: usize) -> Self {
51 Self::new(VwapParams {
52 period: Some(period),
53 })
54 }
55
56 fn output_key(&self) -> String {
57 match self.params.period {
58 None => "VWAP".to_string(),
59 Some(p) => format!("VWAP_{p}"),
60 }
61 }
62}
63
64impl Indicator for Vwap {
65 fn name(&self) -> &'static str {
66 "VWAP"
67 }
68
69 fn required_len(&self) -> usize {
70 self.params.period.unwrap_or(1)
71 }
72
73 fn required_columns(&self) -> &[&'static str] {
74 &["high", "low", "close", "volume"]
75 }
76
77 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
83 self.check_len(candles)?;
84
85 let n = candles.len();
86 let tp: Vec<f64> = candles
87 .iter()
88 .map(|c| (c.high + c.low + c.close) / 3.0)
89 .collect();
90 let vp: Vec<f64> = candles
91 .iter()
92 .zip(&tp)
93 .map(|(c, &t)| t * c.volume)
94 .collect();
95 let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
96
97 let values = match self.params.period {
98 None => {
99 let mut cum_vp = 0.0f64;
101 let mut cum_vol = 0.0f64;
102 vp.iter()
103 .zip(&vol)
104 .map(|(&v, &vol)| {
105 cum_vp += v;
106 cum_vol += vol;
107 if cum_vol == 0.0 {
108 f64::NAN
109 } else {
110 cum_vp / cum_vol
111 }
112 })
113 .collect()
114 }
115 Some(period) => {
116 let mut values = vec![f64::NAN; n];
118 for i in (period - 1)..n {
119 let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
120 let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
121 values[i] = if sum_vol == 0.0 {
122 f64::NAN
123 } else {
124 sum_vp / sum_vol
125 };
126 }
127 values
128 }
129 };
130
131 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
132 }
133}
134
135pub fn factory<S: ::std::hash::BuildHasher>(
138 params: &HashMap<String, String, S>,
139) -> Result<Box<dyn Indicator>, IndicatorError> {
140 let period = if params.contains_key("period") {
141 Some(param_usize(params, "period", 0)?)
142 } else {
143 None
144 };
145 Ok(Box::new(Vwap::new(VwapParams { period })))
146}
147
148#[cfg(test)]
151mod tests {
152 use super::*;
153
154 fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
155 data.iter()
157 .enumerate()
158 .map(|(i, &(h, l, c, v))| Candle {
159 time: i64::try_from(i).expect("time index fits i64"),
160 open: c,
161 high: h,
162 low: l,
163 close: c,
164 volume: v,
165 })
166 .collect()
167 }
168
169 #[test]
170 fn vwap_cumulative_single_bar() {
171 let bars = [(10.0, 8.0, 9.0, 100.0)];
172 let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
173 let vals = out.get("VWAP").unwrap();
174 assert!((vals[0] - 9.0).abs() < 1e-9);
176 }
177
178 #[test]
179 fn vwap_rolling_output_key() {
180 let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
181 let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
182 assert!(out.get("VWAP_3").is_some());
183 }
184
185 #[test]
186 fn factory_default_is_cumulative() {
187 let ind = factory(&HashMap::new()).unwrap();
188 assert_eq!(ind.name(), "VWAP");
189 }
190}