Skip to main content

indicators/signal/
vol_regime.rs

1//! Volume-regime helpers: rolling percentile tracker, volatility regime classifier,
2//! and a simple MA-slope market regime classifier.
3//!
4//! These are ported from the Python `VolatilityPercentile`, `PercentileTracker`,
5//! and `MarketRegime` classes in `indicators.py`.
6//!
7//! Note: `MarketRegimeTracker` is distinct from the statistical `MarketRegime` enum
8//! in `types.rs` — it is a simpler slope-based classifier used by the signal engine.
9
10use std::collections::{HashMap, VecDeque};
11
12use crate::error::IndicatorError;
13use crate::indicator::{Indicator, IndicatorOutput};
14use crate::registry::param_usize;
15use crate::types::Candle;
16
17// ── Params ────────────────────────────────────────────────────────────────────
18
19#[derive(Debug, Clone)]
20pub struct VolumeRegimeParams {
21    /// ATR period for computing true-range inputs to the percentile tracker.
22    pub atr_period: usize,
23    /// Rolling window for the [`PercentileTracker`].
24    pub pct_window: usize,
25}
26
27impl Default for VolumeRegimeParams {
28    fn default() -> Self {
29        Self {
30            atr_period: 14,
31            pct_window: 100,
32        }
33    }
34}
35
36// ── Indicator struct ──────────────────────────────────────────────────────────
37
38/// Batch `Indicator` wrapping [`VolatilityPercentile`].
39///
40/// Computes a rolling ATR, feeds it into the percentile tracker, and outputs
41/// `vol_pct` (0–1) and `vol_regime` (encoded as 0=VERY_LOW … 4=VERY_HIGH).
42#[derive(Debug, Clone)]
43pub struct VolumeRegime {
44    pub params: VolumeRegimeParams,
45}
46
47impl VolumeRegime {
48    pub fn new(params: VolumeRegimeParams) -> Self {
49        Self { params }
50    }
51    pub fn with_defaults() -> Self {
52        Self::new(VolumeRegimeParams::default())
53    }
54}
55
56impl Indicator for VolumeRegime {
57    fn name(&self) -> &'static str {
58        "VolumeRegime"
59    }
60    fn required_len(&self) -> usize {
61        self.params.atr_period + 1
62    }
63    fn required_columns(&self) -> &[&'static str] {
64        &["high", "low", "close"]
65    }
66
67    fn calculate(&self, candles: &[Candle]) -> Result<IndicatorOutput, IndicatorError> {
68        self.check_len(candles)?;
69        let p = &self.params;
70        let mut tracker = VolatilityPercentile::new(p.pct_window);
71
72        // Incremental ATR (RMA / Wilder smoothing).
73        let mut prev_close: Option<f64> = None;
74        let mut atr_rma: Option<f64> = None;
75        let alpha = 1.0 / p.atr_period as f64;
76
77        let n = candles.len();
78        let mut vol_pct = vec![f64::NAN; n];
79        let mut vol_regime = vec![f64::NAN; n];
80
81        for (i, c) in candles.iter().enumerate() {
82            let tr = match prev_close {
83                None => c.high - c.low,
84                Some(pc) => (c.high - c.low)
85                    .max((c.high - pc).abs())
86                    .max((c.low - pc).abs()),
87            };
88            atr_rma = Some(match atr_rma {
89                None => tr,
90                Some(a) => alpha * tr + (1.0 - alpha) * a,
91            });
92            prev_close = Some(c.close);
93
94            tracker.update(atr_rma);
95            vol_pct[i] = tracker.vol_pct;
96            vol_regime[i] = match tracker.vol_regime {
97                "VERY LOW" => 0.0,
98                "LOW" => 1.0,
99                "HIGH" => 3.0,
100                "VERY HIGH" => 4.0,
101                _ => 2.0, // MED
102            };
103        }
104
105        Ok(IndicatorOutput::from_pairs([
106            ("vol_pct", vol_pct),
107            ("vol_regime", vol_regime),
108        ]))
109    }
110}
111
112// ── Registry factory ──────────────────────────────────────────────────────────
113
114pub fn factory<S: ::std::hash::BuildHasher>(
115    params: &HashMap<String, String, S>,
116) -> Result<Box<dyn Indicator>, IndicatorError> {
117    let atr_period = param_usize(params, "atr_period", 14)?;
118    let pct_window = param_usize(params, "pct_window", 100)?;
119    Ok(Box::new(VolumeRegime::new(VolumeRegimeParams {
120        atr_period,
121        pct_window,
122    })))
123}
124
125// ── PercentileTracker ─────────────────────────────────────────────────────────
126
127/// Rolling percentile calculator over a fixed-size window.
128pub struct PercentileTracker {
129    buf: VecDeque<f64>,
130}
131
132impl PercentileTracker {
133    pub fn new(maxlen: usize) -> Self {
134        Self {
135            buf: VecDeque::with_capacity(maxlen),
136        }
137    }
138
139    /// Seed the buffer with alternating `lo` / `hi` values so it is never empty.
140    pub fn seeded(maxlen: usize, seed_lo: f64, seed_hi: f64) -> Self {
141        let mut t = Self::new(maxlen);
142        for i in 0..(maxlen / 2) {
143            t.buf.push_back(if i % 2 == 0 { seed_lo } else { seed_hi });
144        }
145        t
146    }
147
148    pub fn push(&mut self, val: f64) {
149        if self.buf.len() == self.buf.capacity() {
150            self.buf.pop_front();
151        }
152        self.buf.push_back(val);
153    }
154
155    /// Fraction of buffered values strictly less than `val`.
156    pub fn pct(&self, val: f64) -> f64 {
157        let n = self.buf.len();
158        if n == 0 {
159            return 0.5;
160        }
161        self.buf.iter().filter(|&&v| v < val).count() as f64 / n as f64
162    }
163}
164
165// ── VolatilityPercentile ──────────────────────────────────────────────────────
166
167/// Classifies ATR into a volatility regime by comparing the current ATR to its
168/// own rolling percentile history.
169pub struct VolatilityPercentile {
170    tracker: PercentileTracker,
171    pub vol_pct: f64,
172    pub vol_regime: &'static str,
173    pub vol_mult: f64,
174    /// Confidence score adjustment applied to `conf_min_score`.
175    pub conf_adj: f64,
176}
177
178impl VolatilityPercentile {
179    pub fn new(maxlen: usize) -> Self {
180        let tracker = PercentileTracker::seeded(maxlen, 20.0, 200.0);
181        Self {
182            tracker,
183            vol_pct: 0.5,
184            vol_regime: "MED",
185            vol_mult: 1.2,
186            conf_adj: 1.0,
187        }
188    }
189
190    pub fn update(&mut self, atr: Option<f64>) {
191        let Some(v) = atr else { return };
192        if v <= 0.0 {
193            return;
194        }
195        self.tracker.push(v);
196        let p = self.tracker.pct(v);
197        self.vol_pct = p;
198        (self.vol_regime, self.vol_mult, self.conf_adj) = if p >= 0.8 {
199            ("VERY HIGH", 1.8, 1.15)
200        } else if p >= 0.6 {
201            ("HIGH", 1.5, 1.05)
202        } else if p <= 0.2 {
203            ("VERY LOW", 0.8, 0.9)
204        } else if p <= 0.4 {
205            ("LOW", 1.0, 0.95)
206        } else {
207            ("MED", 1.2, 1.0)
208        };
209    }
210}
211
212// ── MarketRegimeTracker ───────────────────────────────────────────────────────
213
214/// Simple slope + volatility regime tracker (ported from Python `MarketRegime` class).
215///
216/// Uses a 200-bar MA slope and return volatility to classify as:
217/// `"TRENDING↑"`, `"TRENDING↓"`, `"VOLATILE"`, `"RANGING"`, or `"NEUTRAL"`.
218pub struct MarketRegimeTracker {
219    closes: VecDeque<f64>,
220    ma200_hist: VecDeque<f64>,
221    ret_hist: VecDeque<f64>,
222
223    pub regime: &'static str,
224    pub is_trending_u: bool,
225    pub is_trending_d: bool,
226    pub is_ranging: bool,
227    pub is_volatile: bool,
228}
229
230impl MarketRegimeTracker {
231    pub fn new() -> Self {
232        Self {
233            closes: VecDeque::with_capacity(220),
234            ma200_hist: VecDeque::with_capacity(120),
235            ret_hist: VecDeque::with_capacity(110),
236            regime: "NEUTRAL",
237            is_trending_u: false,
238            is_trending_d: false,
239            is_ranging: false,
240            is_volatile: false,
241        }
242    }
243
244    pub fn update(&mut self, close: f64) {
245        let prev_cl = self.closes.back().copied().unwrap_or(close);
246
247        if self.closes.len() == 220 {
248            self.closes.pop_front();
249        }
250        self.closes.push_back(close);
251
252        if self.closes.len() < 200 {
253            return;
254        }
255
256        // 200-bar SMA
257        let ma200: f64 = self.closes.iter().rev().take(200).sum::<f64>() / 200.0;
258
259        if self.ma200_hist.len() == 120 {
260            self.ma200_hist.pop_front();
261        }
262        self.ma200_hist.push_back(ma200);
263
264        let ret = if prev_cl != 0.0 {
265            (close - prev_cl) / prev_cl
266        } else {
267            0.0
268        };
269        if self.ret_hist.len() == 110 {
270            self.ret_hist.pop_front();
271        }
272        self.ret_hist.push_back(ret);
273
274        if self.ma200_hist.len() < 21 || self.ret_hist.len() < 51 {
275            return;
276        }
277
278        // Slope of MA200 over last 20 bars, normalised by average MA change
279        let ma_arr: Vec<f64> = self.ma200_hist.iter().copied().collect();
280        let diffs: Vec<f64> = ma_arr.windows(2).map(|w| (w[1] - w[0]).abs()).collect();
281        let avg_chg = if diffs.is_empty() {
282            1e-9
283        } else {
284            let tail: Vec<f64> = diffs.iter().rev().take(100).copied().collect();
285            tail.iter().sum::<f64>() / tail.len() as f64
286        };
287        let slope_n = if avg_chg > 0.0 {
288            (ma200 - ma_arr[ma_arr.len() - 21]) / (avg_chg * 20.0)
289        } else {
290            0.0
291        };
292
293        // Return volatility
294        let ret_arr: Vec<f64> = self.ret_hist.iter().copied().collect();
295        let tail100: Vec<f64> = ret_arr.iter().rev().take(100).copied().collect();
296        let ret_s = std_dev(&tail100);
297        let tail50: Vec<f64> = ret_arr.iter().rev().take(50).map(|r| r.abs()).collect();
298        let ret_sma = if tail50.is_empty() {
299            ret_s.max(1e-9)
300        } else {
301            (tail50.iter().sum::<f64>() / tail50.len() as f64).max(1e-9)
302        };
303        let vol_n = ret_s / ret_sma;
304
305        self.regime = if slope_n > 1.0 {
306            "TRENDING↑"
307        } else if slope_n < -1.0 {
308            "TRENDING↓"
309        } else if vol_n > 1.5 {
310            "VOLATILE"
311        } else if vol_n < 0.8 {
312            "RANGING"
313        } else {
314            "NEUTRAL"
315        };
316
317        self.is_trending_u = self.regime == "TRENDING↑";
318        self.is_trending_d = self.regime == "TRENDING↓";
319        self.is_ranging = self.regime == "RANGING";
320        self.is_volatile = self.regime == "VOLATILE";
321    }
322}
323
324impl Default for MarketRegimeTracker {
325    fn default() -> Self {
326        Self::new()
327    }
328}
329
330// ── helpers ───────────────────────────────────────────────────────────────────
331
332fn std_dev(data: &[f64]) -> f64 {
333    if data.len() < 2 {
334        return 0.0;
335    }
336    let mean = data.iter().sum::<f64>() / data.len() as f64;
337    let var = data.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;
338    var.sqrt()
339}