indicators/volume/
vwap.rs1use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27#[derive(Debug, Clone, Default)]
30pub struct VwapParams {
31 pub period: Option<usize>,
34}
35
36#[derive(Debug, Clone)]
39pub struct Vwap {
40 pub params: VwapParams,
41}
42
43impl Vwap {
44 pub fn new(params: VwapParams) -> Self {
45 Self { params }
46 }
47 pub fn cumulative() -> Self {
48 Self::new(VwapParams { period: None })
49 }
50 pub fn rolling(period: usize) -> Self {
51 Self::new(VwapParams {
52 period: Some(period),
53 })
54 }
55
56 fn output_key(&self) -> String {
57 match self.params.period {
58 None => "VWAP".to_string(),
59 Some(p) => format!("VWAP_{p}"),
60 }
61 }
62}
63
64impl Indicator for Vwap {
65 fn name(&self) -> &'static str {
66 "VWAP"
67 }
68
69 fn required_len(&self) -> usize {
70 self.params.period.unwrap_or(1)
71 }
72
73 fn required_columns(&self) -> &[&'static str] {
74 &["high", "low", "close", "volume"]
75 }
76
77 fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
79 self.check_len(candles)?;
80
81 let n = candles.len();
82 let tp: Vec<f64> = candles
83 .iter()
84 .map(|c| (c.high + c.low + c.close) / 3.0)
85 .collect();
86 let vp: Vec<f64> = candles
87 .iter()
88 .zip(&tp)
89 .map(|(c, &t)| t * c.volume)
90 .collect();
91 let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
92
93 let values = match self.params.period {
94 None => {
95 let mut cum_vp = 0.0f64;
97 let mut cum_vol = 0.0f64;
98 vp.iter()
99 .zip(&vol)
100 .map(|(&v, &vol)| {
101 cum_vp += v;
102 cum_vol += vol;
103 if cum_vol == 0.0 {
104 f64::NAN
105 } else {
106 cum_vp / cum_vol
107 }
108 })
109 .collect()
110 }
111 Some(period) => {
112 let mut values = vec![f64::NAN; n];
114 for i in (period - 1)..n {
115 let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
116 let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
117 values[i] = if sum_vol == 0.0 {
118 f64::NAN
119 } else {
120 sum_vp / sum_vol
121 };
122 }
123 values
124 }
125 };
126
127 Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
128 }
129}
130
131pub fn factory<S: ::std::hash::BuildHasher>(params: &HashMap<String, String, S>) -> Result<Box<dyn Indicator>, IndicatorError> {
134 let period = if params.contains_key("period") {
135 Some(param_usize(params, "period", 0)?)
136 } else {
137 None
138 };
139 Ok(Box::new(Vwap::new(VwapParams { period })))
140}
141
142#[cfg(test)]
145mod tests {
146 use super::*;
147
148 fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
149 data.iter()
151 .enumerate()
152 .map(|(i, &(h, l, c, v))| Candle {
153 time: i64::try_from(i).expect("time index fits i64"),
154 open: c,
155 high: h,
156 low: l,
157 close: c,
158 volume: v,
159 })
160 .collect()
161 }
162
163 #[test]
164 fn vwap_cumulative_single_bar() {
165 let bars = [(10.0, 8.0, 9.0, 100.0)];
166 let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
167 let vals = out.get("VWAP").unwrap();
168 assert!((vals[0] - 9.0).abs() < 1e-9);
170 }
171
172 #[test]
173 fn vwap_rolling_output_key() {
174 let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
175 let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
176 assert!(out.get("VWAP_3").is_some());
177 }
178
179 #[test]
180 fn factory_default_is_cumulative() {
181 let ind = factory(&HashMap::new()).unwrap();
182 assert_eq!(ind.name(), "VWAP");
183 }
184}