use std::path::{Path, PathBuf};
use std::sync::mpsc;
use clap::Args;
use console::Term;
use cobre_comm::{Communicator, ReduceOp, create_communicator};
use cobre_core::{System, TrainingEvent};
use cobre_io::output::{
write_correlation_json, write_fitting_report, write_inflow_ar_coefficients,
write_inflow_seasonal_stats, write_load_seasonal_stats, write_noise_openings,
};
use cobre_io::scenarios::LoadSeasonalStatsRow;
use cobre_sddp::{
EstimationReport, PrepareHydroModelsResult, PrepareStochasticResult, StudySetup,
build_hydro_model_summary, build_stochastic_summary, estimation_report_to_fitting_report,
inflow_models_to_ar_rows, inflow_models_to_stats_rows, prepare_hydro_models,
prepare_stochastic,
};
use cobre_solver::HighsSolver;
use cobre_stochastic::{
build_stochastic_context, context::OpeningTree, provenance::ComponentProvenance,
};
use crate::error::CliError;
use crate::summary::{SimulationSummary, TrainingSummary};
use super::broadcast::{
BroadcastConfig, BroadcastCutSelection, BroadcastOpeningTree, broadcast_value,
stopping_rules_from_broadcast,
};
#[derive(Debug, Args)]
#[command(about = "Load a case directory, train an SDDP policy, and run simulation")]
pub struct RunArgs {
pub case_dir: PathBuf,
#[arg(long, value_name = "DIR")]
pub output: Option<PathBuf>,
#[arg(long)]
pub quiet: bool,
#[arg(long, value_parser = clap::value_parser!(u32).range(1..))]
pub threads: Option<u32>,
}
fn resolve_thread_count(cli_threads: Option<u32>) -> usize {
if let Some(n) = cli_threads {
return n as usize;
}
if let Ok(val) = std::env::var("COBRE_THREADS") {
if let Ok(n) = val.parse::<usize>() {
if n > 0 {
return n;
}
}
}
1
}
type LoadedCase = (
PrepareStochasticResult,
PrepareHydroModelsResult,
BroadcastConfig,
cobre_io::Config,
);
fn load_case_and_config(
args: &RunArgs,
quiet: bool,
stderr: &Term,
) -> Result<LoadedCase, CliError> {
if !args.case_dir.exists() {
return Err(CliError::Io {
source: std::io::Error::new(
std::io::ErrorKind::NotFound,
"case directory does not exist",
),
context: args.case_dir.display().to_string(),
});
}
if !quiet {
let _ = stderr.write_line(&format!("Loading case: {}", args.case_dir.display()));
}
let system = cobre_io::load_case(&args.case_dir)?;
let config_path = args.case_dir.join("config.json");
let config = cobre_io::parse_config(&config_path)?;
let bcast = BroadcastConfig::from_config(&config)?;
let seed = bcast.seed;
let prepared =
prepare_stochastic(system, &args.case_dir, &config, seed).map_err(CliError::from)?;
let hydro_models =
prepare_hydro_models(&prepared.system, &args.case_dir).map_err(CliError::from)?;
Ok((prepared, hydro_models, bcast, config))
}
struct RunContext<C: Communicator> {
comm: C,
is_root: bool,
quiet: bool,
n_threads: usize,
output_dir: PathBuf,
term_width: u16,
stderr: Term,
}
struct LoadBroadcastResult {
system: System,
setup: StudySetup,
root_config: Option<cobre_io::Config>,
root_estimation_report: Option<EstimationReport>,
training_enabled: bool,
policy_mode: cobre_io::PolicyMode,
}
struct TrainingPhaseResult {
result: cobre_sddp::TrainingResult,
output: cobre_io::TrainingOutput,
error: Option<cobre_sddp::SddpError>,
}
pub fn execute(args: &RunArgs) -> Result<(), CliError> {
let ctx = setup_communicator(args)?;
let LoadBroadcastResult {
system,
mut setup,
mut root_config,
root_estimation_report,
training_enabled,
policy_mode,
} = broadcast_and_build_setup(&ctx, args)?;
run_pre_training(
&ctx,
&system,
&setup,
root_config.as_ref(),
root_estimation_report.as_ref(),
)?;
if training_enabled {
apply_training_policy(&ctx, &system, &mut setup, root_config.as_ref(), policy_mode)?;
let training = run_training_phase(&ctx, &mut setup)?;
if ctx.is_root {
let config = root_config.take().ok_or_else(|| CliError::Internal {
message: "root_config was None on rank 0 — internal invariant violated".to_string(),
})?;
write_training_outputs(&WriteTrainingArgs {
output_dir: &ctx.output_dir,
system: &system,
config: &config,
training_output: &training.output,
setup: &setup,
training_result: &training.result,
quiet: ctx.quiet,
stderr: &ctx.stderr,
})?;
drop(config);
}
if let Some(ref training_error) = training.error {
if ctx.is_root {
tracing::error!(
"training failed after {} iterations: {training_error}",
training.result.iterations
);
if !ctx.quiet {
let _ = ctx.stderr.write_line(&format!(
"Training failed after {} iterations. Partial outputs written to {}.",
training.result.iterations,
ctx.output_dir.display()
));
}
}
return Err(CliError::Internal {
message: format!("training error: {training_error}"),
});
}
if setup.n_scenarios() > 0 {
run_simulation_phase(&ctx, &system, &mut setup, &training.result)?;
}
} else if setup.n_scenarios() > 0 {
let training_result =
load_policy_for_simulation(&ctx, &system, &mut setup, root_config.as_ref())?;
run_simulation_phase(&ctx, &system, &mut setup, &training_result)?;
} else {
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Training disabled, simulation disabled — nothing to do.");
}
}
Ok(())
}
fn load_and_validate_checkpoint(
policy_dir: &Path,
system: &System,
setup: &StudySetup,
root_config: Option<&cobre_io::Config>,
) -> Result<cobre_io::PolicyCheckpoint, CliError> {
let checkpoint = cobre_io::output::policy::read_policy_checkpoint(policy_dir).map_err(|e| {
CliError::Internal {
message: format!("failed to read policy checkpoint: {e}"),
}
})?;
if let Some(config) = root_config {
if config.policy.validate_compatibility {
#[allow(clippy::cast_possible_truncation)]
let n_stages = system.stages().iter().filter(|s| s.id >= 0).count() as u32;
let state_dim =
u32::try_from(setup.fcf().state_dimension).map_err(|e| CliError::Internal {
message: format!("state_dimension overflows u32: {e}"),
})?;
cobre_sddp::validate_policy_compatibility(
&checkpoint.metadata,
state_dim,
n_stages,
None,
None,
)
.map_err(CliError::from)?;
}
}
Ok(checkpoint)
}
fn apply_training_policy(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
root_config: Option<&cobre_io::Config>,
policy_mode: cobre_io::PolicyMode,
) -> Result<(), CliError> {
match policy_mode {
cobre_io::PolicyMode::WarmStart => {
let policy_dir = ctx.output_dir.join(setup.policy_path());
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot warm-start \
without a prior policy.",
policy_dir.display()
),
});
}
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Loading prior policy for warm-start training...");
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let warm_fcf = cobre_sddp::FutureCostFunction::new_with_warm_start(
&checkpoint.stage_cuts,
setup.forward_passes(),
setup.max_iterations().saturating_add(1),
)
.map_err(CliError::from)?;
setup.replace_fcf(warm_fcf);
if ctx.is_root && !ctx.quiet {
let warm_count = setup.fcf().pools[0].warm_start_count;
let _ = ctx.stderr.write_line(&format!(
"Warm-start: loaded {warm_count} cuts per stage from prior policy."
));
}
}
cobre_io::PolicyMode::Resume => {
let policy_dir = ctx.output_dir.join(setup.policy_path());
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot resume \
without a prior checkpoint.",
policy_dir.display()
),
});
}
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Loading prior checkpoint for resume training...");
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let completed = u64::from(checkpoint.metadata.completed_iterations);
if completed >= setup.max_iterations() && ctx.is_root && !ctx.quiet {
let _ = ctx.stderr.write_line(&format!(
"WARNING: Checkpoint already completed {completed} iterations \
(max_iterations = {}). No additional training will occur.",
setup.max_iterations()
));
}
let warm_fcf = cobre_sddp::FutureCostFunction::new_with_warm_start(
&checkpoint.stage_cuts,
setup.forward_passes(),
setup.max_iterations().saturating_add(1),
)
.map_err(CliError::from)?;
setup.replace_fcf(warm_fcf);
setup.set_start_iteration(completed);
if ctx.is_root && !ctx.quiet {
let warm_count = setup.fcf().pools[0].warm_start_count;
let _ = ctx.stderr.write_line(&format!(
"Resume: loaded {warm_count} cuts per stage, \
resuming from iteration {completed}."
));
}
}
cobre_io::PolicyMode::Fresh => {}
}
Ok(())
}
fn load_policy_for_simulation(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
root_config: Option<&cobre_io::Config>,
) -> Result<cobre_sddp::TrainingResult, CliError> {
if ctx.is_root && !ctx.quiet {
let _ = ctx
.stderr
.write_line("Training disabled. Loading policy for simulation-only mode...");
}
let policy_dir = ctx.output_dir.join(setup.policy_path());
if !policy_dir.exists() {
return Err(CliError::Internal {
message: format!(
"Policy directory not found: {}. Cannot run simulation-only \
mode without a trained policy.",
policy_dir.display()
),
});
}
let checkpoint = load_and_validate_checkpoint(&policy_dir, system, setup, root_config)?;
let loaded_fcf = cobre_sddp::FutureCostFunction::from_deserialized(&checkpoint.stage_cuts)
.map_err(CliError::from)?;
setup.replace_fcf(loaded_fcf);
let basis_cache =
cobre_sddp::build_basis_cache_from_checkpoint(setup.num_stages(), &checkpoint.stage_bases);
Ok(cobre_sddp::TrainingResult {
iterations: checkpoint.metadata.completed_iterations.into(),
final_lb: checkpoint.metadata.final_lower_bound,
final_ub: checkpoint
.metadata
.best_upper_bound
.unwrap_or(f64::INFINITY),
final_ub_std: 0.0,
final_gap: 0.0,
total_time_ms: 0,
reason: "loaded from checkpoint".to_string(),
solver_stats_log: Vec::new(),
basis_cache,
})
}
fn setup_communicator(args: &RunArgs) -> Result<RunContext<impl Communicator>, CliError> {
let comm = create_communicator()?;
let is_root = comm.rank() == 0;
let quiet = args.quiet || !is_root;
let mpi_active = comm.size() > 1;
if mpi_active && is_root && !args.quiet {
console::set_colors_enabled_stderr(true);
}
let stderr = Term::stderr();
if !quiet {
crate::banner::print_banner(&stderr);
}
let n_threads = resolve_thread_count(args.threads);
rayon::ThreadPoolBuilder::new()
.num_threads(n_threads)
.build_global()
.unwrap_or_else(|_| {
tracing::warn!("rayon global thread pool already initialized; ignoring --threads");
});
let output_dir: PathBuf = args
.output
.clone()
.unwrap_or_else(|| args.case_dir.join("output"));
let term_width = crate::progress::resolve_term_width();
Ok(RunContext {
comm,
is_root,
quiet,
n_threads,
output_dir,
term_width,
stderr,
})
}
fn broadcast_and_build_setup(
ctx: &RunContext<impl Communicator>,
args: &RunArgs,
) -> Result<LoadBroadcastResult, CliError> {
let (
raw_system,
raw_bcast_config,
mut root_config,
root_stochastic,
root_estimation_report,
raw_bcast_tree,
root_hydro_models,
load_err,
) = if ctx.is_root {
match load_case_and_config(args, ctx.quiet, &ctx.stderr) {
Ok((prepared, hydro_models, bcast, config)) => {
let bcast_tree = if prepared.stochastic.provenance().opening_tree
== ComponentProvenance::UserSupplied
{
let t = prepared.stochastic.opening_tree();
Some(BroadcastOpeningTree {
data: t.data().to_vec(),
openings_per_stage: t.openings_per_stage_slice().to_vec(),
dim: t.dim(),
})
} else {
None
};
let cobre_sddp::PrepareStochasticResult {
system,
stochastic,
estimation_report,
} = prepared;
(
Some(system),
Some(bcast),
Some(config),
Some(stochastic),
estimation_report,
Some(bcast_tree),
Some(hydro_models),
None,
)
}
Err(e) => (None, None, None, None, None, None, None, Some(e)),
}
} else {
(None, None, None, None, None, None, None, None)
};
let root_estimation_report: Option<EstimationReport> = root_estimation_report;
let system_result = broadcast_value(raw_system, &ctx.comm);
let bcast_config_result = broadcast_value(raw_bcast_config, &ctx.comm);
let root_hydro_models: Option<PrepareHydroModelsResult> = root_hydro_models;
let tree_result = broadcast_value(raw_bcast_tree, &ctx.comm);
if let Some(e) = load_err {
return Err(e);
}
let system = system_result?;
let mut bcast_config = bcast_config_result?;
let seed = bcast_config.seed;
let stochastic = if ctx.is_root {
drop(tree_result);
root_stochastic.ok_or_else(|| CliError::Internal {
message: "stochastic context missing on rank 0 after successful load".to_string(),
})?
} else {
let user_tree: Option<OpeningTree> =
tree_result?.map(|bt| OpeningTree::from_parts(bt.data, bt.openings_per_stage, bt.dim));
build_stochastic_context(&system, seed, &[], &[], user_tree).map_err(|e| {
CliError::Internal {
message: format!("stochastic context error: {e}"),
}
})?
};
let hydro_models = if ctx.is_root {
root_hydro_models.ok_or_else(|| CliError::Internal {
message: "hydro models missing on rank 0 after successful load".to_string(),
})?
} else {
prepare_hydro_models(&system, &args.case_dir).map_err(|e| CliError::Internal {
message: format!("hydro model preprocessing error on non-root rank: {e}"),
})?
};
let training_enabled = bcast_config.training_enabled;
let policy_mode = bcast_config.policy_mode;
let setup = build_study_setup(&system, &mut bcast_config, stochastic, hydro_models)?;
Ok(LoadBroadcastResult {
system,
setup,
root_config: root_config.take(),
root_estimation_report,
training_enabled,
policy_mode,
})
}
fn build_study_setup(
system: &System,
bcast_config: &mut BroadcastConfig,
stochastic: cobre_stochastic::StochasticContext,
hydro_models: PrepareHydroModelsResult,
) -> Result<StudySetup, CliError> {
let stopping_rule_set = stopping_rules_from_broadcast(bcast_config);
let cut_selection = std::mem::replace(
&mut bcast_config.cut_selection,
BroadcastCutSelection::Disabled,
)
.into_strategy();
StudySetup::from_broadcast_params(
system,
stochastic,
bcast_config.seed,
bcast_config.forward_passes,
stopping_rule_set,
bcast_config.n_scenarios,
usize::try_from(bcast_config.io_channel_capacity).unwrap_or(64),
bcast_config.policy_path.clone(),
bcast_config.inflow_method.clone(),
cut_selection,
bcast_config.cut_activity_tolerance,
hydro_models,
)
.map_err(CliError::from)
}
fn run_pre_training(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &StudySetup,
root_config: Option<&cobre_io::Config>,
root_estimation_report: Option<&EstimationReport>,
) -> Result<(), CliError> {
if !ctx.quiet && ctx.is_root {
let stochastic_summary = build_stochastic_summary(
system,
setup.stochastic(),
root_estimation_report,
setup.seed(),
);
crate::summary::print_stochastic_summary(&ctx.stderr, &stochastic_summary);
let hydro_summary = build_hydro_model_summary(setup.hydro_models(), system);
crate::summary::print_hydro_model_summary(&ctx.stderr, &hydro_summary);
}
if ctx.is_root && root_config.is_some_and(|c| c.exports.stochastic) {
export_stochastic_artifacts(
&ctx.output_dir,
setup.stochastic(),
system,
root_estimation_report,
ctx.quiet,
&ctx.stderr,
);
}
if ctx.is_root {
let scaling_path = ctx.output_dir.join("training/scaling_report.json");
cobre_io::write_scaling_report(&scaling_path, setup.scaling_report()).map_err(|e| {
CliError::Internal {
message: format!("failed to write scaling report: {e}"),
}
})?;
}
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-export barrier error: {e}"),
})?;
Ok(())
}
fn run_training_phase(
ctx: &RunContext<impl Communicator>,
setup: &mut StudySetup,
) -> Result<TrainingPhaseResult, CliError> {
let solver_factory = HighsSolver::new;
let mut solver = HighsSolver::new().map_err(|e| CliError::Solver {
message: format!("HiGHS initialisation failed: {e}"),
})?;
let (event_tx, event_rx) = mpsc::channel::<TrainingEvent>();
let quiet_rx: Option<mpsc::Receiver<TrainingEvent>>;
let progress_handle = if ctx.quiet {
quiet_rx = Some(event_rx);
None
} else {
quiet_rx = None;
Some(crate::progress::run_progress_thread(
event_rx,
setup.max_iterations(),
ctx.term_width,
))
};
let training_outcome = match setup.train(
&mut solver,
&ctx.comm,
ctx.n_threads,
solver_factory,
Some(event_tx),
None,
) {
Ok(outcome) => outcome,
Err(e) => {
if let Some(handle) = progress_handle {
let _ = handle.join();
}
return Err(CliError::from(e));
}
};
let training_result = training_outcome.result;
let events: Vec<TrainingEvent> = match (progress_handle, quiet_rx) {
(Some(handle), _) => handle.join(),
(None, Some(rx)) => rx.try_iter().collect(),
(None, None) => Vec::new(),
};
let training_output = setup.build_training_output(&training_result, &events);
let local_lp_solves: u64 = training_output
.convergence_records
.iter()
.map(|r| u64::from(r.lp_solves))
.sum();
let mut global_lp_solves = [0u64];
ctx.comm
.allreduce(&[local_lp_solves], &mut global_lp_solves, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("LP solve count allreduce error: {e}"),
})?;
let global_lp_solves = global_lp_solves[0];
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-training barrier error: {e}"),
})?;
let (
total_first_try,
total_retried,
total_failed,
total_solve_time_s,
total_basis_offered,
total_basis_rejections,
total_simplex_iter,
) = aggregate_solver_stats(&training_result.solver_stats_log);
let training_summary = TrainingSummary {
iterations: training_result.iterations,
converged: training_output.converged,
converged_at: if training_output.converged {
Some(training_result.iterations)
} else {
None
},
reason: training_result.reason.clone(),
lower_bound: training_result.final_lb,
upper_bound: training_result.final_ub,
upper_bound_std: training_result.final_ub_std,
gap_percent: training_result.final_gap * 100.0,
total_cuts_active: training_output.cut_stats.total_active,
total_cuts_generated: training_output.cut_stats.total_generated,
total_lp_solves: global_lp_solves,
total_time_ms: training_result.total_time_ms,
total_first_try,
total_retried,
total_failed,
total_solve_time_seconds: total_solve_time_s,
total_basis_offered,
total_basis_rejections,
total_simplex_iterations: total_simplex_iter,
};
if !ctx.quiet && ctx.is_root {
crate::summary::print_training_summary(&ctx.stderr, &training_summary);
}
Ok(TrainingPhaseResult {
result: training_result,
output: training_output,
error: training_outcome.error,
})
}
fn aggregate_solver_stats(
stats_log: &[(u64, String, i32, cobre_sddp::SolverStatsDelta)],
) -> (u64, u64, u64, f64, u64, u64, u64) {
let mut first_try = 0u64;
let mut retried = 0u64;
let mut failed = 0u64;
let mut solve_time = 0.0_f64;
let mut basis_offered = 0u64;
let mut basis_rejections = 0u64;
let mut simplex = 0u64;
for (_, _, _, delta) in stats_log {
first_try += delta.first_try_successes;
retried += delta.lp_successes.saturating_sub(delta.first_try_successes);
failed += delta.lp_failures;
solve_time += delta.solve_time_ms;
basis_offered += delta.basis_offered;
basis_rejections += delta.basis_rejections;
simplex += delta.simplex_iterations;
}
(
first_try,
retried,
failed,
solve_time / 1000.0,
basis_offered,
basis_rejections,
simplex,
)
}
fn run_simulation_phase(
ctx: &RunContext<impl Communicator>,
system: &System,
setup: &mut StudySetup,
training_result: &cobre_sddp::TrainingResult,
) -> Result<(), CliError> {
let solver_factory = HighsSolver::new;
let n_scenarios = setup.n_scenarios();
let sim_config = setup.simulation_config();
let mut sim_pool = setup
.create_workspace_pool(ctx.n_threads, solver_factory)
.map_err(|e| CliError::Solver {
message: format!("HiGHS initialisation failed for simulation pool: {e}"),
})?;
let (sim_event_tx, sim_event_rx) = mpsc::channel::<TrainingEvent>();
let sim_progress_handle = if ctx.quiet {
drop(sim_event_rx);
None
} else {
Some(crate::progress::run_progress_thread(
sim_event_rx,
u64::from(n_scenarios),
ctx.term_width,
))
};
let io_capacity = sim_config.io_channel_capacity;
let (result_tx, result_rx) = mpsc::sync_channel(io_capacity.max(1));
let parquet_config = cobre_io::ParquetWriterConfig::default();
let mut sim_writer = cobre_io::output::simulation_writer::SimulationParquetWriter::new(
&ctx.output_dir,
system,
&parquet_config,
)
.map_err(CliError::from)?;
let drain_handle = std::thread::spawn(move || {
let mut failed: u32 = 0;
for scenario_result in result_rx {
let payload =
cobre_io::output::simulation_writer::ScenarioWritePayload::from(scenario_result);
if let Err(e) = sim_writer.write_scenario(payload) {
tracing::error!("simulation write error: {e}");
failed += 1;
}
}
(sim_writer, failed)
});
let sim_start = std::time::Instant::now();
let sim_result = setup
.simulate(
&mut sim_pool.workspaces,
&ctx.comm,
&result_tx,
Some(sim_event_tx),
&training_result.basis_cache,
)
.map_err(CliError::from);
if let Some(handle) = sim_progress_handle {
let _ = handle.join();
}
drop(result_tx);
#[allow(clippy::expect_used)]
let (sim_writer, write_failures) = drain_handle.join().expect("drain thread panicked");
let sim_run_result = sim_result?;
#[allow(clippy::cast_possible_truncation)]
let sim_time_ms = sim_start.elapsed().as_millis() as u64;
let mut local_sim_output = sim_writer.finalize(sim_time_ms);
local_sim_output.failed = write_failures;
let merged_sim_output = merge_simulation_metadata(&ctx.comm, &local_sim_output)?;
ctx.comm.barrier().map_err(|e| CliError::Internal {
message: format!("post-simulation barrier error: {e}"),
})?;
if !ctx.quiet && ctx.is_root {
let agg = cobre_sddp::SolverStatsDelta::aggregate(
&sim_run_result
.solver_stats
.iter()
.map(|(_, delta)| delta.clone())
.collect::<Vec<_>>(),
);
print_sim_summary(&ctx.stderr, n_scenarios, sim_time_ms, &agg);
}
if ctx.is_root {
write_simulation_outputs(&WriteSimulationArgs {
output_dir: &ctx.output_dir,
sim_output: &merged_sim_output,
sim_solver_stats: &sim_run_result.solver_stats,
quiet: ctx.quiet,
stderr: &ctx.stderr,
})?;
}
Ok(())
}
fn print_sim_summary(
stderr: &Term,
n_scenarios: u32,
sim_time_ms: u64,
agg: &cobre_sddp::SolverStatsDelta,
) {
crate::summary::print_simulation_summary(
stderr,
&SimulationSummary {
n_scenarios,
completed: n_scenarios,
failed: 0,
total_time_ms: sim_time_ms,
total_lp_solves: agg.lp_solves,
total_first_try: agg.first_try_successes,
total_retried: agg.lp_successes.saturating_sub(agg.first_try_successes),
total_failed_solves: agg.lp_failures,
total_solve_time_seconds: agg.solve_time_ms / 1000.0,
total_basis_offered: agg.basis_offered,
total_basis_rejections: agg.basis_rejections,
total_simplex_iterations: agg.simplex_iterations,
},
);
}
#[allow(clippy::cast_possible_truncation)]
fn merge_simulation_metadata<C: Communicator>(
comm: &C,
local: &cobre_io::SimulationOutput,
) -> Result<cobre_io::SimulationOutput, CliError> {
let send_counts = [local.n_scenarios, local.completed, local.failed];
let mut merged_counts = [0u32; 3];
comm.allreduce(&send_counts, &mut merged_counts, ReduceOp::Sum)
.map_err(|e| CliError::Internal {
message: format!("simulation metadata count allreduce error: {e}"),
})?;
let send_time = [local.total_time_ms];
let mut merged_time = [0u64; 1];
comm.allreduce(&send_time, &mut merged_time, ReduceOp::Max)
.map_err(|e| CliError::Internal {
message: format!("simulation metadata time allreduce error: {e}"),
})?;
let local_paths_bytes = local.partitions_written.join("\n").into_bytes();
let send_len = [local_paths_bytes.len() as u64];
let n_ranks = comm.size();
let mut all_lens = vec![0u64; n_ranks];
let len_counts: Vec<usize> = vec![1; n_ranks];
let len_displs: Vec<usize> = (0..n_ranks).collect();
comm.allgatherv(&send_len, &mut all_lens, &len_counts, &len_displs)
.map_err(|e| CliError::Internal {
message: format!("partition path length exchange error: {e}"),
})?;
let recv_counts: Vec<usize> = all_lens.iter().map(|&l| l as usize).collect();
let recv_displs: Vec<usize> = recv_counts
.iter()
.scan(0usize, |acc, &c| {
let d = *acc;
*acc += c;
Some(d)
})
.collect();
let total_bytes: usize = recv_counts.iter().sum();
let mut all_bytes = vec![0u8; total_bytes];
comm.allgatherv(
&local_paths_bytes,
&mut all_bytes,
&recv_counts,
&recv_displs,
)
.map_err(|e| CliError::Internal {
message: format!("partition path gather error: {e}"),
})?;
let mut all_partitions: Vec<String> = Vec::new();
for (i, &count) in recv_counts.iter().enumerate() {
if count == 0 {
continue;
}
let start = recv_displs[i];
let chunk = &all_bytes[start..start + count];
let text = std::str::from_utf8(chunk).map_err(|e| CliError::Internal {
message: format!("partition path UTF-8 decode error from rank {i}: {e}"),
})?;
all_partitions.extend(text.split('\n').filter(|s| !s.is_empty()).map(String::from));
}
all_partitions.sort();
Ok(cobre_io::SimulationOutput {
n_scenarios: merged_counts[0],
completed: merged_counts[1],
failed: merged_counts[2],
total_time_ms: merged_time[0],
partitions_written: all_partitions,
})
}
#[allow(clippy::cast_possible_truncation)]
fn delta_to_stats_row(
id: u32,
phase: &str,
stage: i32,
delta: &cobre_sddp::SolverStatsDelta,
) -> cobre_io::SolverStatsRow {
cobre_io::SolverStatsRow {
iteration: id,
phase: phase.to_string(),
stage,
lp_solves: delta.lp_solves as u32,
lp_successes: delta.lp_successes as u32,
lp_retries: delta.lp_successes.saturating_sub(delta.first_try_successes) as u32,
lp_failures: delta.lp_failures as u32,
retry_attempts: delta.retry_attempts as u32,
basis_offered: delta.basis_offered as u32,
basis_rejections: delta.basis_rejections as u32,
simplex_iterations: delta.simplex_iterations,
solve_time_ms: delta.solve_time_ms,
load_model_time_ms: delta.load_model_time_ms,
add_rows_time_ms: delta.add_rows_time_ms,
set_bounds_time_ms: delta.set_bounds_time_ms,
basis_set_time_ms: delta.basis_set_time_ms,
retry_level_histogram: delta.retry_level_histogram,
}
}
struct WriteTrainingArgs<'a> {
output_dir: &'a Path,
system: &'a System,
config: &'a cobre_io::Config,
training_output: &'a cobre_io::TrainingOutput,
setup: &'a StudySetup,
training_result: &'a cobre_sddp::TrainingResult,
quiet: bool,
stderr: &'a Term,
}
fn write_training_outputs(args: &WriteTrainingArgs<'_>) -> Result<(), CliError> {
if !args.quiet {
use std::io::Write;
let _ = args.stderr.write_line("Writing training outputs...");
let _ = std::io::stderr().flush();
}
let write_start = std::time::Instant::now();
let policy_dir = args.output_dir.join(args.setup.policy_path());
crate::policy_io::write_checkpoint(
&policy_dir,
args.setup.fcf(),
args.training_result,
&crate::policy_io::CheckpointParams {
max_iterations: args.setup.max_iterations(),
forward_passes: args.setup.forward_passes(),
seed: args.setup.seed(),
},
)?;
cobre_io::write_training_results(
args.output_dir,
args.training_output,
args.system,
args.config,
)
.map_err(CliError::from)?;
if !args.training_result.solver_stats_log.is_empty() {
let rows: Vec<cobre_io::SolverStatsRow> = args
.training_result
.solver_stats_log
.iter()
.map(|(iter, phase, stage, delta)| {
#[allow(clippy::cast_possible_truncation)] delta_to_stats_row(*iter as u32, phase, *stage, delta)
})
.collect();
cobre_io::write_solver_stats(args.output_dir, &rows).map_err(CliError::from)?;
}
if !args.training_output.cut_selection_records.is_empty() {
let parquet_config = cobre_io::ParquetWriterConfig::default();
cobre_io::write_cut_selection_records(
args.output_dir,
&args.training_output.cut_selection_records,
&parquet_config,
)
.map_err(CliError::from)?;
}
if !args.quiet {
let write_secs = write_start.elapsed().as_secs_f64();
crate::summary::print_output_path(args.stderr, args.output_dir, write_secs);
}
Ok(())
}
struct WriteSimulationArgs<'a> {
output_dir: &'a Path,
sim_output: &'a cobre_io::SimulationOutput,
sim_solver_stats: &'a [(u32, cobre_sddp::SolverStatsDelta)],
quiet: bool,
stderr: &'a Term,
}
fn write_simulation_outputs(args: &WriteSimulationArgs<'_>) -> Result<(), CliError> {
if !args.quiet {
use std::io::Write;
let _ = args.stderr.write_line("Writing simulation outputs...");
let _ = std::io::stderr().flush();
}
let write_start = std::time::Instant::now();
cobre_io::write_simulation_results(args.output_dir, args.sim_output).map_err(CliError::from)?;
if !args.sim_solver_stats.is_empty() {
let rows: Vec<cobre_io::SolverStatsRow> = args
.sim_solver_stats
.iter()
.map(|(scenario_id, delta)| delta_to_stats_row(*scenario_id, "simulation", -1, delta))
.collect();
cobre_io::write_simulation_solver_stats(args.output_dir, &rows).map_err(CliError::from)?;
}
if !args.quiet {
let write_secs = write_start.elapsed().as_secs_f64();
crate::summary::print_output_path(args.stderr, args.output_dir, write_secs);
}
Ok(())
}
fn export_stochastic_artifacts(
output_dir: &Path,
stochastic: &cobre_stochastic::StochasticContext,
system: &System,
estimation_report: Option<&EstimationReport>,
quiet: bool,
stderr: &Term,
) {
use cobre_core::scenario::LoadModel;
let stochastic_dir = output_dir.join("stochastic");
if !quiet {
let _ = stderr.write_line("Exporting stochastic artifacts...");
}
if let Err(e) = write_noise_openings(
&stochastic_dir.join("noise_openings.parquet"),
stochastic.opening_tree(),
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (noise_openings): {e}"
));
}
}
let stats_rows = inflow_models_to_stats_rows(system.inflow_models());
if let Err(e) = write_inflow_seasonal_stats(
&stochastic_dir.join("inflow_seasonal_stats.parquet"),
&stats_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (inflow_seasonal_stats): {e}"
));
}
}
let ar_rows = inflow_models_to_ar_rows(system.inflow_models());
if let Err(e) = write_inflow_ar_coefficients(
&stochastic_dir.join("inflow_ar_coefficients.parquet"),
&ar_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (inflow_ar_coefficients): {e}"
));
}
}
if let Err(e) = write_correlation_json(
&stochastic_dir.join("correlation.json"),
system.correlation(),
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (correlation): {e}"
));
}
}
let has_stochastic_load = system
.load_models()
.iter()
.any(|m: &LoadModel| m.std_mw > 0.0);
if has_stochastic_load {
let load_rows: Vec<LoadSeasonalStatsRow> = system
.load_models()
.iter()
.map(|m| LoadSeasonalStatsRow {
bus_id: m.bus_id,
stage_id: m.stage_id,
mean_mw: m.mean_mw,
std_mw: m.std_mw,
})
.collect();
if let Err(e) = write_load_seasonal_stats(
&stochastic_dir.join("load_seasonal_stats.parquet"),
&load_rows,
) {
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (load_seasonal_stats): {e}"
));
}
}
}
if let Some(report) = estimation_report {
let fitting = estimation_report_to_fitting_report(report);
if let Err(e) = write_fitting_report(&stochastic_dir.join("fitting_report.json"), &fitting)
{
if !quiet {
let _ = stderr.write_line(&format!(
"warning: stochastic export failed (fitting_report): {e}"
));
}
}
}
}
#[cfg(test)]
mod tests {
use super::resolve_thread_count;
#[test]
fn test_resolve_thread_count_cli_value() {
assert_eq!(resolve_thread_count(Some(4)), 4);
}
#[test]
fn test_resolve_thread_count_default() {
assert_eq!(resolve_thread_count(Some(1)), 1);
}
}