use std::collections::VecDeque;
use crate::indicators::{validate_data_length, validate_period};
use crate::Candle;
use crate::Indicator;
use crate::IndicatorError;
#[derive(Debug)]
pub struct Cmf {
period: usize,
mfv_buffer: VecDeque<f64>,
volume_buffer: VecDeque<f64>,
}
impl Cmf {
pub fn new(period: usize) -> Result<Self, IndicatorError> {
validate_period(period, 1)?;
Ok(Self {
period,
mfv_buffer: VecDeque::with_capacity(period),
volume_buffer: VecDeque::with_capacity(period),
})
}
fn money_flow_multiplier(candle: &Candle) -> Result<f64, IndicatorError> {
let high = candle.high;
let low = candle.low;
let close = candle.close;
let range = high - low;
if range == 0.0 {
return Err(IndicatorError::CalculationError(
"Division by zero: high and low prices are equal".to_string(),
));
}
Ok((2.0 * close - high - low) / range)
}
fn money_flow_volume(candle: &Candle) -> Result<f64, IndicatorError> {
let mfm = Self::money_flow_multiplier(candle)?;
let volume = candle.volume;
Ok(mfm * volume)
}
}
impl Indicator<Candle, f64> for Cmf {
fn calculate(&mut self, data: &[Candle]) -> Result<Vec<f64>, IndicatorError> {
validate_data_length(data, self.period)?;
let n = data.len();
let mut result = Vec::with_capacity(n - self.period + 1);
self.reset();
for candle in data.iter().take(n) {
let mfv = Self::money_flow_volume(candle)?;
self.mfv_buffer.push_back(mfv);
self.volume_buffer.push_back(candle.volume);
if self.mfv_buffer.len() > self.period {
self.mfv_buffer.pop_front();
self.volume_buffer.pop_front();
}
if self.mfv_buffer.len() == self.period {
let sum_mfv: f64 = self.mfv_buffer.iter().sum();
let sum_volume: f64 = self.volume_buffer.iter().sum();
if sum_volume == 0.0 {
return Err(IndicatorError::CalculationError(
"Division by zero: sum of volumes is zero".to_string(),
));
}
let cmf = sum_mfv / sum_volume;
result.push(cmf);
}
}
Ok(result)
}
fn next(&mut self, value: Candle) -> Result<Option<f64>, IndicatorError> {
let mfv = Self::money_flow_volume(&value)?;
self.mfv_buffer.push_back(mfv);
self.volume_buffer.push_back(value.volume);
if self.mfv_buffer.len() > self.period {
self.mfv_buffer.pop_front();
self.volume_buffer.pop_front();
}
if self.mfv_buffer.len() == self.period {
let sum_mfv: f64 = self.mfv_buffer.iter().sum();
let sum_volume: f64 = self.volume_buffer.iter().sum();
if sum_volume == 0.0 {
return Err(IndicatorError::CalculationError(
"Division by zero: sum of volumes is zero".to_string(),
));
}
let cmf = sum_mfv / sum_volume;
Ok(Some(cmf))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.mfv_buffer.clear();
self.volume_buffer.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::indicators::Candle;
#[test]
fn test_cmf_new() {
assert!(Cmf::new(14).is_ok());
assert!(Cmf::new(0).is_err());
}
#[test]
fn test_cmf_calculation() {
let mut cmf = Cmf::new(2).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 12.0,
high: 14.0,
low: 10.0,
close: 11.0,
volume: 800.0,
},
];
let result = cmf.calculate(&candles).unwrap();
assert_eq!(result.len(), 2);
for cmf_value in &result {
assert!(*cmf_value >= -1.0 && *cmf_value <= 1.0);
}
assert!((result[0] - 0.5).abs() < 0.01);
}
#[test]
fn test_cmf_zero_volume_sum() {
let mut cmf = Cmf::new(2).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 0.0, },
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 0.0, },
];
let result = cmf.calculate(&candles);
assert!(result.is_err());
if let Err(IndicatorError::CalculationError(msg)) = result {
assert!(msg.contains("division by zero") || msg.contains("sum of volumes is zero"));
} else {
panic!("Expected CalculationError for zero volume sum");
}
cmf.reset();
assert_eq!(cmf.next(candles[0]).unwrap(), None); let result = cmf.next(candles[1]);
assert!(result.is_err());
if let Err(IndicatorError::CalculationError(msg)) = result {
assert!(msg.contains("division by zero") || msg.contains("sum of volumes is zero"));
} else {
panic!("Expected CalculationError for zero volume sum in streaming mode");
}
}
#[test]
fn test_cmf_boundary_conditions() {
let mut cmf = Cmf::new(3).unwrap();
let max_candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.9, volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.9,
high: 14.0,
low: 11.0,
close: 13.9, volume: 1000.0,
},
Candle {
timestamp: 3,
open: 13.9,
high: 16.0,
low: 13.0,
close: 15.9, volume: 1000.0,
},
];
let min_candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 8.1, volume: 1000.0,
},
Candle {
timestamp: 2,
open: 8.1,
high: 10.0,
low: 7.0,
close: 7.1, volume: 1000.0,
},
Candle {
timestamp: 3,
open: 7.1,
high: 9.0,
low: 6.0,
close: 6.1, volume: 1000.0,
},
];
let max_result = cmf.calculate(&max_candles).unwrap();
assert_eq!(max_result.len(), 1);
assert!(
max_result[0] > 0.9,
"CMF value should be close to +1, got {}",
max_result[0]
);
assert!(
max_result[0] <= 1.0,
"CMF value should not exceed +1, got {}",
max_result[0]
);
cmf.reset();
let min_result = cmf.calculate(&min_candles).unwrap();
assert_eq!(min_result.len(), 1);
assert!(
min_result[0] < -0.9,
"CMF value should be close to -1, got {}",
min_result[0]
);
assert!(
min_result[0] >= -1.0,
"CMF value should not be less than -1, got {}",
min_result[0]
);
}
#[test]
fn test_cmf_minimum_period() {
let mut cmf = Cmf::new(1).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 12.0,
high: 14.0,
low: 10.0,
close: 11.0,
volume: 800.0,
},
];
let result = cmf.calculate(&candles).unwrap();
assert_eq!(result.len(), 3);
assert!((result[0] - 0.5).abs() < 0.001);
assert!((result[1] - 0.5).abs() < 0.001);
assert!((result[2] - (-0.5)).abs() < 0.001);
cmf.reset();
assert_eq!(cmf.next(candles[0]).unwrap().unwrap(), 0.5);
assert_eq!(cmf.next(candles[1]).unwrap().unwrap(), 0.5);
assert!((cmf.next(candles[2]).unwrap().unwrap() - (-0.5)).abs() < 0.001);
}
#[test]
fn test_cmf_reset_partial_data() {
let mut cmf = Cmf::new(3).unwrap();
let candles = [
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 12.0,
high: 14.0,
low: 10.0,
close: 11.0,
volume: 800.0,
},
Candle {
timestamp: 4,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1500.0,
},
Candle {
timestamp: 5,
open: 12.0,
high: 15.0,
low: 11.0,
close: 14.0,
volume: 2000.0,
},
];
cmf.next(candles[0]).unwrap();
cmf.next(candles[1]).unwrap();
cmf.reset();
cmf.next(candles[2]).unwrap();
cmf.next(candles[3]).unwrap();
let result = cmf.next(candles[4]).unwrap();
assert!(result.is_some());
cmf.reset();
let expected = cmf.calculate(&candles[2..5]).unwrap()[0];
assert!((result.unwrap() - expected).abs() < 0.001);
}
#[test]
fn test_cmf_batch_vs_streaming() {
let period = 3;
let mut batch_cmf = Cmf::new(period).unwrap();
let mut streaming_cmf = Cmf::new(period).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 1000.0,
},
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1200.0,
},
Candle {
timestamp: 3,
open: 12.0,
high: 14.0,
low: 10.0,
close: 11.0,
volume: 800.0,
},
Candle {
timestamp: 4,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 1500.0,
},
Candle {
timestamp: 5,
open: 12.0,
high: 15.0,
low: 11.0,
close: 14.0,
volume: 2000.0,
},
];
let batch_result = batch_cmf.calculate(&candles).unwrap();
let mut streaming_result = Vec::new();
for candle in &candles {
if let Some(value) = streaming_cmf.next(*candle).unwrap() {
streaming_result.push(value);
}
}
assert_eq!(batch_result.len(), streaming_result.len());
for i in 0..batch_result.len() {
assert!(
(batch_result[i] - streaming_result[i]).abs() < 0.001,
"Batch and streaming results differ at index {}: batch={}, streaming={}",
i,
batch_result[i],
streaming_result[i]
);
}
}
#[test]
fn test_cmf_extreme_volume_values() {
let mut cmf = Cmf::new(2).unwrap();
let candles = vec![
Candle {
timestamp: 1,
open: 10.0,
high: 12.0,
low: 8.0,
close: 11.0,
volume: 1_000_000_000.0, },
Candle {
timestamp: 2,
open: 11.0,
high: 13.0,
low: 9.0,
close: 12.0,
volume: 2_000_000_000.0, },
];
let result = cmf.calculate(&candles).unwrap();
assert_eq!(result.len(), 1);
assert!(
result[0] >= -1.0 && result[0] <= 1.0,
"CMF with extreme volumes should still be between -1 and 1, got: {}",
result[0]
);
}
}