use crate::indicators::IndicatorError;
pub fn validate_period(period: usize, min_period: usize) -> Result<(), IndicatorError> {
if period < min_period {
return Err(IndicatorError::InvalidParameter(format!(
"Period must be greater than or equal to {}",
min_period
)));
}
Ok(())
}
pub fn validate_data_length<T>(data: &[T], min_length: usize) -> Result<(), IndicatorError> {
if data.len() < min_length {
return Err(IndicatorError::InsufficientData(format!(
"Input data length must be at least {}",
min_length
)));
}
Ok(())
}
pub fn calculate_sma(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
validate_period(period, 1)?;
validate_data_length(data, period)?;
let n = data.len();
let mut result = Vec::with_capacity(n - period + 1);
let mut sum = data.iter().take(period).sum::<f64>();
result.push(sum / period as f64);
for i in period..n {
sum = sum + data[i] - data[i - period];
result.push(sum / period as f64);
}
Ok(result)
}
pub fn calculate_ema(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
validate_period(period, 1)?;
validate_data_length(data, period)?;
let multiplier = 2.0 / (period as f64 + 1.0);
let mut result = Vec::with_capacity(data.len());
let mut current = data[0];
result.push(current);
for &value in &data[1..] {
current = (value - current) * multiplier + current;
result.push(current);
}
Ok(result)
}
pub fn standard_deviation(data: &[f64], mean: Option<f64>) -> Result<f64, IndicatorError> {
if data.is_empty() {
return Err(IndicatorError::InsufficientData(
"Cannot calculate standard deviation of empty dataset".to_string(),
));
}
if data.len() == 1 {
return Ok(0.0);
}
let mean = mean.unwrap_or_else(|| data.iter().sum::<f64>() / data.len() as f64);
let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;
Ok(variance.sqrt())
}
pub fn rate_of_change(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
validate_period(period, 1)?;
validate_data_length(data, period + 1)?;
let n = data.len();
let mut result = Vec::with_capacity(n - period);
for i in period..n {
let current = data[i];
let past = data[i - period];
let roc = (current - past) / past * 100.0;
result.push(roc);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_period() {
assert!(validate_period(10, 5).is_ok());
assert!(validate_period(5, 5).is_ok());
assert!(validate_period(1, 1).is_ok());
let result = validate_period(4, 5);
assert!(result.is_err());
if let Err(IndicatorError::InvalidParameter(msg)) = result {
assert!(msg.contains("5"));
} else {
panic!("Expected InvalidParameter error");
}
}
#[test]
fn test_validate_data_length() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert!(validate_data_length(&data, 5).is_ok());
assert!(validate_data_length(&data, 3).is_ok());
let result = validate_data_length(&data, 6);
assert!(result.is_err());
if let Err(IndicatorError::InsufficientData(msg)) = result {
assert!(msg.contains("6"));
} else {
panic!("Expected InsufficientData error");
}
}
#[test]
fn test_calculate_sma() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let sma_result = calculate_sma(&data, 3).unwrap();
assert_eq!(sma_result.len(), 8);
assert_eq!(sma_result[0], (1.0 + 2.0 + 3.0) / 3.0);
assert_eq!(sma_result[1], (2.0 + 3.0 + 4.0) / 3.0);
assert_eq!(sma_result[7], (8.0 + 9.0 + 10.0) / 3.0);
let sma_result = calculate_sma(&data, 5).unwrap();
assert_eq!(sma_result.len(), 6);
assert_eq!(sma_result[0], (1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 5.0);
assert_eq!(sma_result[5], (6.0 + 7.0 + 8.0 + 9.0 + 10.0) / 5.0);
let result = calculate_sma(&data, 11);
assert!(result.is_err());
}
#[test]
fn test_calculate_ema() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let ema_result = calculate_ema(&data, 3).unwrap();
assert_eq!(ema_result.len(), data.len());
assert_eq!(ema_result[0], 1.0);
assert_eq!(ema_result[1], 1.5);
assert_eq!(ema_result[2], 2.25);
let result = calculate_ema(&data, 11);
assert!(result.is_err());
}
#[test]
fn test_standard_deviation() {
let data = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let std_dev = standard_deviation(&data, Some(6.0)).unwrap();
assert!((std_dev - 2.828427).abs() < 0.000001);
let std_dev = standard_deviation(&data, None).unwrap();
assert!((std_dev - 2.828427).abs() < 0.000001);
let result = standard_deviation(&[] as &[f64], None);
assert!(result.is_err());
let single_data = vec![5.0];
let std_dev = standard_deviation(&single_data, None).unwrap();
assert_eq!(std_dev, 0.0);
}
#[test]
fn test_rate_of_change() {
let data = vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0];
let roc_result = rate_of_change(&data, 1).unwrap();
assert_eq!(roc_result.len(), 5);
assert_eq!(roc_result[0], 10.0);
assert!((roc_result[4] - 7.142857).abs() < 0.000001);
let roc_result = rate_of_change(&data, 3).unwrap();
assert_eq!(roc_result.len(), 3);
assert_eq!(roc_result[0], 30.0);
let result = rate_of_change(&data, 6);
assert!(result.is_err());
}
}