use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct Smma {
period: usize,
seed: VecDeque<f64>,
seed_sum: f64,
current: Option<f64>,
}
impl Smma {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
period,
seed: VecDeque::with_capacity(period),
seed_sum: 0.0,
current: None,
})
}
pub const fn period(&self) -> usize {
self.period
}
pub const fn value(&self) -> Option<f64> {
self.current
}
}
impl Indicator for Smma {
type Input = f64;
type Output = f64;
fn update(&mut self, input: f64) -> Option<f64> {
if !input.is_finite() {
return self.current;
}
if let Some(prev) = self.current {
let period = self.period as f64;
self.current = Some((prev * (period - 1.0) + input) / period);
} else {
self.seed.push_back(input);
self.seed_sum += input;
if self.seed.len() == self.period {
self.current = Some(self.seed_sum / self.period as f64);
}
}
self.current
}
fn reset(&mut self) {
self.seed.clear();
self.seed_sum = 0.0;
self.current = None;
}
fn warmup_period(&self) -> usize {
self.period
}
fn is_ready(&self) -> bool {
self.current.is_some()
}
fn name(&self) -> &'static str {
"SMMA"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
#[test]
fn new_rejects_zero_period() {
assert!(matches!(Smma::new(0), Err(Error::PeriodZero)));
}
#[test]
fn accessors_and_metadata() {
let mut smma = Smma::new(7).unwrap();
assert_eq!(smma.period(), 7);
assert_eq!(smma.warmup_period(), 7);
assert_eq!(smma.name(), "SMMA");
assert_eq!(smma.value(), None);
for i in 1..=7 {
smma.update(f64::from(i));
}
assert!(smma.value().is_some());
}
#[test]
fn warmup_then_recurrence() {
let mut smma = Smma::new(3).unwrap();
assert_eq!(smma.update(1.0), None);
assert_eq!(smma.update(2.0), None);
assert_eq!(smma.update(3.0), Some(2.0));
assert_relative_eq!(
smma.update(4.0).unwrap(),
(2.0 * 2.0 + 4.0) / 3.0,
epsilon = 1e-12
);
assert_relative_eq!(
smma.update(5.0).unwrap(),
((2.0 * 2.0 + 4.0) / 3.0 * 2.0 + 5.0) / 3.0,
epsilon = 1e-12
);
}
#[test]
fn period_one_is_pass_through() {
let mut smma = Smma::new(1).unwrap();
assert_eq!(smma.update(5.0), Some(5.0));
assert_eq!(smma.update(10.0), Some(10.0));
}
#[test]
fn constant_series_yields_the_constant() {
let mut smma = Smma::new(5).unwrap();
let out = smma.batch(&[7.0; 20]);
for x in out.iter().skip(4) {
assert_relative_eq!(x.unwrap(), 7.0, epsilon = 1e-12);
}
}
#[test]
fn ignores_non_finite_input() {
let mut smma = Smma::new(3).unwrap();
smma.batch(&[1.0, 2.0, 3.0]);
assert_eq!(smma.update(f64::NAN), Some(2.0));
assert_eq!(smma.update(f64::INFINITY), Some(2.0));
}
#[test]
fn reset_clears_state() {
let mut smma = Smma::new(3).unwrap();
smma.batch(&[1.0, 2.0, 3.0, 4.0]);
assert!(smma.is_ready());
smma.reset();
assert!(!smma.is_ready());
assert_eq!(smma.update(10.0), None);
}
#[test]
fn batch_equals_streaming() {
let prices: Vec<f64> = (1..=30).map(f64::from).collect();
let batch = Smma::new(7).unwrap().batch(&prices);
let mut b = Smma::new(7).unwrap();
let streamed: Vec<_> = prices.iter().map(|p| b.update(*p)).collect();
assert_eq!(batch, streamed);
}
}