use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::indicators::rolling_quantile::quantile_sorted;
use crate::traits::Indicator;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MedianChannelOutput {
pub upper: f64,
pub middle: f64,
pub lower: f64,
}
#[derive(Debug, Clone)]
pub struct MedianChannel {
period: usize,
multiplier: f64,
window: VecDeque<f64>,
scratch: Vec<f64>,
deviations: Vec<f64>,
}
impl MedianChannel {
pub fn new(period: usize, multiplier: f64) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
if !multiplier.is_finite() || multiplier <= 0.0 {
return Err(Error::NonPositiveMultiplier);
}
Ok(Self {
period,
multiplier,
window: VecDeque::with_capacity(period),
scratch: Vec::with_capacity(period),
deviations: Vec::with_capacity(period),
})
}
pub const fn period(&self) -> usize {
self.period
}
pub const fn multiplier(&self) -> f64 {
self.multiplier
}
}
impl Indicator for MedianChannel {
type Input = f64;
type Output = MedianChannelOutput;
fn update(&mut self, value: f64) -> Option<MedianChannelOutput> {
if self.window.len() == self.period {
self.window.pop_front();
}
self.window.push_back(value);
if self.window.len() < self.period {
return None;
}
self.scratch.clear();
self.scratch.extend(self.window.iter().copied());
self.scratch.sort_by(f64::total_cmp);
let median = quantile_sorted(&self.scratch, 0.5);
self.deviations.clear();
for &v in &self.window {
self.deviations.push((v - median).abs());
}
self.deviations.sort_by(f64::total_cmp);
let mad = quantile_sorted(&self.deviations, 0.5);
let offset = self.multiplier * mad;
Some(MedianChannelOutput {
upper: median + offset,
middle: median,
lower: median - offset,
})
}
fn reset(&mut self) {
self.window.clear();
self.scratch.clear();
self.deviations.clear();
}
fn warmup_period(&self) -> usize {
self.period
}
fn is_ready(&self) -> bool {
self.window.len() == self.period
}
fn name(&self) -> &'static str {
"MedianChannel"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
#[test]
fn rejects_zero_period() {
assert!(matches!(MedianChannel::new(0, 2.0), Err(Error::PeriodZero)));
assert!(MedianChannel::new(1, 2.0).is_ok());
}
#[test]
fn rejects_non_positive_multiplier() {
assert!(matches!(
MedianChannel::new(20, 0.0),
Err(Error::NonPositiveMultiplier)
));
assert!(matches!(
MedianChannel::new(20, -1.0),
Err(Error::NonPositiveMultiplier)
));
assert!(matches!(
MedianChannel::new(20, f64::NAN),
Err(Error::NonPositiveMultiplier)
));
}
#[test]
fn accessors_and_metadata() {
let mc = MedianChannel::new(20, 2.0).unwrap();
assert_eq!(mc.period(), 20);
assert_relative_eq!(mc.multiplier(), 2.0, epsilon = 1e-12);
assert_eq!(mc.warmup_period(), 20);
assert_eq!(mc.name(), "MedianChannel");
assert!(!mc.is_ready());
}
#[test]
fn warms_up_then_emits() {
let mut mc = MedianChannel::new(5, 2.0).unwrap();
for v in [1.0, 2.0, 3.0, 4.0] {
assert!(mc.update(v).is_none());
}
assert!(mc.update(5.0).is_some());
assert!(mc.is_ready());
}
#[test]
fn known_channel() {
let mut mc = MedianChannel::new(5, 2.0).unwrap();
let out = mc.batch(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let last = out[4].unwrap();
assert_relative_eq!(last.middle, 3.0, epsilon = 1e-12);
assert_relative_eq!(last.upper, 5.0, epsilon = 1e-12);
assert_relative_eq!(last.lower, 1.0, epsilon = 1e-12);
}
#[test]
fn robust_to_outlier() {
let mut mc = MedianChannel::new(5, 2.0).unwrap();
let out = mc.batch(&[1.0, 2.0, 3.0, 4.0, 1_000.0]);
assert_relative_eq!(out[4].unwrap().middle, 3.0, epsilon = 1e-12);
}
#[test]
fn rolling_window_evicts_oldest() {
let mut mc = MedianChannel::new(5, 2.0).unwrap();
let out = mc.batch(&[10.0, 10.0, 10.0, 10.0, 10.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
let last = out[9].unwrap();
assert_relative_eq!(last.middle, 3.0, epsilon = 1e-12);
assert_relative_eq!(last.upper, 5.0, epsilon = 1e-12);
assert_relative_eq!(last.lower, 1.0, epsilon = 1e-12);
}
#[test]
fn reset_clears_state() {
let mut mc = MedianChannel::new(5, 2.0).unwrap();
for v in [1.0, 2.0, 3.0, 4.0, 5.0] {
mc.update(v);
}
assert!(mc.is_ready());
mc.reset();
assert!(!mc.is_ready());
assert!(mc.update(1.0).is_none());
}
}