use std::collections::VecDeque;
use crate::error::{Error, Result};
use crate::ohlcv::Candle;
use crate::traits::Indicator;
#[derive(Debug, Clone)]
pub struct Mfi {
period: usize,
prev_tp: Option<f64>,
pos_window: VecDeque<f64>,
neg_window: VecDeque<f64>,
pos_sum: f64,
neg_sum: f64,
}
impl Mfi {
pub fn new(period: usize) -> Result<Self> {
if period == 0 {
return Err(Error::PeriodZero);
}
Ok(Self {
period,
prev_tp: None,
pos_window: VecDeque::with_capacity(period),
neg_window: VecDeque::with_capacity(period),
pos_sum: 0.0,
neg_sum: 0.0,
})
}
pub const fn period(&self) -> usize {
self.period
}
}
impl Indicator for Mfi {
type Input = Candle;
type Output = f64;
fn update(&mut self, candle: Candle) -> Option<f64> {
let tp = candle.typical_price();
let Some(prev) = self.prev_tp else {
self.prev_tp = Some(tp);
return None;
};
let mf = tp * candle.volume;
let (pos_flow, neg_flow) = if tp > prev {
(mf, 0.0)
} else if tp < prev {
(0.0, mf)
} else {
(0.0, 0.0)
};
if self.pos_window.len() == self.period {
self.pos_sum -= self.pos_window.pop_front().expect("non-empty");
self.neg_sum -= self.neg_window.pop_front().expect("non-empty");
}
self.pos_window.push_back(pos_flow);
self.neg_window.push_back(neg_flow);
self.pos_sum += pos_flow;
self.neg_sum += neg_flow;
self.prev_tp = Some(tp);
if self.pos_window.len() < self.period {
return None;
}
if self.pos_sum == 0.0 && self.neg_sum == 0.0 {
return Some(50.0);
}
if self.neg_sum == 0.0 {
return Some(100.0);
}
let mr = self.pos_sum / self.neg_sum;
Some(100.0 - 100.0 / (1.0 + mr))
}
fn reset(&mut self) {
self.prev_tp = None;
self.pos_window.clear();
self.neg_window.clear();
self.pos_sum = 0.0;
self.neg_sum = 0.0;
}
fn warmup_period(&self) -> usize {
self.period + 1
}
fn is_ready(&self) -> bool {
self.pos_window.len() == self.period
}
fn name(&self) -> &'static str {
"MFI"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BatchExt;
use approx::assert_relative_eq;
fn c(price: f64, volume: f64) -> Candle {
Candle::new(price, price, price, price, volume, 0).unwrap()
}
#[test]
fn pure_uptrend_yields_high_mfi() {
let candles: Vec<Candle> = (1..30).map(|i| c(f64::from(i), 100.0)).collect();
let mut mfi = Mfi::new(14).unwrap();
let last = mfi.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 100.0, epsilon = 1e-9);
}
#[test]
fn pure_downtrend_yields_low_mfi() {
let candles: Vec<Candle> = (1..30).rev().map(|i| c(f64::from(i), 100.0)).collect();
let mut mfi = Mfi::new(14).unwrap();
let last = mfi.batch(&candles).into_iter().flatten().last().unwrap();
assert_relative_eq!(last, 0.0, epsilon = 1e-9);
}
#[test]
fn batch_equals_streaming() {
let candles: Vec<Candle> = (0..40).map(|i| c(f64::from(i) + 10.0, 50.0)).collect();
let mut a = Mfi::new(14).unwrap();
let mut b = Mfi::new(14).unwrap();
assert_eq!(
a.batch(&candles),
candles.iter().map(|x| b.update(*x)).collect::<Vec<_>>()
);
}
#[test]
fn reset_clears_state() {
let candles: Vec<Candle> = (1..30).map(|i| c(f64::from(i), 100.0)).collect();
let mut mfi = Mfi::new(14).unwrap();
mfi.batch(&candles);
assert!(mfi.is_ready());
mfi.reset();
assert!(!mfi.is_ready());
}
#[test]
fn accessors_and_metadata() {
let mfi = Mfi::new(14).unwrap();
assert_eq!(mfi.period(), 14);
assert_eq!(mfi.name(), "MFI");
}
#[test]
fn flat_typical_prices_default_to_50() {
let mut mfi = Mfi::new(3).unwrap();
let candles: Vec<Candle> = (0..6)
.map(|i| Candle::new(10.0, 10.0, 10.0, 10.0, 1.0, i).unwrap())
.collect();
let last = mfi
.batch(&candles)
.into_iter()
.flatten()
.last()
.expect("emits");
assert_eq!(last, 50.0);
}
#[test]
fn rejects_zero_period() {
assert!(Mfi::new(0).is_err());
}
#[test]
fn first_value_emitted_on_period_plus_one_candle() {
let candles: Vec<Candle> = (1..=20).map(|i| c(f64::from(i), 100.0)).collect();
let mut mfi = Mfi::new(5).unwrap();
let out = mfi.batch(&candles);
for (i, v) in out.iter().enumerate().take(5) {
assert!(v.is_none(), "candle index {i} must be None during warmup");
}
assert!(
out[5].is_some(),
"first MFI value lands at index period (5)"
);
assert_eq!(mfi.warmup_period(), 6);
}
#[test]
fn known_value_period_2() {
let candles = vec![c(10.0, 100.0), c(12.0, 100.0), c(11.0, 100.0)];
let mut mfi = Mfi::new(2).unwrap();
let out = mfi.batch(&candles);
assert!(out[0].is_none());
assert!(out[1].is_none());
assert_relative_eq!(out[2].unwrap(), 1200.0 / 23.0, epsilon = 1e-9);
}
}