rsta 0.1.0

Technical analysis indicators, streaming signals, and a single-asset backtesting engine for Rust
Documentation
//! Utility functions for technical indicators

use crate::indicators::IndicatorError;

/// Validate period parameter
///
/// # Arguments
/// * `period` - The period to validate
/// * `min_period` - The minimum allowed period
///
/// # Returns
/// * `Result<(), IndicatorError>` - Ok if valid, Err otherwise
pub fn validate_period(period: usize, min_period: usize) -> Result<(), IndicatorError> {
    if period < min_period {
        return Err(IndicatorError::InvalidParameter(format!(
            "Period must be greater than or equal to {}",
            min_period
        )));
    }
    Ok(())
}

/// Validate data length against a minimum length
///
/// # Arguments
/// * `data` - The data slice to validate
/// * `min_length` - The minimum allowed length
///
/// # Returns
/// * `Result<(), IndicatorError>` - Ok if valid, Err otherwise
pub fn validate_data_length<T>(data: &[T], min_length: usize) -> Result<(), IndicatorError> {
    if data.len() < min_length {
        return Err(IndicatorError::InsufficientData(format!(
            "Input data length must be at least {}",
            min_length
        )));
    }
    Ok(())
}

/// Calculate Simple Moving Average (SMA)
///
/// # Arguments
/// * `data` - Data values
/// * `period` - Period for SMA calculation
///
/// # Returns
/// * `Result<Vec<f64>, IndicatorError>` - Vector of SMA values
pub fn calculate_sma(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
    validate_period(period, 1)?;
    validate_data_length(data, period)?;

    let n = data.len();
    let mut result = Vec::with_capacity(n - period + 1);

    // Calculate first SMA value
    let mut sum = data.iter().take(period).sum::<f64>();
    result.push(sum / period as f64);

    // Calculate the rest using the sliding window
    for i in period..n {
        sum = sum + data[i] - data[i - period];
        result.push(sum / period as f64);
    }

    Ok(result)
}

/// Recursive Exponential Moving Average — `adjust=False` semantics.
///
/// Seeds with `data[0]` and iterates `EMA[t] = α * x[t] + (1 - α) * EMA[t-1]`
/// with `α = 2 / (period + 1)`. The output has the same length as the input
/// (one EMA per bar, starting from the first).
///
/// This matches `Ema::next`'s convention exactly, as well as pandas'
/// `ewm(span=period, adjust=False).mean()`, TradingView, and `ta-rs`.
/// `validate_data_length(data, period)` is still enforced so callers get a
/// clear error for trivially short inputs.
///
/// # Arguments
/// * `data` - Data values
/// * `period` - Period for EMA calculation
///
/// # Returns
/// * `Result<Vec<f64>, IndicatorError>` - Vector of EMA values
pub fn calculate_ema(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
    validate_period(period, 1)?;
    validate_data_length(data, period)?;

    let multiplier = 2.0 / (period as f64 + 1.0);
    let mut result = Vec::with_capacity(data.len());
    let mut current = data[0];
    result.push(current);
    for &value in &data[1..] {
        current = (value - current) * multiplier + current;
        result.push(current);
    }
    Ok(result)
}

/// Calculate standard deviation
///
/// # Arguments
/// * `data` - Data values
/// * `mean` - Mean value of the data (if None, will be calculated)
///
/// # Returns
/// * `Result<f64, IndicatorError>` - Standard deviation value
pub fn standard_deviation(data: &[f64], mean: Option<f64>) -> Result<f64, IndicatorError> {
    if data.is_empty() {
        return Err(IndicatorError::InsufficientData(
            "Cannot calculate standard deviation of empty dataset".to_string(),
        ));
    }

    if data.len() == 1 {
        return Ok(0.0);
    }

    let mean = mean.unwrap_or_else(|| data.iter().sum::<f64>() / data.len() as f64);

    // Use n denominator for population standard deviation
    let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / data.len() as f64;

    Ok(variance.sqrt())
}

/// Calculate the rate of change (ROC)
///
/// # Arguments
/// * `data` - Data values
/// * `period` - Period for ROC calculation
///
/// # Returns
/// * `Result<Vec<f64>, IndicatorError>` - Vector of ROC values
pub fn rate_of_change(data: &[f64], period: usize) -> Result<Vec<f64>, IndicatorError> {
    validate_period(period, 1)?;
    validate_data_length(data, period + 1)?;

    let n = data.len();
    let mut result = Vec::with_capacity(n - period);

    // Creating iterator pairs of (current value, past value) separated by period
    for i in period..n {
        let current = data[i];
        let past = data[i - period];
        let roc = (current - past) / past * 100.0;
        result.push(roc);
    }
    Ok(result)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_validate_period() {
        assert!(validate_period(10, 5).is_ok());
        assert!(validate_period(5, 5).is_ok());
        assert!(validate_period(1, 1).is_ok());

        let result = validate_period(4, 5);
        assert!(result.is_err());
        if let Err(IndicatorError::InvalidParameter(msg)) = result {
            assert!(msg.contains("5"));
        } else {
            panic!("Expected InvalidParameter error");
        }
    }

    #[test]
    fn test_validate_data_length() {
        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
        assert!(validate_data_length(&data, 5).is_ok());
        assert!(validate_data_length(&data, 3).is_ok());

        let result = validate_data_length(&data, 6);
        assert!(result.is_err());
        if let Err(IndicatorError::InsufficientData(msg)) = result {
            assert!(msg.contains("6"));
        } else {
            panic!("Expected InsufficientData error");
        }
    }

    #[test]
    fn test_calculate_sma() {
        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];

        // SMA with period 3
        let sma_result = calculate_sma(&data, 3).unwrap();
        assert_eq!(sma_result.len(), 8);
        assert_eq!(sma_result[0], (1.0 + 2.0 + 3.0) / 3.0);
        assert_eq!(sma_result[1], (2.0 + 3.0 + 4.0) / 3.0);
        assert_eq!(sma_result[7], (8.0 + 9.0 + 10.0) / 3.0);

        // SMA with period 5
        let sma_result = calculate_sma(&data, 5).unwrap();
        assert_eq!(sma_result.len(), 6);
        assert_eq!(sma_result[0], (1.0 + 2.0 + 3.0 + 4.0 + 5.0) / 5.0);
        assert_eq!(sma_result[5], (6.0 + 7.0 + 8.0 + 9.0 + 10.0) / 5.0);

        // Error case - period too large
        let result = calculate_sma(&data, 11);
        assert!(result.is_err());
    }

    #[test]
    fn test_calculate_ema() {
        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];

        // Recursive EMA with period 3 (alpha = 0.5).
        let ema_result = calculate_ema(&data, 3).unwrap();
        // Same length as input (one EMA per bar, starting from the first).
        assert_eq!(ema_result.len(), data.len());
        // First value is the seed (= data[0]).
        assert_eq!(ema_result[0], 1.0);
        // EMA[1] = (data[1] - EMA[0]) * 0.5 + EMA[0] = (2 - 1) * 0.5 + 1 = 1.5
        assert_eq!(ema_result[1], 1.5);
        // EMA[2] = (3 - 1.5) * 0.5 + 1.5 = 2.25
        assert_eq!(ema_result[2], 2.25);

        // Error case - period too large
        let result = calculate_ema(&data, 11);
        assert!(result.is_err());
    }

    #[test]
    fn test_standard_deviation() {
        let data = vec![2.0, 4.0, 6.0, 8.0, 10.0];

        // Population standard deviation calculation
        // Mean = 6.0
        // Variance = ((2-6)² + (4-6)² + (6-6)² + (8-6)² + (10-6)²) / 5 = (16 + 4 + 0 + 4 + 16) / 5 = 8
        // Standard deviation = √8 ≈ 2.828427
        let std_dev = standard_deviation(&data, Some(6.0)).unwrap();
        assert!((std_dev - 2.828427).abs() < 0.000001);

        // Auto-calculate mean
        let std_dev = standard_deviation(&data, None).unwrap();
        assert!((std_dev - 2.828427).abs() < 0.000001);

        // Error case - empty data
        let result = standard_deviation(&[] as &[f64], None);
        assert!(result.is_err());

        // Single value case
        let single_data = vec![5.0];
        let std_dev = standard_deviation(&single_data, None).unwrap();
        assert_eq!(std_dev, 0.0);
    }

    #[test]
    fn test_rate_of_change() {
        let data = vec![10.0, 11.0, 12.0, 13.0, 14.0, 15.0];

        // ROC with period 1
        let roc_result = rate_of_change(&data, 1).unwrap();
        assert_eq!(roc_result.len(), 5);
        // (11 - 10) / 10 * 100 = 10%
        assert_eq!(roc_result[0], 10.0);
        // (15 - 14) / 14 * 100 = 7.142857%
        assert!((roc_result[4] - 7.142857).abs() < 0.000001);

        // ROC with period 3
        let roc_result = rate_of_change(&data, 3).unwrap();
        assert_eq!(roc_result.len(), 3);
        // (13 - 10) / 10 * 100 = 30%
        assert_eq!(roc_result[0], 30.0);

        // Error case - period too large
        let result = rate_of_change(&data, 6);
        assert!(result.is_err());
    }
}