use std::time::Instant;
use cobre_comm::Communicator;
use cobre_core::WelfordAccumulator;
use super::{ForwardResult, SyncResult};
use crate::error::SddpError;
#[allow(unused_imports)]
use super::run_forward_pass;
pub fn sync_forward<C: Communicator>(
local: &ForwardResult,
comm: &C,
total_forward_passes: usize,
) -> Result<SyncResult, SddpError> {
let start = Instant::now();
let num_ranks = comm.size();
let my_rank = comm.rank();
let base = total_forward_passes / num_ranks;
let remainder = total_forward_passes % num_ranks;
let counts: Vec<usize> = (0..num_ranks)
.map(|r| base + usize::from(r < remainder))
.collect();
let mut displs = vec![0usize; num_ranks];
for r in 1..num_ranks {
displs[r] = displs[r - 1] + counts[r - 1];
}
let global_n = counts.iter().sum::<usize>();
debug_assert_eq!(
global_n, total_forward_passes,
"counts sum {global_n} != total_forward_passes {total_forward_passes}",
);
let mut global_costs = vec![0.0_f64; global_n];
debug_assert_eq!(
local.scenario_costs.len(),
counts[my_rank],
"rank {my_rank}: scenario_costs length {} != expected count {}",
local.scenario_costs.len(),
counts[my_rank],
);
comm.allgatherv(&local.scenario_costs, &mut global_costs, &counts, &displs)?;
let mut welford = WelfordAccumulator::new();
for &c in &global_costs {
welford.update(c);
}
let mean = welford.mean();
let (std_dev, ci_95) = if global_n > 1 {
(welford.sample_std_dev(), welford.sample_ci_95_half_width())
} else {
(0.0_f64, 0.0_f64)
};
#[allow(clippy::cast_possible_truncation)]
let sync_time_ms = start.elapsed().as_millis() as u64;
Ok(SyncResult {
global_ub_mean: mean,
global_ub_std: std_dev,
ci_95_half_width: ci_95,
sync_time_ms,
})
}