use cobre_comm::Communicator;
use crate::{
FutureCostFunction, SddpError,
cut::wire::{cut_wire_size, deserialize_cuts_from_buffer, serialize_cut},
};
#[derive(Debug, Clone)]
pub struct CutSyncBuffers {
send_buf: Vec<u8>,
recv_buf: Vec<u8>,
counts: Vec<usize>,
displs: Vec<usize>,
n_state: usize,
num_ranks: usize,
record_size: usize,
per_rank_cuts: Vec<usize>,
}
impl CutSyncBuffers {
#[must_use]
pub fn new(n_state: usize, max_cuts_per_rank: usize, num_ranks: usize) -> Self {
Self::with_distribution(
n_state,
max_cuts_per_rank,
num_ranks,
max_cuts_per_rank * num_ranks,
)
}
#[must_use]
pub fn with_distribution(
n_state: usize,
max_cuts_per_rank: usize,
num_ranks: usize,
total_forward_passes: usize,
) -> Self {
let record_size = cut_wire_size(n_state);
let send_cap = max_cuts_per_rank * record_size;
let base = total_forward_passes / num_ranks;
let remainder = total_forward_passes % num_ranks;
let per_rank_cuts: Vec<usize> = (0..num_ranks)
.map(|r| base + usize::from(r < remainder))
.collect();
let recv_cap: usize = per_rank_cuts.iter().sum::<usize>() * record_size;
let counts: Vec<usize> = per_rank_cuts.iter().map(|&c| c * record_size).collect();
let mut displs = vec![0usize; num_ranks];
for r in 1..num_ranks {
displs[r] = displs[r - 1] + counts[r - 1];
}
Self {
send_buf: vec![0u8; send_cap],
recv_buf: vec![0u8; recv_cap],
counts,
displs,
n_state,
num_ranks,
record_size,
per_rank_cuts,
}
}
pub fn sync_cuts<C: Communicator>(
&mut self,
stage: usize,
local_cuts: &[(u32, u32, u32, f64, &[f64])],
fcf: &mut FutureCostFunction,
comm: &C,
) -> Result<usize, SddpError> {
let n_local = local_cuts.len();
let send_len = n_local * self.record_size;
debug_assert!(
send_len <= self.send_buf.len(),
"send_len {send_len} exceeds send_buf capacity {}",
self.send_buf.len()
);
for (i, &(slot_index, iteration, forward_pass_index, intercept, coefficients)) in
local_cuts.iter().enumerate()
{
debug_assert!(
coefficients.len() == self.n_state,
"cut {i} coefficient length {} != n_state {}",
coefficients.len(),
self.n_state,
);
let start = i * self.record_size;
serialize_cut(
&mut self.send_buf[start..start + self.record_size],
slot_index,
iteration,
forward_pass_index,
intercept,
coefficients,
);
}
let my_rank = comm.rank();
for r in 0..self.num_ranks {
let cuts_for_r = if r == my_rank {
n_local
} else {
self.per_rank_cuts[r]
};
self.counts[r] = cuts_for_r * self.record_size;
}
self.displs[0] = 0;
for r in 1..self.num_ranks {
self.displs[r] = self.displs[r - 1] + self.counts[r - 1];
}
let recv_len: usize = self.counts.iter().sum();
debug_assert!(
recv_len <= self.recv_buf.len(),
"recv_len {recv_len} exceeds recv_buf capacity {}",
self.recv_buf.len()
);
comm.allgatherv(
&self.send_buf[..send_len],
&mut self.recv_buf[..recv_len],
&self.counts,
&self.displs,
)?;
let local_rank = comm.rank();
let mut remote_count = 0usize;
for r in 0..self.num_ranks {
if r == local_rank {
continue;
}
let start = self.displs[r];
let end = start + self.counts[r];
let slice = &self.recv_buf[start..end];
let cuts = deserialize_cuts_from_buffer(slice, self.n_state);
for (header, coefficients) in cuts {
fcf.add_cut(
stage,
u64::from(header.iteration),
header.forward_pass_index,
header.intercept,
&coefficients,
);
remote_count += 1;
}
}
Ok(remote_count)
}
#[allow(clippy::cast_possible_truncation)]
pub fn pack_local_cuts(
&mut self,
fcf: &FutureCostFunction,
stage: usize,
iteration: u64,
) -> usize {
let pool = &fcf.pools[stage];
let mut n_packed = 0usize;
for slot in 0..pool.populated_count {
if !pool.active[slot] {
continue;
}
let meta = &pool.metadata[slot];
if meta.iteration_generated != iteration {
continue;
}
let start = n_packed * self.record_size;
debug_assert!(
start + self.record_size <= self.send_buf.len(),
"pack_local_cuts: cut {n_packed} exceeds send_buf capacity {}",
self.send_buf.len()
);
serialize_cut(
&mut self.send_buf[start..start + self.record_size],
slot as u32,
iteration as u32,
meta.forward_pass_index,
pool.intercepts[slot],
&pool.coefficients[slot * pool.state_dimension..(slot + 1) * pool.state_dimension],
);
n_packed += 1;
}
n_packed
}
pub fn sync_packed_cuts<C: Communicator>(
&mut self,
stage: usize,
n_local: usize,
fcf: &mut FutureCostFunction,
comm: &C,
) -> Result<usize, SddpError> {
let send_len = n_local * self.record_size;
debug_assert!(
send_len <= self.send_buf.len(),
"send_len {send_len} exceeds send_buf capacity {}",
self.send_buf.len()
);
let my_rank = comm.rank();
for r in 0..self.num_ranks {
let cuts_for_r = if r == my_rank {
n_local
} else {
self.per_rank_cuts[r]
};
self.counts[r] = cuts_for_r * self.record_size;
}
self.displs[0] = 0;
for r in 1..self.num_ranks {
self.displs[r] = self.displs[r - 1] + self.counts[r - 1];
}
let recv_len: usize = self.counts.iter().sum();
debug_assert!(
recv_len <= self.recv_buf.len(),
"recv_len {recv_len} exceeds recv_buf capacity {}",
self.recv_buf.len()
);
comm.allgatherv(
&self.send_buf[..send_len],
&mut self.recv_buf[..recv_len],
&self.counts,
&self.displs,
)?;
let local_rank = comm.rank();
let mut remote_count = 0usize;
for r in 0..self.num_ranks {
if r == local_rank {
continue;
}
let start = self.displs[r];
let end = start + self.counts[r];
let slice = &self.recv_buf[start..end];
let cuts = deserialize_cuts_from_buffer(slice, self.n_state);
for (header, coefficients) in cuts {
fcf.add_cut(
stage,
u64::from(header.iteration),
header.forward_pass_index,
header.intercept,
&coefficients,
);
remote_count += 1;
}
}
Ok(remote_count)
}
#[must_use]
pub fn send_capacity(&self) -> usize {
self.send_buf.len()
}
#[must_use]
pub fn recv_capacity(&self) -> usize {
self.recv_buf.len()
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::float_cmp
)]
use cobre_comm::{CommData, CommError, Communicator, LocalBackend, ReduceOp};
use super::CutSyncBuffers;
use crate::{
SddpError,
cut::{
fcf::FutureCostFunction,
wire::{cut_wire_size, deserialize_cuts_from_buffer, serialize_cut},
},
};
#[test]
fn new_send_buf_capacity_is_max_cuts_times_record_size() {
let bufs = CutSyncBuffers::new(2, 3, 1);
let expected = 3 * cut_wire_size(2);
assert_eq!(bufs.send_capacity(), expected);
}
#[test]
fn new_recv_buf_capacity_is_max_cuts_times_num_ranks_times_record_size() {
let bufs = CutSyncBuffers::new(3, 10, 4);
let expected = 10 * 4 * cut_wire_size(3);
assert_eq!(bufs.recv_capacity(), expected);
assert_eq!(expected, 1920);
}
#[test]
fn new_counts_length_equals_num_ranks() {
let bufs = CutSyncBuffers::new(3, 10, 4);
assert_eq!(bufs.counts.len(), 4);
}
#[test]
fn new_displs_length_equals_num_ranks() {
let bufs = CutSyncBuffers::new(3, 10, 4);
assert_eq!(bufs.displs.len(), 4);
}
#[test]
fn new_counts_and_displs_initialized_to_max_uniform_values() {
let bufs = CutSyncBuffers::new(2, 3, 2);
let per_rank = 3 * cut_wire_size(2); assert_eq!(bufs.counts[0], per_rank);
assert_eq!(bufs.counts[1], per_rank);
assert_eq!(bufs.displs[0], 0);
assert_eq!(bufs.displs[1], per_rank);
}
#[test]
fn new_n_state_zero_record_size_is_24() {
let bufs = CutSyncBuffers::new(0, 5, 1);
assert_eq!(bufs.send_capacity(), 5 * 24);
assert_eq!(bufs.recv_capacity(), 5 * 24);
}
#[test]
fn send_buf_serialization_round_trip_two_cuts() {
let mut bufs = CutSyncBuffers::new(2, 2, 1);
let local_cuts: &[(u32, u32, u32, f64, &[f64])] =
&[(0, 1, 0, 10.0, &[1.0, 2.0]), (1, 1, 1, 20.0, &[3.0, 4.0])];
let record_size = cut_wire_size(2);
let send_len = local_cuts.len() * record_size;
assert_eq!(send_len, 80);
for (i, &(slot_index, iteration, forward_pass_index, intercept, coefficients)) in
local_cuts.iter().enumerate()
{
let start = i * record_size;
serialize_cut(
&mut bufs.send_buf[start..start + record_size],
slot_index,
iteration,
forward_pass_index,
intercept,
coefficients,
);
}
let recovered = deserialize_cuts_from_buffer(&bufs.send_buf[..send_len], 2);
assert_eq!(recovered.len(), 2);
let (h0, c0) = &recovered[0];
assert_eq!(h0.slot_index, 0);
assert_eq!(h0.iteration, 1);
assert_eq!(h0.forward_pass_index, 0);
assert_eq!(h0.intercept, 10.0);
assert_eq!(c0, &[1.0, 2.0]);
let (h1, c1) = &recovered[1];
assert_eq!(h1.slot_index, 1);
assert_eq!(h1.iteration, 1);
assert_eq!(h1.forward_pass_index, 1);
assert_eq!(h1.intercept, 20.0);
assert_eq!(c1, &[3.0, 4.0]);
}
#[test]
fn counts_and_displs_computation_for_various_cut_counts() {
let mut bufs = CutSyncBuffers::new(2, 5, 3);
let n_local = 2usize;
let record_size = cut_wire_size(2); let per_rank = n_local * record_size;
for r in 0..3 {
bufs.counts[r] = per_rank;
bufs.displs[r] = r * per_rank;
}
assert_eq!(bufs.counts, vec![80, 80, 80]);
assert_eq!(bufs.displs, vec![0, 80, 160]);
}
#[test]
fn sync_cuts_single_rank_returns_zero_remote_cuts() {
let mut bufs = CutSyncBuffers::new(2, 3, 1);
let mut fcf = FutureCostFunction::new(2, 2, 3, 10, 0);
let comm = LocalBackend;
let local_cuts: &[(u32, u32, u32, f64, &[f64])] =
&[(0, 1, 0, 10.0, &[1.0, 2.0]), (0, 1, 1, 20.0, &[3.0, 4.0])];
let result = bufs.sync_cuts(0, local_cuts, &mut fcf, &comm).unwrap();
assert_eq!(result, 0, "expected zero remote cuts in single-rank mode");
}
#[test]
fn sync_cuts_single_rank_does_not_insert_local_cuts_into_fcf() {
let mut bufs = CutSyncBuffers::new(2, 3, 1);
let mut fcf = FutureCostFunction::new(2, 2, 3, 10, 0);
let comm = LocalBackend;
let local_cuts: &[(u32, u32, u32, f64, &[f64])] =
&[(0, 1, 0, 10.0, &[1.0, 2.0]), (0, 1, 1, 20.0, &[3.0, 4.0])];
bufs.sync_cuts(0, local_cuts, &mut fcf, &comm).unwrap();
assert_eq!(
fcf.total_active_cuts(),
0,
"sync_cuts must not insert local cuts into FCF"
);
}
#[test]
fn sync_cuts_serialization_round_trip_via_allgatherv_identity() {
let mut bufs = CutSyncBuffers::new(2, 2, 1);
let mut fcf = FutureCostFunction::new(2, 2, 2, 10, 0);
let comm = LocalBackend;
fcf.add_cut(0, 1, 0, 10.0, &[1.0, 2.0]);
let local_cuts: &[(u32, u32, u32, f64, &[f64])] = &[(0, 1, 0, 10.0, &[1.0, 2.0])];
let remote_inserted = bufs.sync_cuts(0, local_cuts, &mut fcf, &comm).unwrap();
assert_eq!(remote_inserted, 0);
assert_eq!(fcf.total_active_cuts(), 1);
}
#[test]
fn sync_cuts_zero_local_cuts_returns_zero() {
let mut bufs = CutSyncBuffers::new(2, 5, 1);
let mut fcf = FutureCostFunction::new(2, 2, 5, 10, 0);
let comm = LocalBackend;
let result = bufs.sync_cuts(0, &[], &mut fcf, &comm).unwrap();
assert_eq!(result, 0);
assert_eq!(fcf.total_active_cuts(), 0);
}
#[test]
fn sync_cuts_error_maps_to_sddp_communication_error() {
struct FailingComm;
impl Communicator for FailingComm {
fn allgatherv<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
Err(CommError::CollectiveFailed {
operation: "allgatherv",
mpi_error_code: 42,
message: "simulated failure".to_string(),
})
}
fn allreduce<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
unreachable!()
}
fn broadcast<T: CommData>(
&self,
_buf: &mut [T],
_root: usize,
) -> Result<(), CommError> {
unreachable!()
}
fn barrier(&self) -> Result<(), CommError> {
unreachable!()
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
1
}
}
let mut bufs = CutSyncBuffers::new(2, 2, 1);
let mut fcf = FutureCostFunction::new(2, 2, 2, 10, 0);
let local_cuts: &[(u32, u32, u32, f64, &[f64])] = &[(0, 1, 0, 5.0, &[1.0, 2.0])];
let result = bufs.sync_cuts(0, local_cuts, &mut fcf, &FailingComm);
assert!(
matches!(result, Err(SddpError::Communication(_))),
"expected SddpError::Communication, got: {result:?}",
);
}
#[test]
fn sync_cuts_three_ranks_returns_four_remote_cuts() {
struct ThreeRankComm;
impl Communicator for ThreeRankComm {
fn allgatherv<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
let r0_len = counts[0];
recv[..r0_len].copy_from_slice(&send[..r0_len]);
Ok(())
}
fn allreduce<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
unreachable!()
}
fn broadcast<T: CommData>(
&self,
_buf: &mut [T],
_root: usize,
) -> Result<(), CommError> {
unreachable!()
}
fn barrier(&self) -> Result<(), CommError> {
unreachable!()
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
3
}
}
let n_state = 2;
let record_size = cut_wire_size(n_state); let n_local = 2;
let per_rank_bytes = n_local * record_size;
let mut fcf = FutureCostFunction::new(1, n_state, 6, 10, 0);
let mut bufs = CutSyncBuffers::new(n_state, n_local, 3);
let r1_start = per_rank_bytes; serialize_cut(
&mut bufs.recv_buf[r1_start..r1_start + record_size],
10,
1,
10,
100.0,
&[1.0, 2.0],
);
serialize_cut(
&mut bufs.recv_buf[r1_start + record_size..r1_start + 2 * record_size],
11,
1,
11,
200.0,
&[3.0, 4.0],
);
let r2_start = 2 * per_rank_bytes; serialize_cut(
&mut bufs.recv_buf[r2_start..r2_start + record_size],
20,
1,
20,
300.0,
&[5.0, 6.0],
);
serialize_cut(
&mut bufs.recv_buf[r2_start + record_size..r2_start + 2 * record_size],
21,
1,
21,
400.0,
&[7.0, 8.0],
);
let local_cuts: &[(u32, u32, u32, f64, &[f64])] =
&[(0, 1, 0, 50.0, &[0.1, 0.2]), (1, 1, 1, 60.0, &[0.3, 0.4])];
let remote_inserted = bufs
.sync_cuts(0, local_cuts, &mut fcf, &ThreeRankComm)
.unwrap();
assert_eq!(remote_inserted, 4, "expected 4 remote cuts inserted");
assert_eq!(fcf.total_active_cuts(), 4);
}
#[test]
fn sync_cuts_preserves_cut_fields_after_deserialization() {
let n_state = 2usize;
let mut bufs = CutSyncBuffers::new(n_state, 1, 1);
let mut fcf = FutureCostFunction::new(1, n_state, 1, 10, 0);
let comm = LocalBackend;
let coeffs = [7.5_f64, -3.25_f64];
let local_cuts: &[(u32, u32, u32, f64, &[f64])] = &[(5, 3, 2, 99.0, &coeffs)];
bufs.sync_cuts(0, local_cuts, &mut fcf, &comm).unwrap();
let record_size = cut_wire_size(n_state);
let recovered = deserialize_cuts_from_buffer(&bufs.recv_buf[..record_size], n_state);
assert_eq!(recovered.len(), 1);
let (header, rec_coeffs) = &recovered[0];
assert_eq!(header.slot_index, 5);
assert_eq!(header.iteration, 3);
assert_eq!(header.forward_pass_index, 2);
assert_eq!(header.intercept, 99.0);
assert_eq!(rec_coeffs[0].to_bits(), coeffs[0].to_bits());
assert_eq!(rec_coeffs[1].to_bits(), coeffs[1].to_bits());
}
}