Documentation
extern crate ndarray;

use crate::BoxErr;
use crate::adf::calculate_adf_test_statistic;
use crate::models::Coint;
use crate::simple::simple_linear_regression;
use crate::kalman::dynamic_hedge_kalman_filter;
use crate::mackinnon::{critical_values_mackinnon_cointegration, p_value_mackinnon_cointegration};

/// Half Life
/// Time it takes for process to revert to half its initial deviation
pub fn half_life(series: &Vec<f64>) -> Result<f64, String> {
  if series.len() <= 1 {
      return Err("Series length must be greater than 1.".to_string());
  }

  let difference: Vec<f64> = series.windows(2).map(|x| x[1] - x[0]).collect();
  let lagged_series: Vec<f64> = series[..(series.len() - 1)].to_vec();

  let ((_, beta_1), _residuals) = simple_linear_regression(&lagged_series, &difference)?;
  
  // check if beta_1 is zero to prevent division by zero error
  if beta_1.abs() < std::f64::EPSILON {
      return Err("Cannot calculate half life. Beta_1 value is too close to zero.".to_string());
  }

  let half_life: f64 = -f64::ln(2.0) / beta_1;
  
  Ok(half_life)
}


/// Spread With Hedge Ratio
/// Calculates the spread for two series and given Hedge Ratio
pub fn spread_standard(x: &Vec<f64>, y: &Vec<f64>) -> Result<Vec<f64>, BoxErr> {

  // Guard: Ensure length matches
  if x.len() != y.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Input vectors have different sizes")));
  }

  // Calculate intercept and hedge ratio (slope)
  let ((intercept, hedge_ratio), _) = simple_linear_regression(&x, &y)?;

  // Compute spread
  let spread: Vec<f64> = x.iter().zip(y.iter()).map(|(&x, &y)| y - hedge_ratio * x - intercept).collect();
  
  // Return result
  Ok(spread)
}


/// Spread With Dynamic Hedge Ratio
/// Calculates the spread for two series and given a Dynamic Hedge Ratio Vector
/// Use if you already know the dynamic hedge ratio
pub fn spread_dynamic(x: &Vec<f64>, y: &Vec<f64>) -> Result<Vec<f64>, BoxErr> {

  // Guard: Ensure length matches
  if x.len() != y.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Input vectors have different sizes")));
  }

  // Extract Hedge Ratio
  let dyn_hedge_ratio: Vec<f64> = dynamic_hedge_kalman_filter(x, y)?;

  // Guard: Ensure Dynamic Hedge Ratio length matches
  if x.len() != dyn_hedge_ratio.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Hedge Ratio vector should match length of time series")));
  }
  
  // Compute dynamic spread
  let mut dyn_spread: Vec<f64> = x.iter().zip(y.iter()).zip(dyn_hedge_ratio.iter())
    .map(|((&x, &y), &hedge_ratio_i)| y - hedge_ratio_i * x)
    .collect();

  // Replace first few items with fixed spread due to Kalman learning curve
  if dyn_spread.len() > 4 {
    dyn_spread[0] = dyn_spread[4];
    dyn_spread[1] = dyn_spread[4];
    dyn_spread[2] = dyn_spread[4];
    dyn_spread[3] = dyn_spread[4];
  }

  // Return result
  Ok(dyn_spread)
}

/// ZScore
/// Calculates the ZScore given a spread
pub fn rolling_zscore(series: &Vec<f64>, window: usize) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
  let mut z_scores: Vec<f64> = vec![0.0; window]; // Padding with 0.0 for the first (window) elements

  // Guard: Ensure correct window size
  if window > series.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Window size is greater than vector length")));
  }

  // Calculate z-scores for each window
  for i in window..series.len() {
    let window_data: &[f64] = &series[i-window..i];
    let mean: f64 = window_data.iter().sum::<f64>() / window_data.len() as f64;
    let var: f64 = window_data.iter().map(|&val| (val - mean).powi(2)).sum::<f64>() / (window_data.len()-1) as f64;
    let std_dev: f64 = var.sqrt();
    if std_dev == 0.0 {
        return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Standard deviation is zero")));
    }
    let z_score = (series[i] - mean) / std_dev;
    z_scores.push(z_score);
  }
  Ok(z_scores)
}

/// Correlation
/// Using Pearsons Correlation Coefficient
pub fn pearson_correlation_coefficient(x: &Vec<f64>, y: &Vec<f64>) -> Result<f64, BoxErr> {
  if x.len() != y.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Input vectors have different sizes")));
  }

  let mean_x: f64 = x.iter().sum::<f64>() / x.len() as f64;
  let mean_y: f64 = y.iter().sum::<f64>() / y.len() as f64;
  
  let covariance: f64 = x.iter().zip(y.iter())
    .map(|(x_i, y_i)| (x_i - mean_x) * (y_i - mean_y))
    .sum::<f64>() / (x.len() - 1) as f64;

  let std_dev_x: f64 = (x.iter().map(|x_i| (x_i - mean_x).powi(2)).sum::<f64>() / x.len() as f64).sqrt();
  let std_dev_y: f64 = (y.iter().map(|y_i| (y_i - mean_y).powi(2)).sum::<f64>() / y.len() as f64).sqrt();

  let corr: f64 = covariance / (std_dev_x * std_dev_y);

  Ok(corr)
}


/// Cointegration Test Based on Engle Granger 2-Step Approach
/// Provides test statistic, critical values, pvalue and also hedge ratio
pub fn engle_granger_cointegration_test(series_1: &Vec<f64>, series_2: &Vec<f64>) -> Result<Coint, String> {
    
  let (_, residuals) = simple_linear_regression(series_1, series_2)?;

  let residuals_diff: Vec<f64> = residuals.windows(2).map(|w| w[1] - w[0]).collect();

  let t_stat: f64 = calculate_adf_test_statistic(residuals, residuals_diff).map_err(|e| e.to_string())?;

  let (cv_1pct, cv_5pct, cv_10pct) = critical_values_mackinnon_cointegration();

  let adf_p_value: f64 = p_value_mackinnon_cointegration(t_stat);

  let is_cointegrated: bool = t_stat < cv_5pct as f64 && adf_p_value < 0.05;
  
  let coint: Coint = Coint {
    is_coint: is_cointegrated,
    test_statistic: t_stat,
    critical_values: (cv_1pct, cv_5pct, cv_10pct),
    p_value: adf_p_value
  };

  Ok(coint)
}

/// Rolling Correlation
/// Calculates the Rolling Correlation for a given window
pub fn rolling_correlation(series_1: &Vec<f64>, series_2: &Vec<f64>, window: usize) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
  let mut correlations: Vec<f64> = vec![0.0; window]; // Padding with 0.0 for the first (window) elements

  // Guard: Ensure series length matches
  if series_1.len() != series_2.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Input vectors have different sizes")));
  }

  // Guard: Ensure correct window size
  if window > series_1.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Window size is greater than vector length")));
  }

  // Calculate rolling cointegration for each window
  for i in window..series_1.len() {
    let series_1_i: &Vec<f64> = &series_1[i-window..i].to_vec();
    let series_2_i: &Vec<f64> = &series_2[i-window..i].to_vec();
    let corr: f64 = pearson_correlation_coefficient(series_1_i, series_2_i)?;
    correlations.push(corr);
  }
  Ok(correlations)
}

/// Rolling Cointegration
/// Calculates the Rolling Cointegration in terms of test-stat minus c-value for a given window
pub fn rolling_cointegration(series_1: &Vec<f64>, series_2: &Vec<f64>, window: usize) -> Result<Vec<f64>, Box<dyn std::error::Error>> {
  let mut t_distances: Vec<f64> = vec![0.0; window]; // Padding with 0.0 for the first (window) elements

  // Guard: Ensure series length matches
  if series_1.len() != series_2.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Input vectors have different sizes")));
  }

  // Guard: Ensure correct window size
  if window > series_1.len() {
    return Err(Box::new(std::io::Error::new(std::io::ErrorKind::Other, "Window size is greater than vector length")));
  }

  // Calculate rolling cointegration for each window
  for i in window..series_1.len() {
    let series_1_i: &Vec<f64> = &series_1[i-window..i].to_vec();
    let series_2_i: &Vec<f64> = &series_2[i-window..i].to_vec();
    let coint: Coint = engle_granger_cointegration_test(series_1_i, series_2_i)?;
    let t_stat: f64 = coint.test_statistic;
    let c_value: f64 = coint.critical_values.1 as f64;
    let t_distance: f64 = -(t_stat - c_value);
    t_distances.push(t_distance);
  }
  Ok(t_distances)
}

#[cfg(test)]
mod tests {
  use super::*;
  use crate::{get_test_data, utils::log_returns};

  #[test]
  fn tests_standard_spread() {
    let (x, y) = get_test_data();
    let spread: Vec<f64> = spread_standard(&x, &y).unwrap();
    dbg!(&spread);
    assert!(spread.len() > 0);
  }

  #[test]
  fn tests_dynamic_spread() {
    let (x, y) = get_test_data();
    let dyn_spread: Vec<f64> = spread_dynamic(&x, &y).unwrap();
    dbg!(&dyn_spread);
    assert!(dyn_spread.len() > 0);
  }

  #[test]
  fn tests_half_life() {
    let (x, y) = get_test_data();
    let spread: Vec<f64> = spread_standard(&x, &y).unwrap();
    let half_life: f64 = half_life(&spread).unwrap();
    dbg!(half_life);
    assert!(half_life > 1.0);
  }

  #[test]
  fn tests_rolling_zscore() {
    let (x, y) = get_test_data();
    let spread: Vec<f64> = spread_standard(&x, &y).unwrap();
    let window: usize = 21;
    let zscore: Vec<f64> = rolling_zscore(&spread, window).unwrap();
    assert_eq!(zscore.len(), spread.len())
  }

  #[test]
  fn tests_correlation_simple() {
    let (x, y) = get_test_data();
    let corr: f64 = pearson_correlation_coefficient(&x, &y).unwrap();
    dbg!(corr);
    assert!(corr > 0.0);
  }

  #[test]
  fn tests_correlation_rets() {
    let (x, y) = get_test_data();
    let rets_x = log_returns(&x, false);
    let rets_y = log_returns(&y, false);
    let corr: f64 = pearson_correlation_coefficient(&rets_x, &rets_y).unwrap();
    dbg!(corr);
    assert!(corr > 0.0);
  }

  #[test]
  fn tests_cointegration() {
    let (x, y) = get_test_data();
    let coint: Coint = engle_granger_cointegration_test(&x, &y).unwrap();
    dbg!(&coint);
    assert!(coint.is_coint)
  }

  #[test]
  fn tests_rolling_correlation() {
    let (x, y) = get_test_data();
    let correlations: Vec<f64> = rolling_correlation(&x, &y, 20).unwrap();
    assert_eq!(correlations.len(), x.len());
  }

  #[test]
  fn tests_rolling_cointegration() {
    let (x, y) = get_test_data();
    let coints: Vec<f64> = rolling_cointegration(&x, &y, 20).unwrap();
    assert_eq!(coints.len(), x.len());
  }

}