use std::path::{Path, PathBuf};
use std::sync::mpsc;
use clap::Args;
use console::Term;
use cobre_comm::{
Communicator, ExecutionTopology, ReduceOp, TopologyProvider, create_communicator,
};
use cobre_core::{System, TrainingEvent};
use cobre_io::output::{
write_correlation_json, write_fitting_report, write_inflow_annual_component,
write_inflow_ar_coefficients, write_inflow_seasonal_stats, write_load_seasonal_stats,
write_noise_openings,
};
use cobre_io::scenarios::LoadSeasonalStatsRow;
use cobre_sddp::{
EstimationReport, PrepareHydroModelsResult, PrepareStochasticResult, StudySetup,
build_hydro_model_summary, estimation_report_to_fitting_report,
inflow_models_to_annual_component_rows, inflow_models_to_ar_rows, inflow_models_to_stats_rows,
prepare_hydro_models, prepare_stochastic,
setup::{ConstructionConfig, build_ncs_factor_entries, load_load_factors_for_stochastic},
};
use cobre_solver::HighsSolver;
use cobre_stochastic::{
OpeningTreeInputs, build_stochastic_context, context::OpeningTree,
provenance::ComponentProvenance,
};
use crate::error::CliError;
use crate::summary::{SimulationSummary, TrainingSummary};
use super::broadcast::{
BroadcastConfig, BroadcastOpeningTree, broadcast_value, stopping_rules_from_broadcast,
};
#[derive(Debug, Args)]
#[command(about = "Load a case directory, train an SDDP policy, and run simulation")]
pub struct RunArgs {
pub case_dir: PathBuf,
#[arg(long, value_name = "DIR")]
pub output: Option<PathBuf>,
#[arg(long)]
pub quiet: bool,
#[arg(long, value_parser = clap::value_parser!(u32).range(1..))]
pub threads: Option<u32>,
}
fn resolve_thread_count(cli_threads: Option<u32>) -> usize {
if let Some(n) = cli_threads {
return n as usize;
}
if let Ok(val) = std::env::var("COBRE_THREADS") {
if let Ok(n) = val.parse::<usize>() {
if n > 0 {
return n;
}
}
}
1
}
type LoadedCase = (
PrepareStochasticResult,
PrepareHydroModelsResult,
BroadcastConfig,
cobre_io::Config,
);
fn load_case_and_config(
args: &RunArgs,
quiet: bool,
stderr: &Term,
) -> Result<LoadedCase, CliError> {
if !args.case_dir.exists() {
return Err(CliError::Io {
source: std::io::Error::new(
std::io::ErrorKind::NotFound,
"case directory does not exist",
),
context: args.case_dir.display().to_string(),
});
}
if !quiet {
let _ = stderr.write_line(&format!("Loading case: {}", args.case_dir.display()));
}
let system = cobre_io::load_case(&args.case_dir)?;
let config_path = args.case_dir.join("config.json");
let config = cobre_io::parse_config(&config_path)?;
let bcast = BroadcastConfig::from_config(&config)?;
let seed = bcast.seed;
let prepared = prepare_stochastic(
system,
&args.case_dir,
&config,
seed,
&bcast.training_source,
)
.map_err(CliError::from)?;
let hydro_models =
prepare_hydro_models(&prepared.system, &args.case_dir).map_err(CliError::from)?;
Ok((prepared, hydro_models, bcast, config))
}
struct RunContext<C: Communicator> {
comm: C,
is_root: bool,
quiet: bool,
n_threads: usize,
output_dir: PathBuf,
term_width: u16,
stderr: Term,
render_mode: crate::progress::RenderMode,
topology: ExecutionTopology,
solver_version: String,
}
struct LoadBroadcastResult {
system: System,
setup: StudySetup,
root_config: Option<cobre_io::Config>,
root_estimation_report: Option<EstimationReport>,
root_estimation_path: Option<cobre_sddp::EstimationPath>,
training_enabled: bool,
policy_mode: cobre_io::PolicyMode,
}
struct TrainingPhaseResult {
result: cobre_sddp::TrainingResult,
output: cobre_io::TrainingOutput,
error: Option<cobre_sddp::SddpError>,
}
pub fn execute(args: &RunArgs) -> Result<(), CliError> {
let ctx = setup_communicator(args)?;
let result = execute_inner(&ctx, args);
if let Err(ref e) = result {
if ctx.comm.size() > 1 {
ctx.comm.abort(e.exit_code());
}
}
result
}
fn execute_inner<C: Communicator>(ctx: &RunContext<C>, args: &RunArgs) -> Result<(), CliError> {
let LoadBroadcastResult {
system,
mut setup,
mut root_config,
root_estimation_report,
root_estimation_path,
training_enabled,
policy_mode,
} = broadcast_and_build_setup(ctx, args)?;
run_pre_training(
ctx,
&system,
&setup,
root_config.as_ref(),
root_estimation_report.as_ref(),
root_estimation_path,
)?;
let hostname = ctx.topology.leader_hostname().to_string();
let mpi_world_size = u32::try_from(ctx.topology.world_size).unwrap_or(u32::MAX);
if training_enabled {
apply_training_policy(ctx, &system, &mut setup, root_config.as_ref(), policy_mode)?;
let training_started_at = cobre_io::now_iso8601();
let training = run_training_phase(ctx, &mut setup)?;
let training_completed_at = cobre_io::now_iso8601();
if ctx.is_root {
let config = root_config.take().ok_or_else(|| CliError::Internal {
message: "root_config was None on rank 0 — internal invariant violated".to_string(),
})?;
let training_ctx = cobre_io::OutputContext {
hostname: hostname.clone(),
solver: "highs".to_string(),
solver_version: Some(ctx.solver_version.clone()),
started_at: training_started_at,
completed_at: training_completed_at,
distribution: build_distribution_info(&ctx.topology, ctx.n_threads, mpi_world_size),
};
write_training_outputs(&WriteTrainingArgs {
output_dir: &ctx.output_dir,
system: &system,
config: &config,
training_output: &training.output,
setup: &setup,
training_result: &training.result,
output_ctx: &training_ctx,
hydro_models: &setup.hydro_models,
quiet: ctx.quiet,
stderr: &ctx.stderr,
})?;
drop(config);
}
if let Some(ref training_error) = training.error {
if ctx.is_root {
tracing::error!(
"training failed after {} iterations: {training_error}",
training.result.iterations
);
if !ctx.quiet {
let _ = ctx.stderr.write_line(&format!(
"Training failed after {} iterations. Partial outputs written to {}.",
training.result.iterations,
ctx.output_dir.display()
));
}
}
return Err(CliError::Internal {
message: format!("training error: {training_error}"),
});
}
if setup.simulation_config.n_scenarios > 0 {
run_simulation_phase(ctx, &system, &mut setup, &training.result, &hostname)?;
}
} else if setup.simulation_config.n_scenarios > 0 {
let training_result =
load_policy_for_simulation(ctx, &system, &mut setup, root_config.as_ref())?;
run_simulation_phase(ctx, &system, &mut setup, &training_result, &hostname)?;
} else {
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Training disabled, simulation disabled — nothing to do.");
}
}
Ok(())
}
fn load_and_validate_checkpoint(
policy_dir: &Path,
system: &System,
setup: &StudySetup,
root_config: Option<&cobre_io::Config>,
) -> Result<cobre_io::PolicyCheckpoint, CliError> {
let checkpoint = cobre_io::output::policy::read_policy_checkpoint(policy_dir).map_err(|e| {
CliError::Internal {
message: format!("failed to read policy checkpoint: {e}"),
}
})?;
if let Some(config) = root_config {
if config.policy.validate_compatibility {
#[allow(clippy::cast_possible_truncation)]
let n_stages = system.stages().iter().filter(|s| s.id >= 0).count() as u32;
let state_dim =
u32::try_from(setup.fcf.state_dimension).map_err(|e| CliError::Internal {
message: format!("state_dimension overflows u32: {e}"),
})?;
cobre_sddp::validate_policy_compatibility(&checkpoint.metadata, state_dim, n_stages)
.map_err(CliError::from)?;
}
}
Ok(checkpoint)
}
fn apply_training_policy(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
root_config: Option<&cobre_io::Config>,
policy_mode: cobre_io::PolicyMode,
) -> Result<(), CliError> {
match policy_mode {
cobre_io::PolicyMode::WarmStart => {
let policy_dir = ctx.output_dir.join(&setup.policy_path);
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot warm-start \
without a prior policy.",
policy_dir.display()
),
});
}
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Loading prior policy for warm-start training...");
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let warm_fcf = cobre_sddp::FutureCostFunction::new_with_warm_start(
&checkpoint.stage_cuts,
setup.loop_params.forward_passes,
setup.loop_params.max_iterations.saturating_add(1),
)
.map_err(CliError::from)?;
setup.replace_fcf(warm_fcf);
if ctx.is_root && !ctx.quiet {
let warm_count = setup.fcf.pools[0].warm_start_count;
let _ = ctx.stderr.write_line(&format!(
"Warm-start: loaded {warm_count} cuts per stage from prior policy."
));
}
}
cobre_io::PolicyMode::Resume => {
let policy_dir = ctx.output_dir.join(&setup.policy_path);
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot resume \
without a prior checkpoint.",
policy_dir.display()
),
});
}
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Loading prior checkpoint for resume training...");
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let completed = u64::from(checkpoint.metadata.completed_iterations);
if completed >= setup.loop_params.max_iterations && ctx.is_root && !ctx.quiet {
let _ = ctx.stderr.write_line(&format!(
"WARNING: Checkpoint already completed {completed} iterations \
(max_iterations = {}). No additional training will occur.",
setup.loop_params.max_iterations
));
}
let warm_fcf = cobre_sddp::FutureCostFunction::new_with_warm_start(
&checkpoint.stage_cuts,
setup.loop_params.forward_passes,
setup.loop_params.max_iterations.saturating_add(1),
)
.map_err(CliError::from)?;
setup.replace_fcf(warm_fcf);
setup.set_start_iteration(completed);
if ctx.is_root && !ctx.quiet {
let warm_count = setup.fcf.pools[0].warm_start_count;
let _ = ctx.stderr.write_line(&format!(
"Resume: loaded {warm_count} cuts per stage, \
resuming from iteration {completed}."
));
}
}
cobre_io::PolicyMode::Fresh => {}
}
if let Some(bp) = root_config.and_then(|c| c.policy.boundary.as_ref()) {
let boundary_path = ctx.output_dir.join(&bp.path);
#[allow(clippy::cast_possible_truncation)]
let state_dim = setup.fcf.state_dimension as u32;
let boundary_records =
cobre_sddp::load_boundary_cuts(&boundary_path, bp.source_stage, state_dim)
.map_err(CliError::from)?;
cobre_sddp::inject_boundary_cuts(setup, &boundary_records);
if ctx.is_root && !ctx.quiet {
let _ = ctx.stderr.write_line(&format!(
"Boundary cuts: loaded {} cuts from stage {} of {}",
boundary_records.len(),
bp.source_stage,
boundary_path.display()
));
}
}
Ok(())
}
fn load_policy_for_simulation(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
root_config: Option<&cobre_io::Config>,
) -> Result<cobre_sddp::TrainingResult, CliError> {
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Training disabled. Loading policy for simulation-only mode...");
}
let policy_dir = ctx.output_dir.join(&setup.policy_path);
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot run simulation-only \
mode without a trained policy.",
policy_dir.display()
),
});
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let loaded_fcf = cobre_sddp::FutureCostFunction::from_deserialized(&checkpoint.stage_cuts)
.map_err(CliError::from)?;
setup.replace_fcf(loaded_fcf);
let basis_cache = cobre_sddp::build_basis_cache_from_checkpoint(
setup.stage_data.stage_templates.templates.len(),
&checkpoint.stage_bases,
);
Ok(cobre_sddp::TrainingResult::new(
checkpoint.metadata.final_lower_bound,
checkpoint
.metadata
.best_upper_bound
.unwrap_or(f64::INFINITY),
0.0,
0.0,
checkpoint.metadata.completed_iterations.into(),
"loaded from checkpoint".to_string(),
0,
basis_cache,
Vec::new(),
None,
None,
))
}
fn setup_communicator(args: &RunArgs) -> Result<RunContext<impl Communicator>, CliError> {
let comm = create_communicator()?;
let is_root = comm.rank() == 0;
let quiet = args.quiet || !is_root;
let mpi_active = comm.size() > 1;
if mpi_active && is_root && !args.quiet {
console::set_colors_enabled_stderr(true);
}
let stderr = Term::stderr();
let topology = comm.topology().clone();
let configured_threads = resolve_thread_count(args.threads);
let actual_threads = match rayon::ThreadPoolBuilder::new()
.num_threads(configured_threads)
.build_global()
{
Ok(()) => configured_threads,
Err(err) => {
let actual = rayon::current_num_threads();
tracing::warn!(
configured = configured_threads,
actual,
%err,
"rayon global thread pool init failed; using existing pool",
);
actual
}
};
if actual_threads == 0 {
return Err(CliError::Internal {
message: "rayon reported zero active threads — unexpected state".to_string(),
});
}
let solver_version = cobre_solver::highs_version();
if !quiet {
crate::banner::print_banner(&stderr);
crate::summary::print_execution_topology(
&stderr,
&topology,
actual_threads,
"HiGHS",
Some(&solver_version),
);
}
let output_dir: PathBuf = args
.output
.clone()
.unwrap_or_else(|| args.case_dir.join("output"));
let term_width = crate::progress::resolve_term_width();
let render_mode = crate::progress::RenderMode::auto();
Ok(RunContext {
comm,
is_root,
quiet,
n_threads: actual_threads,
output_dir,
term_width,
stderr,
render_mode,
topology,
solver_version,
})
}
#[allow(clippy::too_many_lines)]
fn broadcast_and_build_setup(
ctx: &RunContext<impl Communicator>,
args: &RunArgs,
) -> Result<LoadBroadcastResult, CliError> {
let (
raw_system,
raw_bcast_config,
mut root_config,
root_stochastic,
root_estimation_report,
root_estimation_path,
raw_bcast_tree,
root_hydro_models,
load_err,
) = if ctx.is_root {
match load_case_and_config(args, ctx.quiet, &ctx.stderr) {
Ok((prepared, hydro_models, bcast, config)) => {
let bcast_tree = if prepared.stochastic.provenance().opening_tree
== ComponentProvenance::UserSupplied
{
let t = prepared.stochastic.opening_tree();
Some(BroadcastOpeningTree {
data: t.data().to_vec(),
openings_per_stage: t.openings_per_stage_slice().to_vec(),
dim: t.dim(),
})
} else {
None
};
let cobre_sddp::PrepareStochasticResult {
system,
stochastic,
estimation_report,
estimation_path,
} = prepared;
(
Some(system),
Some(bcast),
Some(config),
Some(stochastic),
estimation_report,
Some(estimation_path),
Some(bcast_tree),
Some(hydro_models),
None,
)
}
Err(e) => (None, None, None, None, None, None, None, None, Some(e)),
}
} else {
(None, None, None, None, None, None, None, None, None)
};
let root_estimation_report: Option<EstimationReport> = root_estimation_report;
let root_estimation_path: Option<cobre_sddp::EstimationPath> = root_estimation_path;
let system_result = broadcast_value(raw_system, &ctx.comm);
let bcast_config_result = broadcast_value(raw_bcast_config, &ctx.comm);
let root_hydro_models: Option<PrepareHydroModelsResult> = root_hydro_models;
let tree_result = broadcast_value(raw_bcast_tree, &ctx.comm);
if let Some(e) = load_err {
return Err(e);
}
let system = system_result?;
let mut bcast_config = bcast_config_result?;
let seed = bcast_config.seed;
let stochastic = if ctx.is_root {
drop(tree_result);
root_stochastic.ok_or_else(|| CliError::Internal {
message: "stochastic context missing on rank 0 after successful load".to_string(),
})?
} else {
let user_tree: Option<OpeningTree> =
tree_result?.map(|bt| OpeningTree::from_parts(bt.data, bt.openings_per_stage, bt.dim));
let training_src = &bcast_config.training_source;
let forward_seed = training_src.seed.map(i64::unsigned_abs);
let load_factor_entries =
load_load_factors_for_stochastic(&args.case_dir).map_err(|e| CliError::Internal {
message: format!("load factor error on non-root rank: {e}"),
})?;
let load_block_pairs: Vec<Vec<cobre_stochastic::normal::precompute::BlockFactorPair>> =
load_factor_entries
.iter()
.map(|e| {
e.block_factors
.iter()
.map(|bf| (bf.block_id, bf.factor))
.collect()
})
.collect();
let load_entity_factors: Vec<cobre_stochastic::normal::precompute::EntityFactorEntry<'_>> =
load_factor_entries
.iter()
.zip(load_block_pairs.iter())
.map(|(e, pairs)| (e.bus_id, e.stage_id, pairs.as_slice()))
.collect();
let ncs_raw = build_ncs_factor_entries(&system);
let ncs_entity_factors: Vec<cobre_stochastic::normal::precompute::EntityFactorEntry<'_>> =
ncs_raw
.iter()
.map(|(ncs_id, stage_id, pairs)| (*ncs_id, *stage_id, pairs.as_slice()))
.collect();
let opening_tree_library = {
use cobre_core::temporal::NoiseMethod;
let needs_historical_tree = system.stages().iter().any(|s| {
s.id >= 0 && s.scenario_config.noise_method == NoiseMethod::HistoricalResiduals
});
if needs_historical_tree {
let study_stages: Vec<_> = system
.stages()
.iter()
.filter(|s| s.id >= 0)
.cloned()
.collect();
let hydro_ids: Vec<cobre_core::EntityId> =
system.hydros().iter().map(|h| h.id).collect();
let par = cobre_stochastic::PrecomputedPar::build(
system.inflow_models(),
&study_stages,
&hydro_ids,
)
.map_err(|e| CliError::Internal {
message: format!("PAR build error on non-root rank: {e}"),
})?;
let max_order = par.max_order();
let user_pool = training_src.historical_years.as_ref();
let window_years = cobre_stochastic::discover_historical_windows(
system.inflow_history(),
&hydro_ids,
&study_stages,
max_order,
user_pool,
system.policy_graph().season_map.as_ref(),
1,
)
.map_err(|e| CliError::Internal {
message: format!("historical window discovery error on non-root rank: {e}"),
})?;
let mut lib = cobre_stochastic::HistoricalScenarioLibrary::new(
window_years.len(),
study_stages.len(),
hydro_ids.len(),
max_order,
window_years.clone(),
);
cobre_stochastic::standardize_historical_windows(
&mut lib,
system.inflow_history(),
&hydro_ids,
&study_stages,
&par,
&window_years,
system.policy_graph().season_map.as_ref(),
);
Some(lib)
} else {
None
}
};
build_stochastic_context(
&system,
seed,
forward_seed,
&load_entity_factors,
&ncs_entity_factors,
OpeningTreeInputs {
user_tree,
historical_library: opening_tree_library.as_ref(),
external_scenario_counts: None,
noise_group_ids: None,
},
cobre_stochastic::ClassSchemes {
inflow: Some(training_src.inflow_scheme),
load: Some(training_src.load_scheme),
ncs: Some(training_src.ncs_scheme),
},
)
.map_err(|e| CliError::Internal {
message: format!("stochastic context error: {e}"),
})?
};
let hydro_models = if ctx.is_root {
root_hydro_models.ok_or_else(|| CliError::Internal {
message: "hydro models missing on rank 0 after successful load".to_string(),
})?
} else {
prepare_hydro_models(&system, &args.case_dir).map_err(|e| CliError::Internal {
message: format!("hydro model preprocessing error on non-root rank: {e}"),
})?
};
let training_enabled = bcast_config.training_enabled;
let policy_mode = bcast_config.policy_mode;
let setup = build_study_setup(&system, &mut bcast_config, stochastic, hydro_models)?;
Ok(LoadBroadcastResult {
system,
setup,
root_config: root_config.take(),
root_estimation_report,
root_estimation_path,
training_enabled,
policy_mode,
})
}
fn build_study_setup(
system: &System,
bcast_config: &mut BroadcastConfig,
stochastic: cobre_stochastic::StochasticContext,
hydro_models: PrepareHydroModelsResult,
) -> Result<StudySetup, CliError> {
let stopping_rule_set = stopping_rules_from_broadcast(bcast_config);
let cut_selection = bcast_config.cut_selection.take();
let config = ConstructionConfig {
seed: bcast_config.seed,
forward_passes: bcast_config.forward_passes,
stopping_rule_set,
n_scenarios: bcast_config.n_scenarios,
io_channel_capacity: usize::try_from(bcast_config.io_channel_capacity).unwrap_or(64),
policy_path: bcast_config.policy_path.clone(),
inflow_method: bcast_config.inflow_method.clone(),
cut_selection,
cut_activity_tolerance: bcast_config.cut_activity_tolerance,
basis_activity_window: bcast_config.basis_activity_window,
budget: bcast_config.budget,
export_states: bcast_config.export_states,
};
StudySetup::from_broadcast_params(
system,
stochastic,
config,
hydro_models,
&bcast_config.training_source,
&bcast_config.simulation_source,
)
.map_err(CliError::from)
}
fn run_pre_training(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &StudySetup,
root_config: Option<&cobre_io::Config>,
root_estimation_report: Option<&EstimationReport>,
root_estimation_path: Option<cobre_sddp::EstimationPath>,
) -> Result<(), CliError> {
if !ctx.quiet && ctx.is_root {
let hydro_summary = build_hydro_model_summary(&setup.hydro_models, system);
crate::summary::print_hydro_model_summary(&ctx.stderr, &hydro_summary);
}
if ctx.is_root {
if let Some(path) = root_estimation_path {
let provenance = cobre_sddp::build_provenance_report(
path,
root_estimation_report,
setup.stochastic.provenance(),
system.hydros().len(),
);
if !ctx.quiet {
crate::summary::print_provenance_summary(&ctx.stderr, &provenance);
}
let provenance_path = ctx.output_dir.join("training/model_provenance.json");
cobre_io::write_provenance_report(&provenance_path, &provenance).map_err(|e| {
CliError::Internal {
message: format!("failed to write provenance report: {e}"),
}
})?;
}
}
if ctx.is_root && root_config.is_some_and(|c| c.exports.stochastic) {
export_stochastic_artifacts(
&ctx.output_dir,
&setup.stochastic,
system,
root_estimation_report,
ctx.quiet,
&ctx.stderr,
);
}
if ctx.is_root {
let scaling_path = ctx.output_dir.join("training/scaling_report.json");
cobre_io::write_scaling_report(&scaling_path, &setup.stage_data.scaling_report).map_err(
|e| CliError::Internal {
message: format!("failed to write scaling report: {e}"),
},
)?;
}
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-export barrier error: {e}"),
})?;
Ok(())
}
#[allow(clippy::too_many_lines)]
fn run_training_phase(
ctx: &RunContext<impl Communicator>,
setup: &mut StudySetup,
) -> Result<TrainingPhaseResult, CliError> {
let solver_factory = || HighsSolver::new();
let mut solver = HighsSolver::new().map_err(|e| CliError::Solver {
message: format!("HiGHS initialisation failed: {e}"),
})?;
let (event_tx, event_rx) = mpsc::channel::<TrainingEvent>();
let quiet_rx: Option<mpsc::Receiver<TrainingEvent>>;
let progress_handle = if ctx.quiet {
quiet_rx = Some(event_rx);
None
} else {
quiet_rx = None;
Some(crate::progress::run_progress_thread(
event_rx,
ctx.render_mode,
setup.loop_params.max_iterations,
ctx.term_width,
))
};
let training_outcome = match setup.train(
&mut solver,
&ctx.comm,
ctx.n_threads,
solver_factory,
Some(event_tx),
None,
) {
Ok(outcome) => outcome,
Err(e) => {
if let Some(handle) = progress_handle {
let _ = handle.join();
}
return Err(CliError::from(e));
}
};
let training_result = training_outcome.result;
let events: Vec<TrainingEvent> = match (progress_handle, quiet_rx) {
(Some(handle), _) => handle.join(),
(None, Some(rx)) => rx.try_iter().collect(),
(None, None) => Vec::new(),
};
let training_output = setup.build_training_output(&training_result, &events);
let local_lp_solves: u64 = training_output
.convergence_records
.iter()
.map(|r| u64::from(r.lp_solves))
.sum();
let mut global_lp_solves = [0u64];
ctx.comm
.allreduce(&[local_lp_solves], &mut global_lp_solves, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("LP solve count allreduce error: {e}"),
})?;
let global_lp_solves = global_lp_solves[0];
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-training barrier error: {e}"),
})?;
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
let my_rank = ctx.comm.rank() as i32;
let (
local_first_try,
local_retried,
local_failed,
local_forward_solve_s,
local_backward_solve_s,
) = aggregate_solver_stats(&training_result.solver_stats_log, my_rank);
let training_guard_delta = cobre_sddp::SolverStatsDelta {
lp_successes: local_first_try.saturating_add(local_retried),
first_try_successes: local_first_try,
lp_failures: local_failed,
..cobre_sddp::SolverStatsDelta::default()
};
check_stats_overflow(&training_guard_delta)?;
#[allow(clippy::cast_precision_loss)]
let send_stats = [
local_first_try as f64,
local_retried as f64,
local_failed as f64,
local_forward_solve_s,
local_backward_solve_s,
];
let mut recv_stats = [0.0_f64; 5];
ctx.comm
.allreduce(&send_stats, &mut recv_stats, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("training solver stats allreduce error: {e}"),
})?;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let (
total_first_try,
total_retried,
total_failed,
total_forward_solve_s,
total_backward_solve_s,
) = (
recv_stats[0] as u64,
recv_stats[1] as u64,
recv_stats[2] as u64,
recv_stats[3],
recv_stats[4],
);
let initial_gap_percent = training_output
.convergence_records
.first()
.and_then(|r| r.gap_percent);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let parallelism = (ctx.n_threads as u32).saturating_mul(ctx.comm.size() as u32);
let training_summary = TrainingSummary {
iterations: training_result.iterations,
converged: training_output.converged,
converged_at: if training_output.converged {
Some(training_result.iterations)
} else {
None
},
reason: training_result.reason.clone(),
lower_bound: training_result.final_lb,
upper_bound: training_result.final_ub,
upper_bound_std: training_result.final_ub_std,
gap_percent: training_result.final_gap * 100.0,
total_rows_active: training_output.cut_stats.total_active,
total_rows_generated: training_output.cut_stats.total_generated,
total_lp_solves: global_lp_solves,
total_time_ms: training_result.total_time_ms,
total_first_try: Some(total_first_try),
total_retried: Some(total_retried),
total_failed: Some(total_failed),
total_forward_solve_seconds: Some(total_forward_solve_s),
total_backward_solve_seconds: Some(total_backward_solve_s),
parallelism: Some(parallelism),
initial_gap_percent,
};
if !ctx.quiet && ctx.is_root {
crate::summary::print_training_summary(&ctx.stderr, &training_summary);
}
Ok(TrainingPhaseResult {
result: training_result,
output: training_output,
error: training_outcome.error,
})
}
fn aggregate_solver_stats(
stats_log: &[(
u64,
&'static str,
i32,
i32,
i32,
i32,
cobre_sddp::SolverStatsDelta,
)],
my_rank: i32,
) -> (u64, u64, u64, f64, f64) {
let mut first_try = 0u64;
let mut retried = 0u64;
let mut failed = 0u64;
let mut forward_solve_ms = 0.0_f64;
let mut backward_solve_ms = 0.0_f64;
for (_, phase, _, _, entry_rank, _, delta) in stats_log {
if *entry_rank != my_rank {
continue;
}
first_try += delta.first_try_successes;
retried += delta.lp_successes.saturating_sub(delta.first_try_successes);
failed += delta.lp_failures;
match *phase {
"forward" => forward_solve_ms += delta.solve_time_ms,
"backward" => backward_solve_ms += delta.solve_time_ms,
_ => {}
}
}
(
first_try,
retried,
failed,
forward_solve_ms / 1000.0,
backward_solve_ms / 1000.0,
)
}
fn check_stats_overflow(delta: &cobre_sddp::SolverStatsDelta) -> Result<(), CliError> {
const F64_INTEGER_LIMIT: u64 = 1u64 << 53;
for (label, value) in [
("lp_solves", delta.lp_solves),
("lp_successes", delta.lp_successes),
("first_try_successes", delta.first_try_successes),
("lp_failures", delta.lp_failures),
("retry_attempts", delta.retry_attempts),
("basis_offered", delta.basis_offered),
(
"basis_consistency_failures",
delta.basis_consistency_failures,
),
("simplex_iterations", delta.simplex_iterations),
("load_model_count", delta.load_model_count),
] {
if value > F64_INTEGER_LIMIT {
return Err(CliError::Internal {
message: format!(
"solver stats counter '{label}' = {value} \
exceeds 2^53 (f64 integer-precision limit). MPI \
allreduce(Sum) packing would lose precision. \
Reduce iteration count or split the run."
),
});
}
}
Ok(())
}
fn run_simulation_phase(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
training_result: &cobre_sddp::TrainingResult,
hostname: &str,
) -> Result<(), CliError> {
let solver_factory = || HighsSolver::new();
let n_scenarios = setup.simulation_config.n_scenarios;
let sim_config = setup.simulation_config();
let mut sim_pool = setup
.create_workspace_pool(&ctx.comm, ctx.n_threads, solver_factory)
.map_err(|e| CliError::Solver {
message: format!("HiGHS initialisation failed for simulation pool: {e}"),
})?;
let (sim_event_tx, sim_event_rx) = mpsc::channel::<TrainingEvent>();
let sim_progress_handle = if ctx.quiet {
drop(sim_event_rx);
None
} else {
Some(crate::progress::run_progress_thread(
sim_event_rx,
ctx.render_mode,
u64::from(n_scenarios),
ctx.term_width,
))
};
let io_capacity = sim_config.io_channel_capacity;
let (result_tx, result_rx) = mpsc::sync_channel(io_capacity.max(1));
let parquet_config = cobre_io::ParquetWriterConfig::default();
let mut sim_writer = cobre_io::output::simulation_writer::SimulationParquetWriter::new(
&ctx.output_dir,
system,
&parquet_config,
)
.map_err(CliError::from)?;
let drain_handle = std::thread::spawn(move || {
let mut failed: u32 = 0;
for scenario_result in result_rx {
let payload =
cobre_io::output::simulation_writer::ScenarioWritePayload::from(scenario_result);
if let Err(e) = sim_writer.write_scenario(payload) {
tracing::error!("simulation write error: {e}");
failed += 1;
}
}
(sim_writer, failed)
});
let sim_started_at = cobre_io::now_iso8601();
let sim_start = std::time::Instant::now();
let sim_result = setup
.simulate(
&mut sim_pool.workspaces,
&ctx.comm,
&result_tx,
Some(sim_event_tx),
training_result.baked_templates.as_deref(),
&training_result.basis_cache,
)
.map_err(CliError::from);
if let Some(handle) = sim_progress_handle {
let _ = handle.join();
}
drop(result_tx);
#[allow(clippy::expect_used)]
let (sim_writer, write_failures) = drain_handle.join().expect("drain thread panicked");
let sim_run_result = sim_result?;
#[allow(clippy::cast_possible_truncation)]
let sim_time_ms = sim_start.elapsed().as_millis() as u64;
let mut local_sim_output = sim_writer.finalize(sim_time_ms);
local_sim_output.failed = write_failures;
let merged_sim_output = merge_simulation_metadata(&ctx.comm, &local_sim_output)?;
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-simulation barrier error: {e}"),
})?;
let (global_agg, global_scenario_stats) =
aggregate_simulation_solver_stats(&ctx.comm, &sim_run_result.solver_stats)?;
let cost_summary =
cobre_sddp::aggregate_simulation(&sim_run_result.costs, sim_config, &ctx.comm).map_err(
|e| CliError::Internal {
message: format!("simulation cost aggregation error: {e}"),
},
)?;
if !ctx.quiet && ctx.is_root {
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let parallelism = (ctx.n_threads as u32).saturating_mul(ctx.comm.size() as u32);
print_sim_summary(
&ctx.stderr,
n_scenarios,
sim_time_ms,
&global_agg,
&cost_summary,
parallelism,
);
}
if ctx.is_root {
write_sim_outputs_on_root(
ctx,
hostname,
sim_started_at,
&merged_sim_output,
&global_scenario_stats,
)?;
}
Ok(())
}
fn write_sim_outputs_on_root(
ctx: &RunContext<impl Communicator>,
hostname: &str,
sim_started_at: String,
merged_sim_output: &cobre_io::SimulationOutput,
global_scenario_stats: &[(u32, cobre_sddp::SolverStatsDelta)],
) -> Result<(), CliError> {
let mpi_world_size = u32::try_from(ctx.topology.world_size).unwrap_or(u32::MAX);
let sim_ctx = cobre_io::OutputContext {
hostname: hostname.to_string(),
solver: "highs".to_string(),
solver_version: None,
started_at: sim_started_at,
completed_at: cobre_io::now_iso8601(),
distribution: build_distribution_info(&ctx.topology, ctx.n_threads, mpi_world_size),
};
write_simulation_outputs(&WriteSimulationArgs {
output_dir: &ctx.output_dir,
sim_output: merged_sim_output,
sim_solver_stats: global_scenario_stats,
output_ctx: &sim_ctx,
quiet: ctx.quiet,
stderr: &ctx.stderr,
})
}
fn build_distribution_info(
topology: &ExecutionTopology,
n_threads: usize,
ranks_participated: u32,
) -> cobre_io::DistributionInfo {
use cobre_comm::BackendKind;
cobre_io::DistributionInfo {
backend: match topology.backend {
BackendKind::Mpi => "mpi",
BackendKind::Local => "local",
BackendKind::Auto => "unknown",
}
.to_string(),
world_size: u32::try_from(topology.world_size).unwrap_or(u32::MAX),
ranks_participated,
num_nodes: u32::try_from(topology.num_hosts()).unwrap_or(u32::MAX),
threads_per_rank: u32::try_from(n_threads).unwrap_or(u32::MAX),
mpi_library: topology.mpi.as_ref().map(|m| m.library_version.clone()),
mpi_standard: topology.mpi.as_ref().map(|m| m.standard_version.clone()),
thread_level: topology.mpi.as_ref().map(|m| m.thread_level.clone()),
slurm_job_id: topology.slurm.as_ref().map(|s| s.job_id.clone()),
}
}
fn print_sim_summary(
stderr: &Term,
n_scenarios: u32,
sim_time_ms: u64,
agg: &cobre_sddp::SolverStatsDelta,
cost_summary: &cobre_sddp::SimulationSummary,
parallelism: u32,
) {
crate::summary::print_simulation_summary(
stderr,
&SimulationSummary {
n_scenarios,
completed: n_scenarios,
failed: 0,
total_time_ms: sim_time_ms,
mean_cost: Some(cost_summary.mean_cost),
std_cost: Some(cost_summary.std_cost),
total_lp_solves: Some(agg.lp_solves),
total_first_try: Some(agg.first_try_successes),
total_retried: Some(agg.lp_successes.saturating_sub(agg.first_try_successes)),
total_failed_solves: Some(agg.lp_failures),
total_solve_time_seconds: Some(agg.solve_time_ms / 1000.0),
parallelism: Some(parallelism),
},
);
}
#[allow(clippy::cast_possible_truncation)]
fn merge_simulation_metadata<C: Communicator>(
comm: &C,
local: &cobre_io::SimulationOutput,
) -> Result<cobre_io::SimulationOutput, CliError> {
let send_counts = [local.n_scenarios, local.completed, local.failed];
let mut merged_counts = [0u32; 3];
comm.allreduce(&send_counts, &mut merged_counts, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("simulation metadata count allreduce error: {e}"),
})?;
let send_time = [local.total_time_ms];
let mut merged_time = [0u64; 1];
comm.allreduce(&send_time, &mut merged_time, ReduceOp::Max)
.map_err(|e| CliError::Internal {
message: format!("simulation metadata time allreduce error: {e}"),
})?;
let local_paths_bytes = local.partitions_written.join("\n").into_bytes();
let send_len = [local_paths_bytes.len() as u64];
let n_ranks = comm.size();
let mut all_lens = vec![0u64; n_ranks];
let len_counts: Vec<usize> = vec![1; n_ranks];
let len_displs: Vec<usize> = (0..n_ranks).collect();
comm.allgatherv(&send_len, &mut all_lens, &len_counts, &len_displs)
.map_err(|e| CliError::Internal {
message: format!("partition path length exchange error: {e}"),
})?;
let recv_counts: Vec<usize> = all_lens.iter().map(|&l| l as usize).collect();
let recv_displs: Vec<usize> = recv_counts
.iter()
.scan(0usize, |acc, &c| {
let d = *acc;
*acc += c;
Some(d)
})
.collect();
let total_bytes: usize = recv_counts.iter().sum();
let mut all_bytes = vec![0u8; total_bytes];
comm.allgatherv(
&local_paths_bytes,
&mut all_bytes,
&recv_counts,
&recv_displs,
)
.map_err(|e| CliError::Internal {
message: format!("partition path gather error: {e}"),
})?;
let mut all_partitions: Vec<String> = Vec::new();
for (i, &count) in recv_counts.iter().enumerate() {
if count == 0 {
continue;
}
let start = recv_displs[i];
let chunk = &all_bytes[start..start + count];
let text = std::str::from_utf8(chunk).map_err(|e| CliError::Internal {
message: format!("partition path UTF-8 decode error from rank {i}: {e}"),
})?;
all_partitions.extend(text.split('\n').filter(|s| !s.is_empty()).map(String::from));
}
all_partitions.sort();
Ok(cobre_io::SimulationOutput {
n_scenarios: merged_counts[0],
completed: merged_counts[1],
failed: merged_counts[2],
total_time_ms: merged_time[0],
partitions_written: all_partitions,
})
}
#[allow(clippy::cast_possible_truncation)]
fn aggregate_simulation_solver_stats<C: Communicator>(
comm: &C,
local_stats: &[(u32, i32, cobre_sddp::SolverStatsDelta)],
) -> Result<
(
cobre_sddp::SolverStatsDelta,
Vec<(u32, cobre_sddp::SolverStatsDelta)>,
),
CliError,
> {
let local_agg = cobre_sddp::SolverStatsDelta::aggregate(local_stats.iter().map(|(_, _, d)| d));
check_stats_overflow(&local_agg)?;
let send_scalars = cobre_sddp::pack_delta_scalars(&local_agg);
let mut recv_scalars = [0.0_f64; cobre_sddp::SOLVER_STATS_DELTA_SCALAR_FIELDS];
comm.allreduce(&send_scalars, &mut recv_scalars, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("simulation solver stats allreduce error: {e}"),
})?;
let global_agg = cobre_sddp::unpack_delta_scalars(&recv_scalars);
let local_stats_stripped: Vec<(u32, cobre_sddp::SolverStatsDelta)> = local_stats
.iter()
.map(|(id, _opening, delta)| (*id, delta.clone()))
.collect();
let n_ranks = comm.size();
let local_buf = cobre_sddp::pack_scenario_stats(&local_stats_stripped);
let local_count = local_buf.len();
let send_len = [local_count as u64];
let mut all_lens = vec![0u64; n_ranks];
let len_counts: Vec<usize> = vec![1; n_ranks];
let len_displs: Vec<usize> = (0..n_ranks).collect();
comm.allgatherv(&send_len, &mut all_lens, &len_counts, &len_displs)
.map_err(|e| CliError::Internal {
message: format!("simulation solver stats length exchange error: {e}"),
})?;
let recv_counts: Vec<usize> = all_lens.iter().map(|&l| l as usize).collect();
let recv_displs: Vec<usize> = recv_counts
.iter()
.scan(0usize, |acc, &c| {
let d = *acc;
*acc += c;
Some(d)
})
.collect();
let total_floats: usize = recv_counts.iter().sum();
let mut all_buf = vec![0.0_f64; total_floats];
comm.allgatherv(&local_buf, &mut all_buf, &recv_counts, &recv_displs)
.map_err(|e| CliError::Internal {
message: format!("simulation solver stats gather error: {e}"),
})?;
let mut global_scenario_stats = cobre_sddp::unpack_scenario_stats(&all_buf);
global_scenario_stats.sort_by_key(|(id, _)| *id);
Ok((global_agg, global_scenario_stats))
}
#[allow(clippy::cast_possible_truncation)]
fn delta_to_stats_row(
id: u32,
phase: &str,
stage: i32,
opening: Option<i32>,
rank: Option<i32>,
worker_id: Option<i32>,
delta: &cobre_sddp::SolverStatsDelta,
) -> cobre_io::SolverStatsRow {
cobre_io::SolverStatsRow {
iteration: id,
phase: phase.to_string(),
stage,
opening,
rank,
worker_id,
lp_solves: delta.lp_solves as u32,
lp_successes: delta.lp_successes as u32,
lp_retries: delta.lp_successes.saturating_sub(delta.first_try_successes) as u32,
lp_failures: delta.lp_failures as u32,
retry_attempts: delta.retry_attempts as u32,
basis_offered: delta.basis_offered as u32,
basis_consistency_failures: delta.basis_consistency_failures as u32,
simplex_iterations: delta.simplex_iterations,
solve_time_ms: delta.solve_time_ms,
load_model_time_ms: delta.load_model_time_ms,
set_bounds_time_ms: delta.set_bounds_time_ms,
basis_set_time_ms: delta.basis_set_time_ms,
basis_reconstructions: delta.basis_reconstructions,
retry_level_histogram: delta.retry_level_histogram.clone(),
}
}
struct WriteTrainingArgs<'a> {
output_dir: &'a Path,
system: &'a System,
config: &'a cobre_io::Config,
training_output: &'a cobre_io::TrainingOutput,
setup: &'a StudySetup,
training_result: &'a cobre_sddp::TrainingResult,
output_ctx: &'a cobre_io::OutputContext,
hydro_models: &'a cobre_sddp::PrepareHydroModelsResult,
quiet: bool,
stderr: &'a Term,
}
fn write_training_outputs(args: &WriteTrainingArgs<'_>) -> Result<(), CliError> {
if !args.quiet {
use std::io::Write;
let _ = args.stderr.write_line("Writing training outputs...");
let _ = std::io::stderr().flush();
}
let write_start = std::time::Instant::now();
let policy_dir = args.output_dir.join(&args.setup.policy_path);
crate::policy_io::write_checkpoint(
&policy_dir,
&args.setup.fcf,
args.training_result,
&crate::policy_io::CheckpointParams {
max_iterations: args.setup.loop_params.max_iterations,
forward_passes: args.setup.loop_params.forward_passes,
seed: args.setup.loop_params.seed,
export_states: args.config.exports.states,
},
)?;
cobre_io::write_training_results(
args.output_dir,
args.training_output,
args.system,
args.config,
args.output_ctx,
)
.map_err(CliError::from)?;
if !args.hydro_models.fpha_export_rows.is_empty() {
let fpha_path = args
.output_dir
.join("hydro_models")
.join("fpha_hyperplanes.parquet");
cobre_io::output::write_fpha_hyperplanes(&fpha_path, &args.hydro_models.fpha_export_rows)
.map_err(CliError::from)?;
}
if !args.training_result.solver_stats_log.is_empty() {
let rows: Vec<cobre_io::SolverStatsRow> = args
.training_result
.solver_stats_log
.iter()
.map(|(iter, phase, stage, opening, rank, worker_id, delta)| {
let opening_opt = if *opening == -1 { None } else { Some(*opening) };
let worker_id_opt = if *worker_id == -1 {
None
} else {
Some(*worker_id)
};
#[allow(clippy::cast_possible_truncation)] let id = *iter as u32;
delta_to_stats_row(
id,
phase,
*stage,
opening_opt,
Some(*rank),
worker_id_opt,
delta,
)
})
.collect();
cobre_io::write_solver_stats(args.output_dir, &rows).map_err(CliError::from)?;
}
if !args.training_output.cut_selection_records.is_empty() {
let parquet_config = cobre_io::ParquetWriterConfig::default();
cobre_io::write_row_selection_records(
args.output_dir,
&args.training_output.cut_selection_records,
&parquet_config,
)
.map_err(CliError::from)?;
}
if !args.quiet {
let write_secs = write_start.elapsed().as_secs_f64();
crate::summary::print_output_path(args.stderr, args.output_dir, write_secs);
}
Ok(())
}
struct WriteSimulationArgs<'a> {
output_dir: &'a Path,
sim_output: &'a cobre_io::SimulationOutput,
sim_solver_stats: &'a [(u32, cobre_sddp::SolverStatsDelta)],
output_ctx: &'a cobre_io::OutputContext,
quiet: bool,
stderr: &'a Term,
}
fn write_simulation_outputs(args: &WriteSimulationArgs<'_>) -> Result<(), CliError> {
if !args.quiet {
use std::io::Write;
let _ = args.stderr.write_line("Writing simulation outputs...");
let _ = std::io::stderr().flush();
}
let write_start = std::time::Instant::now();
cobre_io::write_simulation_results(args.output_dir, args.sim_output, args.output_ctx)
.map_err(CliError::from)?;
if !args.sim_solver_stats.is_empty() {
let rows: Vec<cobre_io::SolverStatsRow> = args
.sim_solver_stats
.iter()
.map(|(scenario_id, delta)| {
delta_to_stats_row(*scenario_id, "simulation", -1, None, None, None, delta)
})
.collect();
cobre_io::write_simulation_solver_stats(args.output_dir, &rows).map_err(CliError::from)?;
}
if !args.quiet {
let write_secs = write_start.elapsed().as_secs_f64();
crate::summary::print_output_path(args.stderr, args.output_dir, write_secs);
}
Ok(())
}
fn export_stochastic_artifacts(
output_dir: &Path,
stochastic: &cobre_stochastic::StochasticContext,
system: &System,
estimation_report: Option<&EstimationReport>,
quiet: bool,
stderr: &Term,
) {
use cobre_core::scenario::LoadModel;
let stochastic_dir = output_dir.join("stochastic");
if !quiet {
let _ = stderr.write_line("Exporting stochastic artifacts...");
}
if let Err(e) = write_noise_openings(
&stochastic_dir.join("noise_openings.parquet"),
stochastic.opening_tree(),
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (noise_openings): {e}"
));
}
}
let stats_rows = inflow_models_to_stats_rows(system.inflow_models());
if let Err(e) = write_inflow_seasonal_stats(
&stochastic_dir.join("inflow_seasonal_stats.parquet"),
&stats_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (inflow_seasonal_stats): {e}"
));
}
}
let ar_rows = inflow_models_to_ar_rows(system.inflow_models());
if let Err(e) = write_inflow_ar_coefficients(
&stochastic_dir.join("inflow_ar_coefficients.parquet"),
&ar_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (inflow_ar_coefficients): {e}"
));
}
}
let annual_rows = inflow_models_to_annual_component_rows(system.inflow_models());
if let Err(e) = write_inflow_annual_component(
&stochastic_dir.join("inflow_annual_component.parquet"),
&annual_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (inflow_annual_component): {e}"
));
}
}
if let Err(e) = write_correlation_json(
&stochastic_dir.join("correlation.json"),
system.correlation(),
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (correlation): {e}"
));
}
}
let has_stochastic_load = system
.load_models()
.iter()
.any(|m: &LoadModel| m.std_mw > 0.0);
if has_stochastic_load {
let load_rows: Vec<LoadSeasonalStatsRow> = system
.load_models()
.iter()
.map(|m| LoadSeasonalStatsRow {
bus_id: m.bus_id,
stage_id: m.stage_id,
mean_mw: m.mean_mw,
std_mw: m.std_mw,
})
.collect();
if let Err(e) = write_load_seasonal_stats(
&stochastic_dir.join("load_seasonal_stats.parquet"),
&load_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (load_seasonal_stats): {e}"
));
}
}
}
if let Some(report) = estimation_report {
let fitting = estimation_report_to_fitting_report(report);
if let Err(e) = write_fitting_report(&stochastic_dir.join("fitting_report.json"), &fitting)
{
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (fitting_report): {e}"
));
}
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::float_cmp,
clippy::panic
)]
mod tests {
use super::{check_stats_overflow, delta_to_stats_row, resolve_thread_count};
use cobre_sddp::SolverStatsDelta;
fn make_delta(lp_solves: u64) -> SolverStatsDelta {
SolverStatsDelta {
lp_solves,
..SolverStatsDelta::default()
}
}
#[test]
fn test_resolve_thread_count_cli_value() {
assert_eq!(resolve_thread_count(Some(4)), 4);
}
#[test]
fn test_resolve_thread_count_default() {
assert_eq!(resolve_thread_count(Some(1)), 1);
}
#[test]
fn test_delta_to_stats_row_backward_carries_opening_rank_worker() {
let delta = make_delta(10);
let row = delta_to_stats_row(1, "backward", 2, Some(0), Some(1), Some(3), &delta);
assert_eq!(row.opening, Some(0));
assert_eq!(row.rank, Some(1));
assert_eq!(row.worker_id, Some(3));
assert_eq!(row.stage, 2);
assert_eq!(row.phase, "backward");
assert_eq!(row.lp_solves, 10);
}
#[test]
fn test_delta_to_stats_row_forward_opening_and_worker_id_are_none() {
let delta = make_delta(4);
let row = delta_to_stats_row(1, "forward", 0, None, Some(0), None, &delta);
assert_eq!(row.opening, None);
assert_eq!(row.rank, Some(0));
assert_eq!(row.worker_id, None);
assert_eq!(row.stage, 0);
assert_eq!(row.lp_solves, 4);
}
#[test]
fn test_delta_to_stats_row_simulation_rank_and_worker_id_are_none() {
let delta = make_delta(7);
let row = delta_to_stats_row(42, "simulation", -1, None, None, None, &delta);
assert_eq!(row.opening, None);
assert_eq!(row.rank, None);
assert_eq!(row.worker_id, None);
assert_eq!(row.iteration, 42);
}
fn delta_with_field(field: &str, value: u64) -> SolverStatsDelta {
let mut d = SolverStatsDelta::default();
match field {
"lp_solves" => d.lp_solves = value,
"lp_successes" => d.lp_successes = value,
"first_try_successes" => d.first_try_successes = value,
"lp_failures" => d.lp_failures = value,
"retry_attempts" => d.retry_attempts = value,
"basis_offered" => d.basis_offered = value,
"basis_consistency_failures" => d.basis_consistency_failures = value,
"simplex_iterations" => d.simplex_iterations = value,
"load_model_count" => d.load_model_count = value,
other => panic!("unknown field: {other}"),
}
d
}
#[test]
fn test_overflow_guard_rejects_excessive_counter() {
let over_limit = (1u64 << 53) + 1;
let fields = [
"lp_solves",
"lp_successes",
"first_try_successes",
"lp_failures",
"retry_attempts",
"basis_offered",
"basis_consistency_failures",
"simplex_iterations",
"load_model_count",
];
for field in fields {
let delta = delta_with_field(field, over_limit);
let err = check_stats_overflow(&delta).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("exceeds 2^53"),
"field '{field}': message was: {msg}"
);
assert!(
msg.contains(field),
"field '{field}': label missing in message: {msg}"
);
}
}
#[test]
fn test_overflow_guard_allows_exact_limit() {
let at_limit = 1u64 << 53;
let delta = SolverStatsDelta {
lp_solves: at_limit,
lp_successes: at_limit,
first_try_successes: at_limit,
lp_failures: at_limit,
retry_attempts: at_limit,
basis_offered: at_limit,
basis_consistency_failures: at_limit,
simplex_iterations: at_limit,
load_model_count: at_limit,
..SolverStatsDelta::default()
};
check_stats_overflow(&delta)
.expect("2^53 is representable in f64 and must not trigger the guard");
}
}