Documentation
use std::collections::VecDeque;
use crate::strategy_engine::Strategy;

#[derive(Debug, Clone)]
pub struct MovingAverageCross {
    short_window: VecDeque<f64>,
    long_window: VecDeque<f64>,
}

impl MovingAverageCross {
    pub fn new(short_window_size: usize, long_window_size: usize) -> Self {
        Self {
            short_window: VecDeque::with_capacity(short_window_size),
            long_window: VecDeque::with_capacity(long_window_size),
        }
    }

}

// src/moving_average_cross.rs
impl MovingAverageCross {
    pub fn update_price(&mut self, price: f64) {
        self.short_window.push_back(price);
        self.long_window.push_back(price);

        if self.short_window.len() > self.short_window.capacity() {
            self.short_window.pop_front();
        }
        if self.long_window.len() > self.long_window.capacity() {
            self.long_window.pop_front();
        }
    }
}

impl Strategy for MovingAverageCross {
    fn evaluate(&self, _price: &f64) -> f64 {
        let short_avg = self.short_window.iter().sum::<f64>() / self.short_window.len() as f64;
        let long_avg = self.long_window.iter().sum::<f64>() / self.long_window.len() as f64;

        if short_avg > long_avg {
            1.0 // 买入信号
        } else if short_avg < long_avg {
            -1.0 // 卖出信号
        } else {
            0.0 // 无信号
        }
    }

    fn calculate_indicator(&self, prices: &Vec<f64>) -> f64 {
        // 计算夏普比率作为指标
        let mut returns = vec![];
        for price in prices {
            let signal = self.evaluate(price);
            if signal > 0.0 {
                returns.push(*price / prices[0] - 1.0); // 买入信号,计算收益
            } else if signal < 0.0 {
                returns.push(1.0 - (*price / prices[0])); // 卖出信号,计算收益
            } else {
                returns.push(0.0); // 无信号,收益为0
            }
        }
        let avg_return = returns.iter().sum::<f64>() / returns.len() as f64;
        let std_dev = (returns.iter().map(|x| (x - avg_return).powi(2)).sum::<f64>() / returns.len() as f64).sqrt();
        avg_return / std_dev // 返回夏普比率
    }
}