use crate::solver::streaming_border::{BorderGramCheckpoint, StreamingBorderGram, chunk_gram_flat};
use ndarray::{Array2, ArrayView2};
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CrossNodePartition {
pub border_dim: usize,
pub n_rows: usize,
pub chunk_size: usize,
pub n_ranks: usize,
}
impl CrossNodePartition {
pub fn new(
border_dim: usize,
n_rows: usize,
chunk_size: usize,
n_ranks: usize,
) -> Result<Self, String> {
if border_dim == 0 {
return Err("CrossNodePartition: border_dim must be positive".to_string());
}
if chunk_size == 0 {
return Err("CrossNodePartition: chunk_size must be positive".to_string());
}
if n_ranks == 0 {
return Err("CrossNodePartition: n_ranks must be positive".to_string());
}
Ok(Self {
border_dim,
n_rows,
chunk_size,
n_ranks,
})
}
pub fn n_chunks(&self) -> usize {
self.n_rows.div_ceil(self.chunk_size)
}
pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
let lo = chunk_index * self.chunk_size;
let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
lo..hi
}
#[inline]
pub fn owner_rank(&self, chunk_index: usize) -> usize {
chunk_index % self.n_ranks
}
pub fn chunks_owned_by(&self, rank: usize) -> usize {
let n = self.n_chunks();
if rank >= self.n_ranks || n == 0 {
return 0;
}
if rank < n {
(n - rank - 1) / self.n_ranks + 1
} else {
0
}
}
pub fn owned_chunk(&self, rank: usize, ordinal: usize) -> Option<usize> {
if rank >= self.n_ranks {
return None;
}
let idx = rank + ordinal * self.n_ranks;
if idx < self.n_chunks() {
Some(idx)
} else {
None
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct NodePartial {
pub rank: usize,
pub chunk_index: usize,
pub gram: Vec<f64>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NodeWorkerCheckpoint {
pub partition: CrossNodePartition,
pub rank: usize,
pub next_ordinal: usize,
}
#[derive(Clone, Debug)]
pub struct NodeWorker {
partition: CrossNodePartition,
rank: usize,
next_ordinal: usize,
}
impl NodeWorker {
pub fn new(partition: CrossNodePartition, rank: usize) -> Result<Self, String> {
if rank >= partition.n_ranks {
return Err(format!(
"NodeWorker: rank {rank} out of range (n_ranks = {})",
partition.n_ranks
));
}
Ok(Self {
partition,
rank,
next_ordinal: 0,
})
}
pub fn resume(state: NodeWorkerCheckpoint) -> Result<Self, String> {
if state.rank >= state.partition.n_ranks {
return Err(format!(
"NodeWorkerCheckpoint: rank {} out of range (n_ranks = {})",
state.rank, state.partition.n_ranks
));
}
let owned = state.partition.chunks_owned_by(state.rank);
if state.next_ordinal > owned {
return Err(format!(
"NodeWorkerCheckpoint: next_ordinal {} exceeds owned chunk count {owned}",
state.next_ordinal
));
}
Ok(Self {
partition: state.partition,
rank: state.rank,
next_ordinal: state.next_ordinal,
})
}
pub fn checkpoint(&self) -> NodeWorkerCheckpoint {
NodeWorkerCheckpoint {
partition: self.partition,
rank: self.rank,
next_ordinal: self.next_ordinal,
}
}
pub fn is_done(&self) -> bool {
self.partition
.owned_chunk(self.rank, self.next_ordinal)
.is_none()
}
pub fn next_chunk_rows(&self) -> Option<(usize, std::ops::Range<usize>)> {
let idx = self.partition.owned_chunk(self.rank, self.next_ordinal)?;
Some((idx, self.partition.chunk_rows(idx)))
}
pub fn emit(&mut self, rows: ArrayView2<'_, f64>) -> Result<NodePartial, String> {
let (chunk_index, range) = self
.next_chunk_rows()
.ok_or_else(|| format!("NodeWorker rank {}: sequence exhausted", self.rank))?;
if rows.nrows() != range.len() || rows.ncols() != self.partition.border_dim {
return Err(format!(
"NodeWorker rank {}: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
self.rank,
rows.nrows(),
rows.ncols(),
range.len(),
self.partition.border_dim
));
}
let gram = chunk_gram_flat(rows);
self.next_ordinal += 1;
Ok(NodePartial {
rank: self.rank,
chunk_index,
gram,
})
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CrossNodeCheckpoint {
pub partition: CrossNodePartition,
pub inner: BorderGramCheckpoint,
pub received_per_rank: Vec<usize>,
}
pub struct CrossNodeGramReduction {
partition: CrossNodePartition,
inner: StreamingBorderGram,
received_per_rank: Vec<usize>,
}
impl CrossNodeGramReduction {
pub fn new(partition: CrossNodePartition) -> Result<Self, String> {
let inner =
StreamingBorderGram::new(partition.border_dim, partition.n_rows, partition.chunk_size)?;
Ok(Self {
received_per_rank: vec![0; partition.n_ranks],
partition,
inner,
})
}
pub fn partition(&self) -> CrossNodePartition {
self.partition
}
pub fn rank_cursor(&self, rank: usize) -> Option<usize> {
self.received_per_rank.get(rank).copied()
}
pub fn is_complete(&self) -> bool {
self.inner.is_complete()
}
pub fn receive(&mut self, partial: NodePartial) -> Result<(), String> {
let NodePartial {
rank,
chunk_index,
gram,
} = partial;
if rank >= self.partition.n_ranks {
return Err(format!(
"CrossNodeGramReduction: rank {rank} out of range (n_ranks = {})",
self.partition.n_ranks
));
}
if self.partition.owner_rank(chunk_index) != rank {
return Err(format!(
"CrossNodeGramReduction: chunk {chunk_index} is owned by rank {}, not rank {rank}",
self.partition.owner_rank(chunk_index)
));
}
let cursor = self.received_per_rank[rank];
match self.partition.owned_chunk(rank, cursor) {
Some(expected) if expected == chunk_index => {}
Some(expected) => {
return Err(format!(
"CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} but its \
cursor expects chunk {expected} (ordinal {cursor}); a worker resumed from \
a stale or future checkpoint"
));
}
None => {
return Err(format!(
"CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} past the \
end of its owned sequence"
));
}
}
self.inner.submit_chunk_gram(chunk_index, gram)?;
self.received_per_rank[rank] = cursor + 1;
Ok(())
}
pub fn checkpoint(&self) -> CrossNodeCheckpoint {
CrossNodeCheckpoint {
partition: self.partition,
inner: self.inner.checkpoint(),
received_per_rank: self.received_per_rank.clone(),
}
}
pub fn resume(state: CrossNodeCheckpoint) -> Result<Self, String> {
if state.received_per_rank.len() != state.partition.n_ranks {
return Err(format!(
"CrossNodeCheckpoint: {} rank cursors for n_ranks = {}",
state.received_per_rank.len(),
state.partition.n_ranks
));
}
if state.inner.border_dim != state.partition.border_dim
|| state.inner.n_rows != state.partition.n_rows
|| state.inner.chunk_size != state.partition.chunk_size
{
return Err(
"CrossNodeCheckpoint: inner accumulator partition disagrees with the cross-node \
partition"
.to_string(),
);
}
for (rank, &cursor) in state.received_per_rank.iter().enumerate() {
if cursor > state.partition.chunks_owned_by(rank) {
return Err(format!(
"CrossNodeCheckpoint: rank {rank} cursor {cursor} exceeds its owned chunk \
count {}",
state.partition.chunks_owned_by(rank)
));
}
}
let inner = StreamingBorderGram::resume(state.inner)?;
Ok(Self {
partition: state.partition,
inner,
received_per_rank: state.received_per_rank,
})
}
pub fn finish(self) -> Result<Array2<f64>, String> {
for (rank, &cursor) in self.received_per_rank.iter().enumerate() {
let owned = self.partition.chunks_owned_by(rank);
if cursor != owned {
return Err(format!(
"CrossNodeGramReduction: finish() with rank {rank} at ordinal {cursor} of \
{owned} owned chunks"
));
}
}
self.inner.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
fn planted_rows(n: usize, k: usize) -> Array2<f64> {
Array2::from_shape_fn((n, k), |(i, j)| {
let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
(x.sin() * 43_758.547).fract() * 2.0 - 1.0
})
}
fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
assert_eq!(
x.to_bits(),
y.to_bits(),
"{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
);
}
}
fn run_fleet(rows: &Array2<f64>, chunk_size: usize, n_ranks: usize) -> Array2<f64> {
let partition =
CrossNodePartition::new(rows.ncols(), rows.nrows(), chunk_size, n_ranks).unwrap();
let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
let mut workers: Vec<NodeWorker> = (0..n_ranks)
.map(|r| NodeWorker::new(partition, r).unwrap())
.collect();
let mut any_live = true;
while any_live {
any_live = false;
for (r, worker) in workers.iter_mut().enumerate() {
for _ in 0..(r % 3 + 1) {
let Some((_, range)) = worker.next_chunk_rows() else {
break;
};
let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
coordinator.receive(partial).unwrap();
any_live = true;
}
if !worker.is_done() {
any_live = true;
}
}
}
assert!(coordinator.is_complete());
coordinator.finish().unwrap()
}
#[test]
fn node_count_never_changes_bits() {
let n = 977; let k = 4;
let chunk_size = 7;
let rows = planted_rows(n, k);
let mut single = StreamingBorderGram::new(k, n, chunk_size).unwrap();
for j in 0..single.n_chunks() {
let range = single.chunk_rows(j);
single.submit_chunk(j, rows.slice(s![range, ..])).unwrap();
}
let reference = single.finish().unwrap();
for n_ranks in [1usize, 3, 5] {
let fleet = run_fleet(&rows, chunk_size, n_ranks);
assert_bit_identical(
&reference,
&fleet,
&format!("single-process vs {n_ranks}-node fleet"),
);
}
}
#[test]
fn dead_node_resumes_from_cursor_bit_identically() {
let n = 530;
let k = 3;
let chunk_size = 5;
let n_ranks = 3;
let rows = planted_rows(n, k);
let reference = run_fleet(&rows, chunk_size, n_ranks);
let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
let mut workers: Vec<NodeWorker> = (0..n_ranks)
.map(|r| NodeWorker::new(partition, r).unwrap())
.collect();
let mut rank1_cursor = None;
for (r, worker) in workers.iter_mut().enumerate() {
let ship = if r == 1 { 4 } else { 2 };
for _ in 0..ship {
let Some((_, range)) = worker.next_chunk_rows() else {
break;
};
let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
coordinator.receive(partial).unwrap();
}
if r == 1 {
let json = serde_json::to_string(&worker.checkpoint()).unwrap();
rank1_cursor = Some(json);
}
}
workers.remove(1);
let coord_json = serde_json::to_string(&coordinator.checkpoint()).unwrap();
drop(coordinator);
let restored: CrossNodeCheckpoint = serde_json::from_str(&coord_json).unwrap();
let mut coordinator = CrossNodeGramReduction::resume(restored).unwrap();
let cp: NodeWorkerCheckpoint = serde_json::from_str(&rank1_cursor.unwrap()).unwrap();
assert_eq!(coordinator.rank_cursor(1), Some(cp.next_ordinal));
let replacement = NodeWorker::resume(cp).unwrap();
workers.insert(1, replacement);
let mut any_live = true;
while any_live {
any_live = false;
for worker in workers.iter_mut() {
if let Some((_, range)) = worker.next_chunk_rows() {
let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
coordinator.receive(partial).unwrap();
any_live = true;
}
}
}
let resumed = coordinator.finish().unwrap();
assert_bit_identical(&reference, &resumed, "death-resume vs straight-through");
}
#[test]
fn receipt_validation_rejects_misrouted_and_out_of_sequence_partials() {
let n = 60;
let k = 2;
let chunk_size = 4; let n_ranks = 3;
let rows = planted_rows(n, k);
let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
let mut w0 = NodeWorker::new(partition, 0).unwrap();
let (idx, range) = w0.next_chunk_rows().unwrap();
assert_eq!(idx, 0);
let mut partial = w0.emit(rows.slice(s![range, ..])).unwrap();
partial.rank = 1;
let err = coordinator.receive(partial.clone()).unwrap_err();
assert!(err.contains("owned by rank 0"), "got: {err}");
partial.rank = 0;
coordinator.receive(partial.clone()).unwrap();
let err = coordinator.receive(partial).unwrap_err();
assert!(err.contains("cursor expects chunk 3"), "got: {err}");
let (idx, range) = w0.next_chunk_rows().unwrap();
assert_eq!(idx, 3);
let skipped = w0.emit(rows.slice(s![range, ..])).unwrap();
let (idx6, range6) = w0.next_chunk_rows().unwrap();
assert_eq!(idx6, 6);
let ahead = w0.emit(rows.slice(s![range6, ..])).unwrap();
let err = coordinator.receive(ahead).unwrap_err();
assert!(err.contains("expects chunk 3"), "got: {err}");
coordinator.receive(skipped).unwrap();
}
#[test]
fn assignment_is_a_pure_partition() {
for (n_rows, chunk_size, n_ranks) in [(100, 7, 1), (100, 7, 4), (3, 10, 8), (0, 5, 3)] {
let partition = CrossNodePartition::new(2, n_rows, chunk_size, n_ranks).unwrap();
let n_chunks = partition.n_chunks();
let mut seen = vec![false; n_chunks];
let mut total = 0usize;
for rank in 0..n_ranks {
let owned = partition.chunks_owned_by(rank);
for ordinal in 0..owned {
let idx = partition.owned_chunk(rank, ordinal).unwrap();
assert_eq!(partition.owner_rank(idx), rank);
assert!(!seen[idx], "chunk {idx} assigned twice");
seen[idx] = true;
total += 1;
}
assert!(partition.owned_chunk(rank, owned).is_none());
}
assert_eq!(total, n_chunks, "assignment must tile all chunks");
}
}
}