use std::collections::HashMap;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use quant_indicators::hrp;
use super::ReturnSeries;
pub(crate) fn compute_hrp_weights(series: &[ReturnSeries]) -> (Vec<Decimal>, Vec<String>) {
let mut warnings = Vec::new();
let n = series.len();
if n <= 1 {
return (vec![Decimal::ONE; n], warnings);
}
let min_obs = 30;
let per_leg: Vec<Vec<f64>> = series
.iter()
.map(|s| {
s.points
.iter()
.map(|p| p.value.try_into().unwrap_or(0.0))
.collect()
})
.collect();
for (i, returns) in per_leg.iter().enumerate() {
if returns.len() < min_obs {
warnings.push(format!(
"HRP: leg '{}' has only {} observations (< {min_obs}) — using equal weights",
series[i].label,
returns.len()
));
let w = Decimal::ONE / Decimal::from(n as u32);
return (vec![w; n], warnings);
}
}
let means: Vec<f64> = per_leg.iter().map(|r| hrp::mean(r)).collect();
let variances: Vec<f64> = per_leg
.iter()
.enumerate()
.map(|(i, r)| hrp::variance(r, means[i]))
.collect();
if variances.iter().any(|v| *v < 1e-20) {
warnings.push("HRP: zero variance detected — using equal weights".into());
let w = Decimal::ONE / Decimal::from(n as u32);
return (vec![w; n], warnings);
}
let ts_lookups: Vec<HashMap<DateTime<Utc>, f64>> = series
.iter()
.map(|s| {
s.points
.iter()
.map(|p| (p.timestamp, p.value.try_into().unwrap_or(0.0)))
.collect()
})
.collect();
let corr = pairwise_correlation(&ts_lookups, series, min_obs, &mut warnings);
let dist = hrp::distance_matrix(&corr);
if !check_triangle_inequality(&dist) {
warnings.push(
"HRP: pairwise correlation matrix may not be positive semi-definite \
(triangle inequality violated) — dendrogram quality degraded"
.into(),
);
}
let tree = hrp::build_dendrogram(&dist, n);
let mut weights_f64 = vec![0.0; n];
hrp::recursive_bisect(&tree, &variances, 1.0, &mut weights_f64);
let weights: Vec<Decimal> = weights_f64
.iter()
.filter_map(|w| hrp::decimal_from_f64(*w).ok())
.collect();
(weights, warnings)
}
fn pairwise_correlation(
ts_lookups: &[HashMap<DateTime<Utc>, f64>],
series: &[ReturnSeries],
min_obs: usize,
warnings: &mut Vec<String>,
) -> Vec<Vec<f64>> {
let n = ts_lookups.len();
let mut corr = vec![vec![1.0; n]; n];
for i in 0..n {
for j in (i + 1)..n {
let (ri, rj): (Vec<f64>, Vec<f64>) = ts_lookups[i]
.iter()
.filter_map(|(ts, &vi)| ts_lookups[j].get(ts).map(|&vj| (vi, vj)))
.unzip();
let common = ri.len();
let longer = ts_lookups[i].len().max(ts_lookups[j].len());
if common < min_obs {
warnings.push(format!(
"HRP: pair ('{}','{}') has only {} common observations (< {min_obs}) — assuming zero correlation",
series[i].label, series[j].label, common
));
corr[i][j] = 0.0;
corr[j][i] = 0.0;
continue;
}
if common < longer * 80 / 100 {
warnings.push(format!(
"HRP: pair ('{}','{}') shares only {}/{} timestamps — correlation may be noisy",
series[i].label, series[j].label, common, longer
));
}
let mean_i = ri.iter().sum::<f64>() / common as f64;
let mean_j = rj.iter().sum::<f64>() / common as f64;
let cov: f64 = ri
.iter()
.zip(rj.iter())
.map(|(a, b)| (a - mean_i) * (b - mean_j))
.sum::<f64>()
/ (common - 1) as f64;
let var_i: f64 =
ri.iter().map(|x| (x - mean_i).powi(2)).sum::<f64>() / (common - 1) as f64;
let var_j: f64 =
rj.iter().map(|x| (x - mean_j).powi(2)).sum::<f64>() / (common - 1) as f64;
let denom = var_i.sqrt() * var_j.sqrt();
let r = if denom < 1e-20 {
0.0
} else {
(cov / denom).clamp(-1.0, 1.0)
};
corr[i][j] = r;
corr[j][i] = r;
}
}
corr
}
fn check_triangle_inequality(dist: &[Vec<f64>]) -> bool {
let n = dist.len();
for i in 0..n {
for j in (i + 1)..n {
for k in (j + 1)..n {
let eps = 1e-9;
if dist[i][k] > dist[i][j] + dist[j][k] + eps
|| dist[i][j] > dist[i][k] + dist[j][k] + eps
|| dist[j][k] > dist[i][j] + dist[i][k] + eps
{
return false;
}
}
}
}
true
}