Skip to main content

indicators/volume/
vwap.rs

1//! Volume-Weighted Average Price (VWAP).
2//!
3//! Python source: `indicators/trend/moving_average.py :: class VWAP`
4//!              + `indicators/volume/vwap.py`
5//!
6//! # Python algorithm (to port)
7//! ```python
8//! typical_price = (data["high"] + data["low"] + data["close"]) / 3
9//! volume_price  = typical_price * data["volume"]
10//!
11//! # Cumulative (period=None):
12//! vwap = volume_price.cumsum() / data["volume"].cumsum()
13//!
14//! # Rolling (period=N):
15//! vwap = volume_price.rolling(N).sum() / data["volume"].rolling(N).sum()
16//! ```
17//!
18//! Output column: `"VWAP"` (cumulative) or `"VWAP_{period}"` (rolling).
19
20use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27// ── Params ────────────────────────────────────────────────────────────────────
28
29#[derive(Debug, Clone, Default)]
30pub struct VwapParams {
31    /// Rolling window.  `None` = cumulative VWAP (session-based).
32    /// Python default: `None`.
33    pub period: Option<usize>,
34}
35
36// ── Indicator struct ──────────────────────────────────────────────────────────
37
38#[derive(Debug, Clone)]
39pub struct Vwap {
40    pub params: VwapParams,
41}
42
43impl Vwap {
44    pub fn new(params: VwapParams) -> Self {
45        Self { params }
46    }
47    pub fn cumulative() -> Self {
48        Self::new(VwapParams { period: None })
49    }
50    pub fn rolling(period: usize) -> Self {
51        Self::new(VwapParams {
52            period: Some(period),
53        })
54    }
55
56    fn output_key(&self) -> String {
57        match self.params.period {
58            None => "VWAP".to_string(),
59            Some(p) => format!("VWAP_{p}"),
60        }
61    }
62}
63
64impl Indicator for Vwap {
65    fn name(&self) -> &'static str {
66        "VWAP"
67    }
68
69    fn required_len(&self) -> usize {
70        self.params.period.unwrap_or(1)
71    }
72
73    fn required_columns(&self) -> &[&'static str] {
74        &["high", "low", "close", "volume"]
75    }
76
77    /// Cumulative VWAP:  `cumsum(tp * vol) / cumsum(vol)` — no NaN warm-up.
78    /// Rolling VWAP:     `rolling_sum(tp * vol, N) / rolling_sum(vol, N)` —
79    ///                   `NaN` for the first `period - 1` positions.
80    ///
81    /// `tp` (typical price) = `(high + low + close) / 3`.
82    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
83        self.check_len(candles)?;
84
85        let n = candles.len();
86        let tp: Vec<f64> = candles
87            .iter()
88            .map(|c| (c.high + c.low + c.close) / 3.0)
89            .collect();
90        let vp: Vec<f64> = candles
91            .iter()
92            .zip(&tp)
93            .map(|(c, &t)| t * c.volume)
94            .collect();
95        let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
96
97        let values = match self.params.period {
98            None => {
99                // Cumulative VWAP — produces a value for every bar.
100                let mut cum_vp = 0.0f64;
101                let mut cum_vol = 0.0f64;
102                vp.iter()
103                    .zip(&vol)
104                    .map(|(&v, &vol)| {
105                        cum_vp += v;
106                        cum_vol += vol;
107                        if cum_vol == 0.0 {
108                            f64::NAN
109                        } else {
110                            cum_vp / cum_vol
111                        }
112                    })
113                    .collect()
114            }
115            Some(period) => {
116                // Rolling VWAP — NaN for first `period - 1` bars.
117                let mut values = vec![f64::NAN; n];
118                for i in (period - 1)..n {
119                    let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
120                    let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
121                    values[i] = if sum_vol == 0.0 {
122                        f64::NAN
123                    } else {
124                        sum_vp / sum_vol
125                    };
126                }
127                values
128            }
129        };
130
131        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
132    }
133}
134
135// ── Registry factory ──────────────────────────────────────────────────────────
136
137pub fn factory<S: ::std::hash::BuildHasher>(
138    params: &HashMap<String, String, S>,
139) -> Result<Box<dyn Indicator>, IndicatorError> {
140    let period = if params.contains_key("period") {
141        Some(param_usize(params, "period", 0)?)
142    } else {
143        None
144    };
145    Ok(Box::new(Vwap::new(VwapParams { period })))
146}
147
148// ── Tests ─────────────────────────────────────────────────────────────────────
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
155        // (high, low, close, volume)
156        data.iter()
157            .enumerate()
158            .map(|(i, &(h, l, c, v))| Candle {
159                time: i64::try_from(i).expect("time index fits i64"),
160                open: c,
161                high: h,
162                low: l,
163                close: c,
164                volume: v,
165            })
166            .collect()
167    }
168
169    #[test]
170    fn vwap_cumulative_single_bar() {
171        let bars = [(10.0, 8.0, 9.0, 100.0)];
172        let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
173        let vals = out.get("VWAP").unwrap();
174        // tp = (10+8+9)/3 = 9; vwap = 9*100/100 = 9
175        assert!((vals[0] - 9.0).abs() < 1e-9);
176    }
177
178    #[test]
179    fn vwap_rolling_output_key() {
180        let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
181        let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
182        assert!(out.get("VWAP_3").is_some());
183    }
184
185    #[test]
186    fn factory_default_is_cumulative() {
187        let ind = factory(&HashMap::new()).unwrap();
188        assert_eq!(ind.name(), "VWAP");
189    }
190}