Skip to main content

indicators/volume/
vwap.rs

1//! Volume-Weighted Average Price (VWAP).
2//!
3//! Python source: `indicators/trend/moving_average.py :: class VWAP`
4//!              + `indicators/volume/vwap.py`
5//!
6//! # Python algorithm (to port)
7//! ```python
8//! typical_price = (data["high"] + data["low"] + data["close"]) / 3
9//! volume_price  = typical_price * data["volume"]
10//!
11//! # Cumulative (period=None):
12//! vwap = volume_price.cumsum() / data["volume"].cumsum()
13//!
14//! # Rolling (period=N):
15//! vwap = volume_price.rolling(N).sum() / data["volume"].rolling(N).sum()
16//! ```
17//!
18//! Output column: `"VWAP"` (cumulative) or `"VWAP_{period}"` (rolling).
19
20use std::collections::HashMap;
21
22use crate::error::IndicatorError;
23use crate::indicator::{Indicator, IndicatorOutput};
24use crate::registry::param_usize;
25use crate::types::Candle;
26
27// ── Params ────────────────────────────────────────────────────────────────────
28
29#[derive(Debug, Clone)]
30pub struct VwapParams {
31    /// Rolling window.  `None` = cumulative VWAP (session-based).
32    /// Python default: `None`.
33    pub period: Option<usize>,
34}
35
36impl Default for VwapParams {
37    fn default() -> Self {
38        Self { period: None }
39    }
40}
41
42// ── Indicator struct ──────────────────────────────────────────────────────────
43
44#[derive(Debug, Clone)]
45pub struct Vwap {
46    pub params: VwapParams,
47}
48
49impl Vwap {
50    pub fn new(params: VwapParams) -> Self {
51        Self { params }
52    }
53    pub fn cumulative() -> Self {
54        Self::new(VwapParams { period: None })
55    }
56    pub fn rolling(period: usize) -> Self {
57        Self::new(VwapParams {
58            period: Some(period),
59        })
60    }
61
62    fn output_key(&self) -> String {
63        match self.params.period {
64            None => "VWAP".to_string(),
65            Some(p) => format!("VWAP_{p}"),
66        }
67    }
68}
69
70impl Indicator for Vwap {
71    fn name(&self) -> &str {
72        "VWAP"
73    }
74
75    fn required_len(&self) -> usize {
76        self.params.period.unwrap_or(1)
77    }
78
79    fn required_columns(&self) -> &[&'static str] {
80        &["high", "low", "close", "volume"]
81    }
82
83    /// TODO: port Python cumulative / rolling VWAP.
84    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
85        self.check_len(candles)?;
86
87        let n = candles.len();
88        let tp: Vec<f64> = candles.iter().map(|c| (c.high + c.low + c.close) / 3.0).collect();
89        let vp: Vec<f64> = candles.iter().zip(&tp).map(|(c, &t)| t * c.volume).collect();
90        let vol: Vec<f64> = candles.iter().map(|c| c.volume).collect();
91
92        let values = match self.params.period {
93            None => {
94                // TODO: cumulative VWAP
95                let mut cum_vp = 0.0f64;
96                let mut cum_vol = 0.0f64;
97                vp.iter().zip(&vol).map(|(&v, &vol)| {
98                    cum_vp += v;
99                    cum_vol += vol;
100                    if cum_vol == 0.0 {
101                            f64::NAN
102                        } else {
103                            cum_vp / cum_vol
104                        }
105                }).collect()
106            }
107            Some(period) => {
108                // TODO: rolling VWAP
109                let mut values = vec![f64::NAN; n];
110                for i in (period - 1)..n {
111                    let sum_vp: f64 = vp[(i + 1 - period)..=i].iter().sum();
112                    let sum_vol: f64 = vol[(i + 1 - period)..=i].iter().sum();
113                    values[i] = if sum_vol == 0.0 {
114                        f64::NAN
115                    } else {
116                        sum_vp / sum_vol
117                    };
118                }
119                values
120            }
121        };
122
123        Ok(IndicatorOutput::from_pairs([(self.output_key(), values)]))
124    }
125}
126
127// ── Registry factory ──────────────────────────────────────────────────────────
128
129pub fn factory(params: &HashMap<String, String>) -> Result<Box<dyn Indicator>, IndicatorError> {
130    let period = if params.contains_key("period") {
131        Some(param_usize(params, "period", 0)?)
132    } else {
133        None
134    };
135    Ok(Box::new(Vwap::new(VwapParams { period })))
136}
137
138// ── Tests ─────────────────────────────────────────────────────────────────────
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    fn candles(data: &[(f64, f64, f64, f64)]) -> Vec<Candle> {
145        // (high, low, close, volume)
146        data.iter().enumerate().map(|(i, &(h, l, c, v))| Candle {
147            time: i as i64, open: c, high: h, low: l, close: c, volume: v,
148        }).collect()
149    }
150
151    #[test]
152    fn vwap_cumulative_single_bar() {
153        let bars = [(10.0, 8.0, 9.0, 100.0)];
154        let out = Vwap::cumulative().calculate(&candles(&bars)).unwrap();
155        let vals = out.get("VWAP").unwrap();
156        // tp = (10+8+9)/3 = 9; vwap = 9*100/100 = 9
157        assert!((vals[0] - 9.0).abs() < 1e-9);
158    }
159
160    #[test]
161    fn vwap_rolling_output_key() {
162        let bars = vec![(10.0, 8.0, 9.0, 100.0); 5];
163        let out = Vwap::rolling(3).calculate(&candles(&bars)).unwrap();
164        assert!(out.get("VWAP_3").is_some());
165    }
166
167    #[test]
168    fn factory_default_is_cumulative() {
169        let ind = factory(&HashMap::new()).unwrap();
170        assert_eq!(ind.name(), "VWAP");
171    }
172}