use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct UltimateOscillator {
short: usize,
mid: usize,
long: usize,
longest: usize,
prev_close: Option<f64>,
window: VecDeque<(f64, f64)>,
sum_bp_short: f64,
sum_tr_short: f64,
sum_bp_mid: f64,
sum_tr_mid: f64,
sum_bp_long: f64,
sum_tr_long: f64,
pairs: usize,
last: Option<f64>,
}
impl UltimateOscillator {
pub fn new(short: usize, mid: usize, long: usize) -> Result<Self> {
if short == 0 || mid == 0 || long == 0 {
return Err(Error::PeriodZero);
}
let longest = short.max(mid).max(long);
Ok(Self {
short,
mid,
long,
longest,
prev_close: None,
window: VecDeque::with_capacity(longest + 1),
sum_bp_short: 0.0,
sum_tr_short: 0.0,
sum_bp_mid: 0.0,
sum_tr_mid: 0.0,
sum_bp_long: 0.0,
sum_tr_long: 0.0,
pairs: 0,
last: None,
})
}
pub fn classic() -> Self {
Self::new(7, 14, 28).expect("classic Ultimate Oscillator periods are valid")
}
pub const fn periods(&self) -> (usize, usize, usize) {
(self.short, self.mid, self.long)
}
pub const fn value(&self) -> Option<f64> {
self.last
}
}
impl Indicator for UltimateOscillator {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
let Some(prev_close) = self.prev_close else {
self.prev_close = Some(candle.close);
return None;
};
self.prev_close = Some(candle.close);
let true_low = candle.low.min(prev_close);
let bp = candle.close - true_low;
let tr = candle.high.max(prev_close) - true_low;
self.window.push_back((bp, tr));
let n = self.window.len();
self.sum_bp_short += bp;
self.sum_tr_short += tr;
self.sum_bp_mid += bp;
self.sum_tr_mid += tr;
self.sum_bp_long += bp;
self.sum_tr_long += tr;
if n > self.short {
let (b, t) = self.window[n - 1 - self.short];
self.sum_bp_short -= b;
self.sum_tr_short -= t;
}
if n > self.mid {
let (b, t) = self.window[n - 1 - self.mid];
self.sum_bp_mid -= b;
self.sum_tr_mid -= t;
}
if n > self.long {
let (b, t) = self.window[n - 1 - self.long];
self.sum_bp_long -= b;
self.sum_tr_long -= t;
}
if self.window.len() > self.longest {
self.window.pop_front();
}
self.pairs += 1;
if self.pairs < self.longest {
return None;
}
let avg = |bp_sum: f64, tr_sum: f64| {
if tr_sum == 0.0 {
0.5
} else {
bp_sum / tr_sum
}
};
let avg_short = avg(self.sum_bp_short, self.sum_tr_short);
let avg_mid = avg(self.sum_bp_mid, self.sum_tr_mid);
let avg_long = avg(self.sum_bp_long, self.sum_tr_long);
let uo = 100.0 * (4.0 * avg_short + 2.0 * avg_mid + avg_long) / 7.0;
self.last = Some(uo);
Some(uo)
}
fn reset(&mut self) {
self.prev_close = None;
self.window.clear();
self.sum_bp_short = 0.0;
self.sum_tr_short = 0.0;
self.sum_bp_mid = 0.0;
self.sum_tr_mid = 0.0;
self.sum_bp_long = 0.0;
self.sum_tr_long = 0.0;
self.pairs = 0;
self.last = None;
}
fn warmup_period(&self) -> usize {
self.longest + 1
}
fn is_ready(&self) -> bool {
self.last.is_some()
}
fn name(&self) -> &'static str {
"UltimateOscillator"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
fn flat(price: f64, ts: i64) -> Candle {
Candle::new(price, price, price, price, 1.0, ts).unwrap()
}
#[test]
fn new_rejects_zero_period() {
assert!(matches!(
UltimateOscillator::new(0, 14, 28),
Err(Error::PeriodZero)
));
assert!(matches!(
UltimateOscillator::new(7, 0, 28),
Err(Error::PeriodZero)
));
assert!(matches!(
UltimateOscillator::new(7, 14, 0),
Err(Error::PeriodZero)
));
}
#[test]
fn accessors_and_metadata() {
let mut uo = UltimateOscillator::new(7, 14, 28).unwrap();
assert_eq!(uo.periods(), (7, 14, 28));
assert_eq!(uo.name(), "UltimateOscillator");
assert_eq!(uo.value(), None);
let warmup = i64::try_from(uo.warmup_period()).unwrap();
let candles: Vec<Candle> = (0..warmup)
.map(|i| {
let p = 100.0 + (i as f64 * 0.3).sin() * 5.0;
Candle::new(p, p + 1.0, p - 1.0, p, 1.0, i).unwrap()
})
.collect();
for c in &candles {
uo.update(*c);
}
assert!(uo.value().is_some());
}
#[test]
fn first_emission_at_warmup_period() {
let mut uo = UltimateOscillator::new(2, 3, 5).unwrap();
assert_eq!(uo.warmup_period(), 6);
let candles: Vec<Candle> = (0..20).map(|i| flat(100.0 + i as f64, i)).collect();
let out = uo.batch(&candles);
for v in out.iter().take(5) {
assert!(v.is_none());
}
assert!(out[5].is_some());
}
#[test]
fn pure_uptrend_saturates_at_100() {
let mut uo = UltimateOscillator::new(2, 3, 5).unwrap();
let candles: Vec<Candle> = (0..30).map(|i| flat(100.0 + i as f64, i)).collect();
for v in uo.batch(&candles).into_iter().flatten() {
assert_relative_eq!(v, 100.0, epsilon = 1e-9);
}
}
#[test]
fn pure_downtrend_saturates_at_0() {
let mut uo = UltimateOscillator::new(2, 3, 5).unwrap();
let candles: Vec<Candle> = (0..30).map(|i| flat(100.0 - i as f64, i)).collect();
for v in uo.batch(&candles).into_iter().flatten() {
assert_relative_eq!(v, 0.0, epsilon = 1e-9);
}
}
#[test]
fn flat_market_reads_50() {
let mut uo = UltimateOscillator::new(2, 3, 5).unwrap();
let candles: Vec<Candle> = (0..30).map(|i| flat(100.0, i)).collect();
for v in uo.batch(&candles).into_iter().flatten() {
assert_relative_eq!(v, 50.0, epsilon = 1e-9);
}
}
#[test]
fn output_stays_within_0_100() {
let mut uo = UltimateOscillator::classic();
let candles: Vec<Candle> = (0..200)
.map(|i| {
let mid = 100.0 + (i as f64 * 0.2).sin() * 12.0;
Candle::new(mid, mid + 3.0, mid - 3.0, mid + 1.0, 10.0, i).unwrap()
})
.collect();
for v in uo.batch(&candles).into_iter().flatten() {
assert!((0.0..=100.0).contains(&v), "UO out of range: {v}");
}
}
#[test]
fn reset_clears_state() {
let mut uo = UltimateOscillator::new(2, 3, 5).unwrap();
let candles: Vec<Candle> = (0..20).map(|i| flat(100.0 + i as f64, i)).collect();
uo.batch(&candles);
assert!(uo.is_ready());
uo.reset();
assert!(!uo.is_ready());
assert_eq!(uo.update(candles[0]), None);
}
#[test]
fn batch_equals_streaming() {
let candles: Vec<Candle> = (0..120)
.map(|i| {
let mid = 100.0 + (i as f64 * 0.3).sin() * 10.0;
Candle::new(mid, mid + 2.0, mid - 2.0, mid + 0.5, 10.0, i).unwrap()
})
.collect();
let batch = UltimateOscillator::classic().batch(&candles);
let mut b = UltimateOscillator::classic();
let streamed: Vec<_> = candles.iter().map(|c| b.update(*c)).collect();
assert_eq!(batch, streamed);
}
}