use std::collections::{HashMap, HashSet};
use std::path::Path;
use chrono::NaiveDate;
use cobre_core::{EntityId, System};
use cobre_io::{
Config, FileManifest, LoadError, ValidationContext,
config::OrderSelectionMethod,
parse_inflow_history,
scenarios::{InflowArCoefficientRow, InflowSeasonalStatsRow, assemble_inflow_models},
validate_structure,
};
use cobre_stochastic::{
StochasticError,
par::fitting::{
ArCoefficientEstimate, SeasonalStats, estimate_ar_coefficients, estimate_correlation,
estimate_seasonal_stats, find_season_for_date, levinson_durbin, select_order_aic,
},
};
#[derive(Debug, thiserror::Error)]
pub enum EstimationError {
#[error("load error: {0}")]
Load(#[from] LoadError),
#[error("estimation failed: {0}")]
Stochastic(#[from] StochasticError),
}
pub fn estimate_from_history(
system: System,
case_dir: &Path,
config: &Config,
) -> Result<System, EstimationError> {
let mut ctx = ValidationContext::new();
let manifest = validate_structure(case_dir, &mut ctx);
if ctx.into_result().is_err() {
return Ok(system);
}
if !manifest.scenarios_inflow_history_parquet {
return Ok(system);
}
if manifest.scenarios_inflow_seasonal_stats_parquet
&& manifest.scenarios_inflow_ar_coefficients_parquet
{
return Ok(system);
}
run_estimation(system, case_dir, config, &manifest)
}
fn run_estimation(
system: System,
case_dir: &Path,
config: &Config,
manifest: &FileManifest,
) -> Result<System, EstimationError> {
let history_path = case_dir.join("scenarios/inflow_history.parquet");
let history = parse_inflow_history(&history_path)?;
let observations: Vec<(EntityId, NaiveDate, f64)> = history
.iter()
.map(|row| (row.hydro_id, row.date, row.value_m3s))
.collect();
let hydro_ids: Vec<EntityId> = system.hydros().iter().map(|h| h.id).collect();
let stages = system.stages();
let seasonal_stats = estimate_seasonal_stats(&observations, stages, &hydro_ids)?;
let max_order = config.estimation.max_order as usize;
let ar_estimates = estimate_ar_coefficients_with_selection(
&observations,
&seasonal_stats,
stages,
&hydro_ids,
max_order,
&config.estimation.order_selection,
)?;
let correlation = if manifest.scenarios_correlation_json {
system.correlation().clone()
} else {
estimate_correlation(
&observations,
&ar_estimates,
&seasonal_stats,
stages,
&hydro_ids,
)?
};
let stats_rows = seasonal_stats_to_rows(&seasonal_stats, stages);
let coeff_rows = ar_estimates_to_rows(&ar_estimates, stages);
let inflow_models = assemble_inflow_models(stats_rows, coeff_rows)?;
Ok(system.with_scenario_models(inflow_models, correlation))
}
fn estimate_ar_coefficients_with_selection(
observations: &[(EntityId, NaiveDate, f64)],
seasonal_stats: &[SeasonalStats],
stages: &[cobre_core::temporal::Stage],
hydro_ids: &[EntityId],
max_order: usize,
method: &OrderSelectionMethod,
) -> Result<Vec<ArCoefficientEstimate>, StochasticError> {
match method {
OrderSelectionMethod::Fixed => {
estimate_ar_coefficients(observations, seasonal_stats, stages, hydro_ids, max_order)
}
OrderSelectionMethod::Aic => {
estimate_ar_with_aic(observations, seasonal_stats, stages, hydro_ids, max_order)
}
}
}
fn estimate_ar_with_aic(
observations: &[(EntityId, NaiveDate, f64)],
seasonal_stats: &[SeasonalStats],
stages: &[cobre_core::temporal::Stage],
hydro_ids: &[EntityId],
max_order: usize,
) -> Result<Vec<ArCoefficientEstimate>, StochasticError> {
let mut estimates =
estimate_ar_coefficients(observations, seasonal_stats, stages, hydro_ids, max_order)?;
if max_order == 0 {
return Ok(estimates);
}
let mut stage_index = stages
.iter()
.filter_map(|s| s.season_id.map(|sid| (s.start_date, s.end_date, s.id, sid)))
.collect::<Vec<_>>();
stage_index.sort_unstable_by_key(|(start, _, _, _)| *start);
let stage_id_to_season: HashMap<i32, usize> = stage_index
.iter()
.map(|&(_, _, stage_id, season_id)| (stage_id, season_id))
.collect();
let stats_map: HashMap<(EntityId, usize), &SeasonalStats> = seasonal_stats
.iter()
.filter_map(|s| {
let season_id = stage_id_to_season.get(&s.stage_id).copied()?;
Some(((s.entity_id, season_id), s))
})
.collect();
let entity_set: HashSet<EntityId> = hydro_ids.iter().copied().collect();
let mut group_obs: HashMap<(EntityId, usize), Vec<f64>> = HashMap::new();
for &(entity_id, date, value) in observations {
if !entity_set.contains(&entity_id) {
continue;
}
let Some(season_id) = find_season_for_date(&stage_index, date) else {
continue;
};
group_obs
.entry((entity_id, season_id))
.or_default()
.push(value);
}
let n_seasons: usize = {
let mut max_season = 0usize;
for &(_, _, _, season_id) in &stage_index {
if season_id >= max_season {
max_season = season_id + 1;
}
}
max_season
};
for est in &mut estimates {
let key = (est.hydro_id, est.season_id);
let Some(stats_m) = stats_map.get(&key) else {
continue;
};
if stats_m.std == 0.0 {
continue;
}
let Some(pair_obs) = group_obs.get(&key) else {
continue;
};
let n_obs = pair_obs.len();
if n_obs < 2 {
continue;
}
let actual_order = est.coefficients.len();
if actual_order == 0 {
continue;
}
let autocorrelations = compute_autocorrelations(
est.hydro_id,
est.season_id,
actual_order,
n_seasons,
pair_obs,
&stats_map,
&group_obs,
);
if autocorrelations.len() < actual_order {
continue;
}
let ld = levinson_durbin(&autocorrelations, actual_order);
if ld.sigma2_per_order.is_empty() {
continue;
}
let effective_n = n_obs.saturating_sub(actual_order);
let aic_result = select_order_aic(&ld.sigma2_per_order, effective_n);
let selected = aic_result.selected_order;
if selected < actual_order {
est.coefficients.truncate(selected);
let sigma2_selected = if selected == 0 {
1.0
} else {
ld.sigma2_per_order[selected - 1]
};
est.residual_std_ratio = sigma2_selected.sqrt().clamp(f64::EPSILON, 1.0);
}
}
Ok(estimates)
}
fn compute_autocorrelations(
hydro_id: EntityId,
season_id: usize,
max_order: usize,
n_seasons: usize,
pair_obs: &[f64],
stats_map: &HashMap<(EntityId, usize), &SeasonalStats>,
group_obs: &HashMap<(EntityId, usize), Vec<f64>>,
) -> Vec<f64> {
let Some(stats_m) = stats_map.get(&(hydro_id, season_id)) else {
return Vec::new();
};
let mu_m = stats_m.mean;
let std_m = stats_m.std;
let mut autocorrelations = Vec::with_capacity(max_order);
for lag in 1..=max_order {
let lag_season = season_id
.wrapping_add(n_seasons)
.wrapping_sub(lag % n_seasons)
% n_seasons;
let lag_key = (hydro_id, lag_season);
let stats_lag = match stats_map.get(&lag_key) {
Some(s) if s.std > 0.0 => s,
_ => break,
};
let Some(lag_obs) = group_obs.get(&lag_key) else {
break;
};
let n_pairs = pair_obs.len().min(lag_obs.len());
if n_pairs < 2 {
break;
}
let mut cross_sum = 0.0_f64;
for i in 0..n_pairs {
cross_sum += (pair_obs[i] - mu_m) * (lag_obs[i] - stats_lag.mean);
}
#[allow(clippy::cast_precision_loss)]
let gamma = cross_sum / (n_pairs - 1) as f64;
let rho = gamma / (std_m * stats_lag.std);
autocorrelations.push(rho.clamp(-1.0, 1.0));
}
autocorrelations
}
fn seasonal_stats_to_rows(
stats: &[SeasonalStats],
_stages: &[cobre_core::temporal::Stage],
) -> Vec<InflowSeasonalStatsRow> {
stats
.iter()
.map(|s| InflowSeasonalStatsRow {
hydro_id: s.entity_id,
stage_id: s.stage_id,
mean_m3s: s.mean,
std_m3s: s.std,
})
.collect()
}
fn ar_estimates_to_rows(
ar_estimates: &[ArCoefficientEstimate],
stages: &[cobre_core::temporal::Stage],
) -> Vec<InflowArCoefficientRow> {
let mut season_to_stage: HashMap<usize, i32> = HashMap::new();
for stage in stages
.iter()
.filter_map(|s| s.season_id.map(|sid| (sid, s.id)))
{
season_to_stage
.entry(stage.0)
.and_modify(|existing| {
if stage.1 < *existing {
*existing = stage.1;
}
})
.or_insert(stage.1);
}
let mut rows: Vec<InflowArCoefficientRow> = Vec::new();
for est in ar_estimates {
let Some(&stage_id) = season_to_stage.get(&est.season_id) else {
continue;
};
for (lag_idx, &coeff) in est.coefficients.iter().enumerate() {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let lag = (lag_idx + 1) as i32;
rows.push(InflowArCoefficientRow {
hydro_id: est.hydro_id,
stage_id,
lag,
coefficient: coeff,
residual_std_ratio: est.residual_std_ratio,
});
}
}
rows.sort_by_key(|r| (r.hydro_id.0, r.stage_id, r.lag));
rows
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::too_many_lines,
clippy::float_cmp
)]
mod tests {
use super::*;
use cobre_core::scenario::{CorrelationModel, InflowModel};
use cobre_core::{EntityId, SystemBuilder};
fn minimal_system_with_inflow_models(models: Vec<InflowModel>) -> System {
SystemBuilder::new()
.inflow_models(models)
.build()
.expect("valid system")
}
#[test]
fn test_with_scenario_models_replaces_fields() {
use cobre_core::{
Bus, DeficitSegment,
scenario::{CorrelationModel, InflowModel},
};
let bus = Bus {
id: EntityId(1),
name: "B1".to_string(),
deficit_segments: vec![DeficitSegment {
depth_mw: Some(f64::INFINITY),
cost_per_mwh: 1000.0,
}],
excess_cost: 0.0,
};
let old_model = InflowModel {
hydro_id: EntityId(1),
stage_id: 0,
mean_m3s: 10.0,
std_m3s: 1.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
};
let system = SystemBuilder::new()
.buses(vec![bus])
.inflow_models(vec![old_model.clone(), {
let mut m = old_model.clone();
m.stage_id = 1;
m
}])
.build()
.expect("valid system");
assert_eq!(system.inflow_models().len(), 2);
assert_eq!(system.n_buses(), 1);
let new_models: Vec<InflowModel> = (0..4)
.map(|i| InflowModel {
hydro_id: EntityId(1),
stage_id: i,
mean_m3s: 50.0,
std_m3s: 5.0,
ar_coefficients: vec![0.4],
residual_std_ratio: 0.9,
})
.collect();
let new_corr = CorrelationModel::default();
let updated = system.with_scenario_models(new_models.clone(), new_corr.clone());
assert_eq!(updated.inflow_models().len(), 4, "expected 4 inflow models");
assert_eq!(
*updated.correlation(),
new_corr,
"correlation should equal new_corr"
);
assert_eq!(updated.n_buses(), 1, "buses must be preserved");
assert!(
updated.hydros().is_empty(),
"hydros must be preserved (empty)"
);
assert!(
updated.stages().is_empty(),
"stages must be preserved (empty)"
);
}
#[test]
fn test_with_scenario_models_clears_when_empty() {
let model = InflowModel {
hydro_id: EntityId(1),
stage_id: 0,
mean_m3s: 100.0,
std_m3s: 10.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
};
let system = minimal_system_with_inflow_models(vec![model]);
assert_eq!(system.inflow_models().len(), 1);
let updated = system.with_scenario_models(vec![], CorrelationModel::default());
assert!(updated.inflow_models().is_empty());
}
#[test]
fn test_estimate_explicit_stats_returns_unchanged() {
use tempfile::TempDir;
let dir = TempDir::new().unwrap();
let case_dir = dir.path();
create_required_files(case_dir);
let scenarios = case_dir.join("scenarios");
std::fs::create_dir_all(&scenarios).unwrap();
std::fs::write(scenarios.join("inflow_history.parquet"), b"").unwrap();
std::fs::write(scenarios.join("inflow_seasonal_stats.parquet"), b"").unwrap();
std::fs::write(scenarios.join("inflow_ar_coefficients.parquet"), b"").unwrap();
let model = InflowModel {
hydro_id: EntityId(1),
stage_id: 0,
mean_m3s: 100.0,
std_m3s: 10.0,
ar_coefficients: vec![0.5],
residual_std_ratio: 0.87,
};
let system = minimal_system_with_inflow_models(vec![model]);
let original_len = system.inflow_models().len();
let config = default_config();
let result = estimate_from_history(system, case_dir, &config).unwrap();
assert_eq!(
result.inflow_models().len(),
original_len,
"explicit stats: system must be unchanged"
);
}
#[test]
fn test_estimate_no_history_returns_unchanged() {
use tempfile::TempDir;
let dir = TempDir::new().unwrap();
let case_dir = dir.path();
create_required_files(case_dir);
let model = InflowModel {
hydro_id: EntityId(1),
stage_id: 0,
mean_m3s: 100.0,
std_m3s: 10.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
};
let system = minimal_system_with_inflow_models(vec![model]);
let original_len = system.inflow_models().len();
let config = default_config();
let result = estimate_from_history(system, case_dir, &config).unwrap();
assert_eq!(
result.inflow_models().len(),
original_len,
"no history: system must be unchanged"
);
}
fn default_config() -> Config {
use cobre_io::config::{EstimationConfig, OrderSelectionMethod};
let mut cfg: Config = serde_json::from_str(MINIMAL_CONFIG_JSON).unwrap();
cfg.estimation = EstimationConfig {
max_order: 2,
order_selection: OrderSelectionMethod::Fixed,
min_observations_per_season: 2,
};
cfg
}
const MINIMAL_CONFIG_JSON: &str = r#"{
"training": { "seed": 42 },
"simulation": { "enabled": false, "num_scenarios": 0, "io_channel_capacity": 16 },
"modeling": {},
"policy": {},
"exports": {},
"output": {}
}"#;
fn create_required_files(case_dir: &std::path::Path) {
let _ = std::fs::create_dir_all(case_dir.join("system"));
let _ = std::fs::create_dir_all(case_dir.join("scenarios"));
let write = |name: &str| {
let _ = std::fs::write(case_dir.join(name), b"{}");
};
write("config.json");
write("penalties.json");
write("stages.json");
write("initial_conditions.json");
write("system/buses.json");
write("system/lines.json");
write("system/hydros.json");
write("system/thermals.json");
}
}