use cobre_comm::Communicator;
use crate::{SddpError, TrajectoryRecord};
#[derive(Debug, Clone)]
pub struct ExchangeBuffers {
send_buf: Vec<f64>,
recv_buf: Vec<f64>,
counts: Vec<usize>,
displs: Vec<usize>,
n_state: usize,
local_count: usize,
num_ranks: usize,
actual_counts: Vec<usize>,
real_total: usize,
}
impl ExchangeBuffers {
#[must_use]
pub fn new(n_state: usize, local_count: usize, num_ranks: usize) -> Self {
let actual_per_rank = vec![local_count; num_ranks];
Self::with_actual_counts(n_state, local_count, num_ranks, &actual_per_rank)
}
#[must_use]
pub fn with_actual_counts(
n_state: usize,
max_local_count: usize,
num_ranks: usize,
actual_per_rank: &[usize],
) -> Self {
debug_assert_eq!(
actual_per_rank.len(),
num_ranks,
"actual_per_rank.len() {} != num_ranks {}",
actual_per_rank.len(),
num_ranks,
);
debug_assert!(
actual_per_rank.iter().all(|&c| c <= max_local_count),
"all actual_per_rank entries must be <= max_local_count {max_local_count}",
);
let per_rank = max_local_count * n_state;
let total = per_rank * num_ranks;
let real_total: usize = actual_per_rank.iter().sum();
let counts: Vec<usize> = vec![per_rank; num_ranks];
let displs: Vec<usize> = (0..num_ranks).map(|r| r * per_rank).collect();
Self {
send_buf: vec![0.0_f64; per_rank],
recv_buf: vec![0.0_f64; total],
counts,
displs,
n_state,
local_count: max_local_count,
num_ranks,
actual_counts: actual_per_rank.to_vec(),
real_total,
}
}
pub fn exchange<C: Communicator>(
&mut self,
records: &[TrajectoryRecord],
stage: usize,
num_stages: usize,
comm: &C,
) -> Result<(), SddpError> {
debug_assert!(
records.len() >= self.local_count * num_stages,
"records.len() {} < local_count {} * num_stages {}",
records.len(),
self.local_count,
num_stages,
);
for m in 0..self.local_count {
let record_idx = m * num_stages + stage;
debug_assert_eq!(
records[record_idx].state.len(),
self.n_state,
"records[{record_idx}].state.len() {} != n_state {}",
records[record_idx].state.len(),
self.n_state,
);
self.send_buf[m * self.n_state..(m + 1) * self.n_state]
.copy_from_slice(&records[record_idx].state);
}
comm.allgatherv(
&self.send_buf,
&mut self.recv_buf,
&self.counts,
&self.displs,
)?;
Ok(())
}
#[must_use]
pub fn gathered_states(&self) -> &[f64] {
&self.recv_buf
}
#[must_use]
pub fn state_at(&self, rank: usize, scenario: usize) -> &[f64] {
debug_assert!(
rank < self.num_ranks,
"rank {rank} >= num_ranks {}",
self.num_ranks
);
debug_assert!(
scenario < self.local_count,
"scenario {scenario} >= local_count {}",
self.local_count
);
let base = rank * self.local_count * self.n_state + scenario * self.n_state;
&self.recv_buf[base..base + self.n_state]
}
#[must_use]
pub fn local_count(&self) -> usize {
self.local_count
}
#[must_use]
pub fn total_scenarios(&self) -> usize {
self.local_count * self.num_ranks
}
#[must_use]
pub fn real_total_scenarios(&self) -> usize {
self.real_total
}
pub fn pack_real_states_into(&self, buf: &mut Vec<f64>) {
buf.clear();
for r in 0..self.num_ranks {
let base = r * self.local_count * self.n_state;
let real_len = self.actual_counts[r] * self.n_state;
buf.extend_from_slice(&self.recv_buf[base..base + real_len]);
}
}
}
#[cfg(test)]
mod tests {
use cobre_comm::{CommData, CommError, Communicator, ReduceOp};
use super::ExchangeBuffers;
use crate::TrajectoryRecord;
fn make_record(state: Vec<f64>) -> TrajectoryRecord {
TrajectoryRecord {
primal: vec![],
dual: vec![],
stage_cost: 0.0,
state,
}
}
#[test]
fn new_allocates_correct_send_buf_length() {
let bufs = ExchangeBuffers::new(3, 4, 2);
assert_eq!(bufs.send_buf.len(), 12);
}
#[test]
fn new_allocates_correct_recv_buf_length() {
let bufs = ExchangeBuffers::new(3, 4, 2);
assert_eq!(bufs.recv_buf.len(), 24);
}
#[test]
fn new_allocates_correct_counts_length_and_values() {
let bufs = ExchangeBuffers::new(3, 4, 2);
assert_eq!(bufs.counts.len(), 2);
assert_eq!(bufs.counts[0], 12);
assert_eq!(bufs.counts[1], 12);
}
#[test]
fn new_allocates_correct_displs_length_and_values() {
let bufs = ExchangeBuffers::new(3, 4, 2);
assert_eq!(bufs.displs.len(), 2);
assert_eq!(bufs.displs[0], 0);
assert_eq!(bufs.displs[1], 12);
}
#[test]
fn new_single_rank_counts_is_one_element() {
let bufs = ExchangeBuffers::new(2, 3, 1);
assert_eq!(bufs.counts, vec![6]); assert_eq!(bufs.displs, vec![0]);
}
#[test]
fn total_scenarios_returns_local_count_times_num_ranks() {
let bufs = ExchangeBuffers::new(2, 5, 4);
assert_eq!(bufs.total_scenarios(), 20); }
#[test]
fn total_scenarios_single_rank() {
let bufs = ExchangeBuffers::new(2, 3, 1);
assert_eq!(bufs.total_scenarios(), 3);
}
#[test]
fn state_at_indexing_arithmetic() {
let n_state = 3;
let local_count = 2;
let num_ranks = 3;
let mut bufs = ExchangeBuffers::new(n_state, local_count, num_ranks);
#[allow(clippy::cast_precision_loss)]
for r in 0..num_ranks {
for s in 0..local_count {
let base = r * local_count * n_state + s * n_state;
for i in 0..n_state {
bufs.recv_buf[base + i] = (r * 100 + s * 10 + i) as f64;
}
}
}
let slice = bufs.state_at(1, 0);
assert_eq!(slice.len(), n_state);
assert_eq!(slice[0], 100.0); assert_eq!(slice[1], 101.0); assert_eq!(slice[2], 102.0);
let slice = bufs.state_at(2, 1);
assert_eq!(slice[0], 210.0); assert_eq!(slice[1], 211.0); assert_eq!(slice[2], 212.0); }
#[test]
fn exchange_single_rank_three_scenarios_two_state() {
use cobre_comm::LocalBackend;
let mut bufs = ExchangeBuffers::new(2, 3, 1);
let records = vec![
make_record(vec![1.0, 2.0]),
make_record(vec![3.0, 4.0]),
make_record(vec![5.0, 6.0]),
];
let comm = LocalBackend;
bufs.exchange(&records, 0, 1, &comm).unwrap();
assert_eq!(bufs.gathered_states(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert_eq!(bufs.total_scenarios(), 3);
}
#[test]
fn exchange_selects_correct_stage_in_multi_stage_records() {
use cobre_comm::LocalBackend;
let records = vec![
make_record(vec![10.0, 11.0]), make_record(vec![30.0, 31.0]), make_record(vec![50.0, 51.0]), make_record(vec![20.0, 21.0]), make_record(vec![40.0, 41.0]), make_record(vec![60.0, 61.0]), ];
let mut bufs = ExchangeBuffers::new(2, 2, 1);
let comm = LocalBackend;
bufs.exchange(&records, 1, 3, &comm).unwrap();
assert_eq!(bufs.gathered_states(), &[30.0, 31.0, 40.0, 41.0]);
assert_ne!(bufs.gathered_states()[0], 10.0, "stage 0 must not appear");
assert_ne!(bufs.gathered_states()[0], 50.0, "stage 2 must not appear");
}
#[test]
fn state_at_matches_record_state_after_exchange() {
use cobre_comm::LocalBackend;
let records = vec![
make_record(vec![1.0, 2.0, 3.0]), make_record(vec![4.0, 5.0, 6.0]), ];
let mut bufs = ExchangeBuffers::new(3, 2, 1);
let comm = LocalBackend;
bufs.exchange(&records, 0, 1, &comm).unwrap();
let state = bufs.state_at(0, 1);
assert_eq!(state, &[4.0, 5.0, 6.0]);
}
#[test]
fn exchange_error_maps_to_sddp_communication_error() {
use crate::SddpError;
struct FailingComm;
impl Communicator for FailingComm {
fn allgatherv<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
Err(CommError::CollectiveFailed {
operation: "allgatherv",
mpi_error_code: 42,
message: "simulated failure".to_string(),
})
}
fn allreduce<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
unreachable!()
}
fn broadcast<T: CommData>(
&self,
_buf: &mut [T],
_root: usize,
) -> Result<(), CommError> {
unreachable!()
}
fn barrier(&self) -> Result<(), CommError> {
unreachable!()
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
1
}
}
let mut bufs = ExchangeBuffers::new(2, 1, 1);
let records = vec![make_record(vec![1.0, 2.0])];
let result = bufs.exchange(&records, 0, 1, &FailingComm);
assert!(
matches!(result, Err(SddpError::Communication(_))),
"expected SddpError::Communication, got: {result:?}",
);
}
#[test]
fn real_total_scenarios_uneven_distribution() {
let bufs = ExchangeBuffers::with_actual_counts(2, 3, 2, &[3, 2]);
assert_eq!(bufs.real_total_scenarios(), 5);
assert_eq!(bufs.total_scenarios(), 6); }
#[test]
fn real_total_scenarios_even_distribution() {
let bufs = ExchangeBuffers::with_actual_counts(2, 3, 2, &[3, 3]);
assert_eq!(bufs.real_total_scenarios(), 6);
assert_eq!(bufs.total_scenarios(), 6);
}
#[test]
fn real_total_scenarios_single_rank() {
let bufs = ExchangeBuffers::with_actual_counts(2, 3, 1, &[3]);
assert_eq!(bufs.real_total_scenarios(), 3);
assert_eq!(bufs.total_scenarios(), 3);
}
#[test]
#[allow(clippy::erasing_op, clippy::identity_op)]
fn pack_real_states_into_excludes_padding() {
let n_state = 2;
let max_local = 3;
let num_ranks = 2;
let mut bufs = ExchangeBuffers::with_actual_counts(n_state, max_local, num_ranks, &[3, 2]);
let rank0_base = 0 * max_local * n_state; let rank1_base = 1 * max_local * n_state; bufs.recv_buf[rank0_base..rank0_base + 2].copy_from_slice(&[10.0, 11.0]);
bufs.recv_buf[rank0_base + 2..rank0_base + 4].copy_from_slice(&[20.0, 21.0]);
bufs.recv_buf[rank0_base + 4..rank0_base + 6].copy_from_slice(&[30.0, 31.0]);
bufs.recv_buf[rank1_base..rank1_base + 2].copy_from_slice(&[40.0, 41.0]);
bufs.recv_buf[rank1_base + 2..rank1_base + 4].copy_from_slice(&[50.0, 51.0]);
let mut out = Vec::new();
bufs.pack_real_states_into(&mut out);
assert_eq!(
out.len(),
10,
"expected 10 f64 values (5 state vectors × 2)"
);
assert_eq!(
out,
vec![10.0, 11.0, 20.0, 21.0, 30.0, 31.0, 40.0, 41.0, 50.0, 51.0]
);
}
#[test]
fn pack_real_states_into_even_distribution_matches_gathered_states() {
use cobre_comm::LocalBackend;
let mut bufs = ExchangeBuffers::new(2, 3, 1);
let records = vec![
make_record(vec![1.0, 2.0]),
make_record(vec![3.0, 4.0]),
make_record(vec![5.0, 6.0]),
];
let comm = LocalBackend;
bufs.exchange(&records, 0, 1, &comm).unwrap();
let mut packed = Vec::new();
bufs.pack_real_states_into(&mut packed);
assert_eq!(packed, bufs.gathered_states());
}
#[test]
fn pack_real_states_into_reuses_buffer_capacity() {
let n_state = 2;
let mut bufs = ExchangeBuffers::with_actual_counts(n_state, 3, 2, &[3, 2]);
bufs.recv_buf.fill(1.0);
let mut out = Vec::with_capacity(20);
bufs.pack_real_states_into(&mut out);
assert_eq!(out.len(), 10);
bufs.pack_real_states_into(&mut out);
assert_eq!(out.len(), 10); }
}