use std::collections::BTreeSet;
use chrono::{DateTime, Datelike, Utc};
use rust_decimal::Decimal;
#[path = "composition_hrp.rs"]
mod composition_hrp;
use crate::MetricsError;
use super::{PortfolioEquityPoint, ReturnLookup, ReturnSeries};
pub(crate) use composition_hrp::compute_hrp_weights;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RebalanceMode {
None,
Daily,
Weekly,
Monthly,
Quarterly,
}
#[derive(Debug, Clone)]
pub struct RebalanceEvent {
pub timestamp: DateTime<Utc>,
pub turnover: Decimal,
pub weights_before: Vec<Decimal>,
pub weights_after: Vec<Decimal>,
}
#[derive(Debug, Clone)]
pub struct WeightScheduleEntry {
pub date: DateTime<Utc>,
pub weights: Vec<Decimal>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AllocationMethod {
Custom,
EqualWeight,
InverseVol {
lookback: usize,
},
Hrp,
}
#[derive(Debug, Clone)]
pub struct ComposeOptions {
pub capital: Decimal,
pub rebalance: RebalanceMode,
pub weight_schedule: Vec<WeightScheduleEntry>,
pub allocation: AllocationMethod,
}
#[derive(Debug, Clone)]
pub struct MixedCompositionResult {
pub equity_curve: Vec<PortfolioEquityPoint>,
pub leg_equity_curves: Vec<Vec<PortfolioEquityPoint>>,
pub periods_per_year: u32,
pub leg_labels: Vec<String>,
pub effective_weights: Vec<(String, Decimal)>,
pub final_weights: Vec<(String, Decimal)>,
pub rebalance_events: Vec<RebalanceEvent>,
pub margin_call: bool,
pub warnings: Vec<String>,
}
fn resolve_target_weights(
ts: DateTime<Utc>,
schedule: &[WeightScheduleEntry],
default: &[Decimal],
) -> Vec<Decimal> {
let mut current = default.to_vec();
for entry in schedule {
if ts >= entry.date {
current.clone_from(&entry.weights);
}
}
current
}
fn forward_fill_returns(
ts: DateTime<Utc>,
n_legs: usize,
leg_lookups: &[ReturnLookup],
leg_first_ts: &[Option<DateTime<Utc>>],
last_known_return: &mut [Decimal],
) {
for i in 0..n_legs {
if let Some(r) = leg_lookups[i].get(&ts) {
last_known_return[i] = *r;
}
if let Some(first) = leg_first_ts[i] {
if ts < first {
last_known_return[i] = Decimal::ZERO;
}
}
}
}
#[derive(Default)]
pub(crate) struct RebalanceState {
last_day: Option<u32>,
last_week: Option<u32>,
last_month: Option<u32>,
last_quarter: Option<u32>,
}
fn period_changed(slot: &mut Option<u32>, current: u32, is_first: bool) -> bool {
if *slot == Some(current) {
return false;
}
*slot = Some(current);
!is_first
}
pub(crate) fn should_rebalance(
mode: &RebalanceMode,
ts: DateTime<Utc>,
is_first: bool,
state: &mut RebalanceState,
) -> bool {
match mode {
RebalanceMode::None => false,
RebalanceMode::Daily => period_changed(&mut state.last_day, ts.ordinal(), is_first),
RebalanceMode::Weekly => {
period_changed(&mut state.last_week, ts.iso_week().week(), is_first)
}
RebalanceMode::Monthly => period_changed(&mut state.last_month, ts.month(), is_first),
RebalanceMode::Quarterly => {
period_changed(&mut state.last_quarter, (ts.month() - 1) / 3, is_first)
}
}
}
fn execute_rebalance(
dollar_alloc: &mut [Decimal],
total: Decimal,
target_weights: &[Decimal],
) -> (Decimal, Vec<Decimal>) {
let weights_before: Vec<Decimal> = dollar_alloc.iter().map(|d| *d / total).collect();
let mut turnover = Decimal::ZERO;
for i in 0..dollar_alloc.len() {
let old_w = dollar_alloc[i] / total;
turnover += (target_weights[i] - old_w).abs();
dollar_alloc[i] = target_weights[i] * total;
}
(turnover, weights_before)
}
pub(crate) fn compute_inverse_vol_weights(
series: &[ReturnSeries],
lookback: usize,
) -> (Vec<Decimal>, Vec<String>) {
let mut warnings = Vec::new();
let n = series.len();
let mut vols: Vec<f64> = Vec::with_capacity(n);
for s in series {
let returns: Vec<f64> = s
.points
.iter()
.rev()
.take(lookback)
.map(|p| p.value.try_into().unwrap_or(0.0))
.collect();
if returns.len() < lookback {
warnings.push(format!(
"leg '{}': only {} periods available for {}-period lookback",
s.label,
returns.len(),
lookback
));
}
if returns.len() < 2 {
vols.push(0.0);
continue;
}
let mean: f64 = returns.iter().sum::<f64>() / returns.len() as f64;
let variance: f64 =
returns.iter().map(|r| (r - mean).powi(2)).sum::<f64>() / (returns.len() - 1) as f64;
vols.push(variance.sqrt());
}
let has_zero = vols.iter().any(|v| *v < 1e-12);
if has_zero {
warnings.push("zero volatility detected in one or more legs — using equal weights".into());
let equal_w = Decimal::ONE / Decimal::from(n as u32);
return (vec![equal_w; n], warnings);
}
let inv_vols: Vec<f64> = vols.iter().map(|v| 1.0 / v).collect();
let sum_inv: f64 = inv_vols.iter().sum();
let weights: Vec<Decimal> = inv_vols
.iter()
.map(|iv| Decimal::try_from(iv / sum_inv).unwrap_or(Decimal::ZERO))
.collect();
(weights, warnings)
}
enum StepResult {
Continue(Decimal),
MarginCall,
}
fn step_one_period(
ts: DateTime<Utc>,
n_legs: usize,
leg_lookups: &[ReturnLookup],
leg_first_ts: &[Option<DateTime<Utc>>],
last_known_return: &mut [Decimal],
dollar_alloc: &mut [Decimal],
) -> StepResult {
forward_fill_returns(ts, n_legs, leg_lookups, leg_first_ts, last_known_return);
let total_before: Decimal = dollar_alloc.iter().sum();
if total_before <= Decimal::ZERO {
return StepResult::MarginCall;
}
for i in 0..n_legs {
dollar_alloc[i] *= Decimal::ONE + last_known_return[i];
}
let total_after: Decimal = dollar_alloc.iter().sum();
if total_after <= Decimal::ZERO {
return StepResult::MarginCall;
}
StepResult::Continue(total_after)
}
fn build_timeline(series: &[ReturnSeries]) -> Vec<DateTime<Utc>> {
let mut all_timestamps = BTreeSet::new();
for s in series {
for p in &s.points {
all_timestamps.insert(p.timestamp);
}
}
all_timestamps.into_iter().collect()
}
fn validate_mixed_inputs(series: &[ReturnSeries], weights: &[Decimal]) -> Result<(), MetricsError> {
if series.is_empty() {
return Err(MetricsError::InvalidParameter(
"at least one leg required".into(),
));
}
if series.len() != weights.len() {
return Err(MetricsError::InvalidParameter(
"series and weights must have same length".into(),
));
}
Ok(())
}
fn resolve_allocation_weights(
series: &[ReturnSeries],
weights: &[Decimal],
allocation: &AllocationMethod,
) -> (Vec<Decimal>, Vec<String>) {
let n_legs = series.len();
let mut warnings = Vec::new();
let effective = match allocation {
AllocationMethod::Custom => weights.to_vec(),
AllocationMethod::EqualWeight => {
let w = Decimal::ONE / Decimal::from(n_legs as u32);
vec![w; n_legs]
}
AllocationMethod::InverseVol { lookback } => {
let (inv_weights, inv_warnings) = compute_inverse_vol_weights(series, *lookback);
warnings.extend(inv_warnings);
inv_weights
}
AllocationMethod::Hrp => {
let (hrp_weights, hrp_warnings) = compute_hrp_weights(series);
warnings.extend(hrp_warnings);
hrp_weights
}
};
(effective, warnings)
}
fn compute_final_weights(
series: &[ReturnSeries],
dollar_alloc: &[Decimal],
) -> Vec<(String, Decimal)> {
let total_final: Decimal = dollar_alloc.iter().sum();
if total_final > Decimal::ZERO {
series
.iter()
.enumerate()
.map(|(i, s)| (s.label.clone(), dollar_alloc[i] / total_final))
.collect()
} else {
series
.iter()
.map(|s| (s.label.clone(), Decimal::ZERO))
.collect()
}
}
pub fn compose_mixed(
series: &[ReturnSeries],
weights: &[Decimal],
options: &ComposeOptions,
) -> Result<MixedCompositionResult, MetricsError> {
validate_mixed_inputs(series, weights)?;
let n_legs = series.len();
let timeline = build_timeline(series);
if timeline.is_empty() {
return Err(MetricsError::InsufficientData {
required: 1,
actual: 0,
});
}
let max_freq = series
.iter()
.map(|s| s.frequency.periods_per_year())
.max()
.unwrap_or(365);
let leg_lookups: Vec<ReturnLookup> = series
.iter()
.map(|s| s.points.iter().map(|p| (p.timestamp, p.value)).collect())
.collect();
let leg_first_ts: Vec<Option<DateTime<Utc>>> = series
.iter()
.map(|s| s.points.first().map(|p| p.timestamp))
.collect();
let (effective_weights, mut warnings) =
resolve_allocation_weights(series, weights, &options.allocation);
let initial_effective_weights: Vec<(String, Decimal)> = series
.iter()
.enumerate()
.map(|(i, s)| (s.label.clone(), effective_weights[i]))
.collect();
let mut last_known_return: Vec<Decimal> = vec![Decimal::ZERO; n_legs];
let mut dollar_alloc: Vec<Decimal> = effective_weights
.iter()
.map(|w| *w * options.capital)
.collect();
let synthetic_t0 = timeline[0] - chrono::Duration::seconds(1);
let mut equity_curve = vec![PortfolioEquityPoint {
timestamp: synthetic_t0,
value: options.capital,
}];
let mut leg_equity_curves: Vec<Vec<PortfolioEquityPoint>> = (0..n_legs)
.map(|i| {
vec![PortfolioEquityPoint {
timestamp: synthetic_t0,
value: dollar_alloc[i],
}]
})
.collect();
let mut rebalance_events: Vec<RebalanceEvent> = Vec::new();
let mut margin_call = false;
let mut rebalance_state = RebalanceState::default();
for (step_idx, &ts) in timeline.iter().enumerate() {
let total_after = match step_one_period(
ts,
n_legs,
&leg_lookups,
&leg_first_ts,
&mut last_known_return,
&mut dollar_alloc,
) {
StepResult::MarginCall => {
margin_call = true;
equity_curve.push(PortfolioEquityPoint {
timestamp: ts,
value: Decimal::ZERO,
});
break;
}
StepResult::Continue(total) => total,
};
if should_rebalance(&options.rebalance, ts, step_idx == 0, &mut rebalance_state) {
let target = resolve_target_weights(ts, &options.weight_schedule, &effective_weights);
let (turnover, weights_before) =
execute_rebalance(&mut dollar_alloc, total_after, &target);
rebalance_events.push(RebalanceEvent {
timestamp: ts,
turnover,
weights_before,
weights_after: target,
});
}
equity_curve.push(PortfolioEquityPoint {
timestamp: ts,
value: total_after,
});
for i in 0..n_legs {
leg_equity_curves[i].push(PortfolioEquityPoint {
timestamp: ts,
value: dollar_alloc[i],
});
}
}
let final_weights = compute_final_weights(series, &dollar_alloc);
let leg_labels = series.iter().map(|s| s.label.clone()).collect();
let weight_sum: Decimal = effective_weights.iter().map(|w| w.abs()).sum();
if weight_sum > Decimal::ONE {
warnings.push(format!(
"leverage detected: absolute weight sum is {} (>1.0)",
weight_sum
));
}
Ok(MixedCompositionResult {
equity_curve,
leg_equity_curves,
periods_per_year: max_freq,
leg_labels,
effective_weights: initial_effective_weights,
final_weights,
rebalance_events,
margin_call,
warnings,
})
}