use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::mpsc::Sender;
use std::time::Instant;
use cobre_comm::Communicator;
use cobre_core::TrainingEvent;
use cobre_solver::{RowBatch, SolverInterface, StageTemplate};
use cobre_stochastic::context::ClassSchemes;
use cobre_stochastic::{
ClassDimensions, ForwardSampler, ForwardSamplerConfig, build_forward_sampler,
};
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use crate::{
FutureCostFunction,
context::{StageContext, TrainingContext},
forward::{build_cut_row_batch_into, partition},
simulation::{
config::SimulationConfig,
error::SimulationError,
extraction::assign_scenarios,
pipeline::{
SIMULATION_SEED_OFFSET, ScenarioIds, SimulationOutputSpec, SimulationRunResult,
WorkerCosts, WorkerStats, dispatch_scenario_result, emit_sim_progress,
process_scenario_stages,
},
},
solver_stats::SolverStatsDelta,
workspace::{CapturedBasis, SolverWorkspace},
};
pub(crate) struct SimulationInputs<'a, S: SolverInterface + Send, C> {
pub workspaces: &'a mut [SolverWorkspace<S>],
pub ctx: &'a StageContext<'a>,
pub fcf: &'a FutureCostFunction,
pub training_ctx: &'a TrainingContext<'a>,
pub config: &'a SimulationConfig,
pub output: SimulationOutputSpec<'a>,
pub baked_templates: Option<&'a [StageTemplate]>,
pub stage_bases: &'a [Option<CapturedBasis>],
pub comm: &'a C,
}
impl<'a, S: SolverInterface + Send, C> SimulationInputs<'a, S, C> {
pub(crate) fn new(
workspaces: &'a mut [SolverWorkspace<S>],
ctx: &'a StageContext<'a>,
fcf: &'a FutureCostFunction,
training_ctx: &'a TrainingContext<'a>,
config: &'a SimulationConfig,
output: SimulationOutputSpec<'a>,
baked_templates: Option<&'a [StageTemplate]>,
stage_bases: &'a [Option<CapturedBasis>],
comm: &'a C,
) -> Self {
Self {
workspaces,
ctx,
fcf,
training_ctx,
config,
output,
baked_templates,
stage_bases,
comm,
}
}
}
pub(crate) struct SimulationState {
owned_baked: Option<Vec<StageTemplate>>,
bake_batch: RowBatch,
}
impl SimulationState {
#[must_use]
pub(crate) fn new(_num_stages: usize) -> Self {
Self {
owned_baked: None,
bake_batch: RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
},
}
}
pub(crate) fn run<S: SolverInterface + Send, C: Communicator>(
&mut self,
inputs: &mut SimulationInputs<'_, S, C>,
) -> Result<SimulationRunResult, SimulationError> {
let training_ctx = inputs.training_ctx;
let TrainingContext {
horizon,
indexer,
initial_state,
..
} = training_ctx;
let num_stages = horizon.num_stages();
let rank = inputs.comm.rank();
debug_assert_inputs(inputs.ctx, num_stages, initial_state.len(), indexer.n_state);
if let Some(baked) = inputs.baked_templates {
if baked.len() != num_stages {
return Err(SimulationError::InvalidConfiguration(format!(
"baked_templates length {} != num_stages {}",
baked.len(),
num_stages
)));
}
}
rebake_templates_if_needed(
inputs.fcf,
inputs.ctx,
indexer,
num_stages,
inputs.baked_templates,
&mut self.bake_batch,
&mut self.owned_baked,
);
let baked_templates: &[StageTemplate] =
match (inputs.baked_templates, self.owned_baked.as_deref()) {
(Some(b), _) | (None, Some(b)) => b,
(None, None) => unreachable!("owned_baked is Some when baked_templates is None"),
};
let scenario_range = assign_scenarios(inputs.config.n_scenarios, rank, inputs.comm.size());
#[allow(clippy::cast_possible_truncation)]
let local_count = (scenario_range.end - scenario_range.start) as usize;
let scenario_start = scenario_range.start as usize;
let n_workers = inputs.workspaces.len().max(1);
let world_size = u32::try_from(inputs.comm.size()).unwrap_or(1).max(1);
let sim_start = Instant::now();
let scenarios_complete = AtomicU32::new(0);
if let Some(sender) = inputs.output.event_sender.as_ref() {
#[allow(clippy::cast_possible_truncation)]
let _ = sender.send(TrainingEvent::SimulationStarted {
case_name: String::new(),
n_scenarios: inputs.config.n_scenarios,
n_stages: num_stages as u32,
ranks: world_size,
threads_per_rank: n_workers as u32,
timestamp: String::new(),
});
}
let sampler = build_sim_sampler(training_ctx)?;
let worker_results: Vec<Result<(WorkerCosts, WorkerStats), SimulationError>> = inputs
.workspaces
.par_iter_mut()
.enumerate()
.map(|(w, ws)| {
run_worker_scenarios(
w,
ws,
inputs.ctx,
inputs.fcf,
training_ctx,
&inputs.output,
inputs.config,
inputs.stage_bases,
baked_templates,
&scenarios_complete,
sim_start,
local_count,
n_workers,
scenario_start,
&sampler,
num_stages,
world_size,
)
})
.collect();
let mut all_costs = Vec::with_capacity(local_count);
let mut all_stats = Vec::with_capacity(local_count);
for result in worker_results {
let (costs, stats) = result?;
all_costs.extend(costs);
all_stats.extend(stats);
}
debug_assert!(
all_costs.windows(2).all(|w| w[0].0 <= w[1].0),
"all_costs not pre-sorted: workers must emit ascending scenario_id"
);
debug_assert!(
all_stats.windows(2).all(|w| w[0].0 <= w[1].0),
"all_stats not pre-sorted: workers must emit ascending scenario_id"
);
if let Some(sender) = inputs.output.event_sender.take() {
#[allow(clippy::cast_possible_truncation)]
let _ = sender.send(TrainingEvent::SimulationFinished {
scenarios: inputs.config.n_scenarios,
output_dir: String::new(),
elapsed_ms: sim_start.elapsed().as_millis() as u64,
});
}
Ok(SimulationRunResult {
costs: all_costs,
solver_stats: all_stats,
})
}
}
fn debug_assert_inputs(
ctx: &StageContext<'_>,
num_stages: usize,
n_initial: usize,
n_state: usize,
) {
debug_assert_eq!(
ctx.templates.len(),
num_stages,
"templates.len()={} != num_stages={num_stages}",
ctx.templates.len()
);
debug_assert_eq!(
ctx.base_rows.len(),
num_stages,
"base_rows.len()={} != num_stages={num_stages}",
ctx.base_rows.len()
);
debug_assert_eq!(
n_initial, n_state,
"initial_state.len()={n_initial} != n_state={n_state}"
);
}
#[allow(clippy::too_many_arguments)]
fn run_worker_scenarios<S: SolverInterface + Send>(
w: usize,
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
fcf: &FutureCostFunction,
training_ctx: &TrainingContext<'_>,
output: &SimulationOutputSpec<'_>,
config: &SimulationConfig,
stage_bases: &[Option<CapturedBasis>],
baked_templates: &[StageTemplate],
scenarios_complete: &AtomicU32,
sim_start: Instant,
local_count: usize,
n_workers: usize,
scenario_start: usize,
sampler: &ForwardSampler,
num_stages: usize,
world_size: u32,
) -> Result<(WorkerCosts, WorkerStats), SimulationError> {
let (start_local, end_local) = partition(local_count, n_workers, w);
let worker_sender: Option<Sender<TrainingEvent>> = output.event_sender.clone();
let n_scenarios = end_local - start_local;
let mut worker_costs = Vec::with_capacity(n_scenarios);
let mut worker_stats = Vec::with_capacity(n_scenarios);
let noise_dim = training_ctx.stochastic.dim();
ws.scratch.raw_noise_buf.resize(noise_dim, 0.0_f64);
#[allow(clippy::cast_possible_truncation)]
ws.scratch
.perm_scratch
.resize(config.n_scenarios.max(1) as usize, 0_usize);
for local_idx in start_local..end_local {
#[allow(clippy::cast_possible_truncation)]
let scenario_id = (scenario_start + local_idx) as u32;
let global_scenario = SIMULATION_SEED_OFFSET.saturating_add(scenario_id);
let stats_before = ws.solver.statistics();
let load_spec = crate::simulation::pipeline::SimScenarioLoadSpec {
baked_templates,
stage_bases,
basis_activity_window: config.basis_activity_window,
};
let mut raw_noise_buf = std::mem::take(&mut ws.scratch.raw_noise_buf);
let mut perm_scratch = std::mem::take(&mut ws.scratch.perm_scratch);
let result = process_scenario_stages(
ws,
ctx,
fcf,
training_ctx,
&load_spec,
output,
&mut ScenarioIds {
scenario_id,
global_scenario,
num_stages,
total_scenarios: config.n_scenarios,
raw_noise_buf: &mut raw_noise_buf,
perm_scratch: &mut perm_scratch,
sampler,
},
);
ws.scratch.raw_noise_buf = raw_noise_buf;
ws.scratch.perm_scratch = perm_scratch;
let (total_cost, stage_results) = result?;
let stats_after = ws.solver.statistics();
let scenario_delta = SolverStatsDelta::from_snapshots(&stats_before, &stats_after);
let scenario_solve_time_ms = scenario_delta.solve_time_ms;
let scenario_lp_solves = scenario_delta.lp_solves;
worker_stats.push((scenario_id, -1_i32, scenario_delta));
worker_costs.push(dispatch_scenario_result(
output,
scenario_id,
total_cost,
stage_results,
)?);
let completed = scenarios_complete.fetch_add(1, Ordering::Relaxed) + 1;
let completed_global = completed.saturating_mul(world_size).min(config.n_scenarios);
#[allow(clippy::cast_possible_truncation)]
emit_sim_progress(
worker_sender.as_ref(),
total_cost,
scenario_solve_time_ms,
scenario_lp_solves,
completed_global,
config.n_scenarios,
sim_start.elapsed().as_millis() as u64,
);
}
Ok((worker_costs, worker_stats))
}
fn build_sim_sampler<'a>(
training_ctx: &'a TrainingContext<'a>,
) -> Result<ForwardSampler<'a>, SimulationError> {
Ok(build_forward_sampler(ForwardSamplerConfig {
class_schemes: ClassSchemes {
inflow: Some(training_ctx.inflow_scheme),
load: Some(training_ctx.load_scheme),
ncs: Some(training_ctx.ncs_scheme),
},
ctx: training_ctx.stochastic,
stages: training_ctx.stages,
dims: ClassDimensions {
n_hydros: training_ctx.stochastic.n_hydros(),
n_load_buses: training_ctx.stochastic.n_load_buses(),
n_ncs: training_ctx.stochastic.n_stochastic_ncs(),
},
historical_library: training_ctx.historical_library,
external_inflow_library: training_ctx.external_inflow_library,
external_load_library: training_ctx.external_load_library,
external_ncs_library: training_ctx.external_ncs_library,
})?)
}
fn rebake_templates_if_needed(
fcf: &FutureCostFunction,
ctx: &StageContext<'_>,
indexer: &crate::indexer::StageIndexer,
num_stages: usize,
caller_baked: Option<&[StageTemplate]>,
bake_batch: &mut RowBatch,
owned_baked: &mut Option<Vec<StageTemplate>>,
) {
if caller_baked.is_some() {
*owned_baked = None;
return;
}
let mut owned = Vec::with_capacity(num_stages);
for t in 0..num_stages {
build_cut_row_batch_into(bake_batch, fcf, t, indexer, &ctx.templates[t].col_scale);
let mut baked = StageTemplate::empty();
cobre_solver::bake_rows_into_template(&ctx.templates[t], bake_batch, &mut baked);
owned.push(baked);
}
*owned_baked = Some(owned);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simulation_state_new_allocates_empty_bake_batch() {
let state = SimulationState::new(3);
assert!(state.owned_baked.is_none(), "owned_baked must be None");
assert_eq!(
state.bake_batch.num_rows, 0,
"bake_batch.num_rows must be 0"
);
}
fn scaled_global_count(local_completed: u32, world_size: u32, total: u32) -> u32 {
local_completed.saturating_mul(world_size).min(total)
}
#[test]
fn scaled_global_count_single_rank_is_identity() {
assert_eq!(scaled_global_count(0, 1, 100), 0);
assert_eq!(scaled_global_count(50, 1, 100), 50);
assert_eq!(scaled_global_count(100, 1, 100), 100);
}
#[test]
fn scaled_global_count_balanced_two_ranks_tracks_global() {
assert_eq!(scaled_global_count(1, 2, 100), 2);
assert_eq!(scaled_global_count(25, 2, 100), 50);
assert_eq!(scaled_global_count(50, 2, 100), 100);
}
#[test]
fn scaled_global_count_clamps_at_total_when_unevenly_divided() {
assert_eq!(scaled_global_count(33, 3, 100), 99);
assert_eq!(scaled_global_count(34, 3, 100), 100);
}
#[test]
fn aggregate_costs_is_ascending_post_extend() {
use crate::simulation::types::ScenarioCategoryCosts;
let zero_cat = ScenarioCategoryCosts {
resource_cost: 0.0,
recourse_cost: 0.0,
violation_cost: 0.0,
regularization_cost: 0.0,
imputed_cost: 0.0,
};
let worker_outputs: Vec<Vec<(u32, f64, ScenarioCategoryCosts)>> = (0u32..4)
.map(|w| {
(0..3u32)
.map(|i| (w * 3 + i, 0.0_f64, zero_cat.clone()))
.collect()
})
.collect();
let mut all_costs: Vec<(u32, f64, ScenarioCategoryCosts)> = Vec::with_capacity(12);
for costs in worker_outputs {
all_costs.extend(costs);
}
assert!(
all_costs.windows(2).all(|w| w[0].0 <= w[1].0),
"all_costs must be ascending by scenario_id after sequential extend"
);
let ids: Vec<u32> = all_costs.iter().map(|e| e.0).collect();
assert_eq!(ids, (0u32..12).collect::<Vec<_>>());
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "all_costs not pre-sorted")]
fn debug_assert_fires_for_out_of_order_costs() {
use crate::simulation::types::ScenarioCategoryCosts;
let zero_cat = ScenarioCategoryCosts {
resource_cost: 0.0,
recourse_cost: 0.0,
violation_cost: 0.0,
regularization_cost: 0.0,
imputed_cost: 0.0,
};
let all_costs: Vec<(u32, f64, ScenarioCategoryCosts)> = vec![
(0, 0.0, zero_cat.clone()),
(2, 0.0, zero_cat.clone()),
(1, 0.0, zero_cat.clone()),
];
debug_assert!(
all_costs.windows(2).all(|w| w[0].0 <= w[1].0),
"all_costs not pre-sorted: workers must emit ascending scenario_id"
);
}
}