gam_solve/cross_node.rs
1//! Cross-node deterministic border-Gram reduction (#987, extending #973).
2//!
3//! [`crate::streaming_border`] made the in-process accumulation of the
4//! Schur border Gram `G = Σ_n x_n x_nᵀ` bit-reproducible by construction: the
5//! chunk partition is a pure function of `(n_rows, chunk_size)`, per-chunk
6//! partials are deterministic [`chunk_gram_flat`] reductions, and the
7//! cross-chunk fold is a fixed pairwise tree keyed by **chunk index** — never by
8//! arrival order, thread timing, or device count. This module extends that same
9//! fixed-shape-by-construction discipline **one level up**, to a fleet of
10//! worker nodes, with three properties the frontier corpus regime
11//! (10⁹–10¹¹ tokens, hundreds of TB of activations) demands:
12//!
13//! 1. **Node count never changes bits.** A node's partials are not leaves of a
14//! *separate* per-node tree that then gets merged (that shape would depend
15//! on the node count). Instead every node computes the **globally indexed**
16//! per-chunk partials it owns and ships `(chunk_index, k·k partial)`
17//! messages; the coordinator folds them through the *single* global cascade
18//! of [`StreamingBorderGram`], which accepts any arrival order and folds in
19//! chunk-index order. The reduction topology is therefore a pure function of
20//! `(n_rows, chunk_size)` alone — running on 1 node, 3 nodes, or 64 nodes
21//! yields the identical bit pattern, because the tree never saw the node
22//! count. (The chunk→node *assignment* is rank-indexed and deterministic,
23//! but it only decides who computes a partial, never how partials combine.)
24//! 2. **Checkpoint/resume is the job model, not an afterthought.** Any worker's
25//! death resumes from its serialized [`NodeWorkerCheckpoint`] (a cursor into
26//! its owned chunk sequence); the coordinator's full state — the in-order
27//! fold forest, the pending out-of-order partials, and the per-rank receipt
28//! cursors — serializes to a [`CrossNodeCheckpoint`]. Resume-equals-
29//! straight-through holds at the bit level on both sides because both
30//! cursors are positions in deterministic sequences.
31//! 3. **Partials, never rows, cross the wire.** A worker streams its shard rows
32//! locally (object store / mmap — `gam_sae::corpus`) and ships
33//! only `k·k` f64 partials. The coordinator's ingest seam is
34//! [`StreamingBorderGram::submit_chunk_gram`]; both producers route through
35//! the one [`chunk_gram_flat`] free function, so a shipped partial is
36//! bit-identical to the partial the coordinator would have computed from the
37//! same rows.
38//!
39//! ## Chunk→rank assignment
40//!
41//! Round-robin by chunk index: rank `r` of `n_ranks` owns chunks
42//! `{j : j ≡ r (mod n_ranks)}`, in increasing order. Round-robin (rather than
43//! contiguous ranges) keeps the coordinator's in-order fold frontier advancing
44//! steadily while all ranks make progress at similar rates, which bounds the
45//! pending out-of-order buffer by O(`n_ranks` × inter-node skew) instead of
46//! O(total chunks). The assignment is a pure function of
47//! `(chunk_index, n_ranks)`; no scheduler, no work stealing — work stealing
48//! would not change bits (the fold is index-keyed) but it *would* break the
49//! one-cursor-per-rank resume model, so it is deliberately absent.
50//!
51//! Pure library: no networking, no flags, no environment variables. The
52//! transport (MPI, gRPC, files on a shared filesystem) is the caller's; this
53//! module owns the deterministic topology, the cursors, and the validation.
54
55use crate::streaming_border::{BorderGramCheckpoint, StreamingBorderGram, chunk_gram_flat};
56use ndarray::{Array2, ArrayView2};
57use serde::{Deserialize, Serialize};
58
59/// The deterministic chunk partition + rank-indexed assignment shared by every
60/// participant of one cross-node pass. A pure function of its four fields; two
61/// participants constructed with the same fields agree on every derived
62/// quantity, with no communication.
63#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
64pub struct CrossNodePartition {
65 /// Border dimension `k` (columns of every chunk; partials are `k·k`).
66 pub border_dim: usize,
67 /// Total row count of the full pass.
68 pub n_rows: usize,
69 /// Fixed chunk size (rows per chunk; the last chunk may be shorter).
70 pub chunk_size: usize,
71 /// Number of worker ranks in the fleet.
72 pub n_ranks: usize,
73}
74
75impl CrossNodePartition {
76 pub fn new(
77 border_dim: usize,
78 n_rows: usize,
79 chunk_size: usize,
80 n_ranks: usize,
81 ) -> Result<Self, String> {
82 if border_dim == 0 {
83 return Err("CrossNodePartition: border_dim must be positive".to_string());
84 }
85 if chunk_size == 0 {
86 return Err("CrossNodePartition: chunk_size must be positive".to_string());
87 }
88 if n_ranks == 0 {
89 return Err("CrossNodePartition: n_ranks must be positive".to_string());
90 }
91 Ok(Self {
92 border_dim,
93 n_rows,
94 chunk_size,
95 n_ranks,
96 })
97 }
98
99 /// Total number of chunks of the pass: `ceil(n_rows / chunk_size)`.
100 /// Identical to [`StreamingBorderGram::n_chunks`] for the same partition
101 /// parameters — the global tree this assignment feeds.
102 pub fn n_chunks(&self) -> usize {
103 self.n_rows.div_ceil(self.chunk_size)
104 }
105
106 /// Row range covered by global chunk `chunk_index` — the same pure function
107 /// as [`StreamingBorderGram::chunk_rows`], duplicated here so a worker can
108 /// slice its rows without constructing a coordinator-side accumulator.
109 pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
110 let lo = chunk_index * self.chunk_size;
111 let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
112 lo..hi
113 }
114
115 /// Which rank owns global chunk `chunk_index`: round-robin by index.
116 #[inline]
117 pub fn owner_rank(&self, chunk_index: usize) -> usize {
118 chunk_index % self.n_ranks
119 }
120
121 /// Number of chunks rank `rank` owns.
122 pub fn chunks_owned_by(&self, rank: usize) -> usize {
123 let n = self.n_chunks();
124 if rank >= self.n_ranks || n == 0 {
125 return 0;
126 }
127 // Chunks r, r + n_ranks, r + 2·n_ranks, … below n.
128 if rank < n {
129 (n - rank - 1) / self.n_ranks + 1
130 } else {
131 0
132 }
133 }
134
135 /// The `ordinal`-th (0-based) global chunk index owned by `rank`, or `None`
136 /// past the end of the rank's sequence. The worker cursor is an ordinal
137 /// into exactly this sequence.
138 pub fn owned_chunk(&self, rank: usize, ordinal: usize) -> Option<usize> {
139 if rank >= self.n_ranks {
140 return None;
141 }
142 let idx = rank + ordinal * self.n_ranks;
143 if idx < self.n_chunks() {
144 Some(idx)
145 } else {
146 None
147 }
148 }
149}
150
151/// One shipped partial: the global chunk index plus the deterministic `k·k`
152/// per-chunk Gram. This is the only message that crosses the node boundary —
153/// `k·k` f64 values per chunk, never rows.
154#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
155pub struct NodePartial {
156 /// Rank that produced this partial (its assignment is validated on
157 /// receipt, so a misconfigured worker is rejected loudly).
158 pub rank: usize,
159 /// Global chunk index of the partial.
160 pub chunk_index: usize,
161 /// Flattened `k·k` row-major per-chunk Gram, as produced by
162 /// [`chunk_gram_flat`] over the chunk's rows.
163 pub gram: Vec<f64>,
164}
165
166/// Serialized cursor of one worker: everything needed for a **replacement**
167/// process (same rank, any host) to continue the dead worker's deterministic
168/// chunk sequence from where receipts stopped. Pure data; the worker's row
169/// source re-seeks by row range, which is a pure function of the partition.
170#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
171pub struct NodeWorkerCheckpoint {
172 pub partition: CrossNodePartition,
173 pub rank: usize,
174 /// Ordinal (into the rank's owned-chunk sequence) of the next chunk to
175 /// compute and ship. Everything below it has been durably shipped.
176 pub next_ordinal: usize,
177}
178
179/// Worker-side driver for one rank: walks the rank's deterministic owned-chunk
180/// sequence, turning row slices into shippable [`NodePartial`]s.
181///
182/// The worker does **not** do I/O: the caller streams rows (from its shards /
183/// object store) for the row range [`NodeWorker::next_chunk_rows`] names, hands
184/// them to [`NodeWorker::emit`], and ships the returned partial. The cursor
185/// advances only on `emit`, so "ship durably, then checkpoint" gives exactly-
186/// once production under crash-resume (re-shipping an already-folded chunk is
187/// rejected by the coordinator as a duplicate, which is the safe failure).
188#[derive(Clone, Debug)]
189pub struct NodeWorker {
190 partition: CrossNodePartition,
191 rank: usize,
192 next_ordinal: usize,
193}
194
195impl NodeWorker {
196 /// Fresh worker for `rank`, starting at the beginning of its sequence.
197 pub fn new(partition: CrossNodePartition, rank: usize) -> Result<Self, String> {
198 if rank >= partition.n_ranks {
199 return Err(format!(
200 "NodeWorker: rank {rank} out of range (n_ranks = {})",
201 partition.n_ranks
202 ));
203 }
204 Ok(Self {
205 partition,
206 rank,
207 next_ordinal: 0,
208 })
209 }
210
211 /// Resume a (replacement) worker from a serialized cursor. Validates the
212 /// cursor against the partition so a checkpoint from a different pass is
213 /// rejected loudly.
214 pub fn resume(state: NodeWorkerCheckpoint) -> Result<Self, String> {
215 if state.rank >= state.partition.n_ranks {
216 return Err(format!(
217 "NodeWorkerCheckpoint: rank {} out of range (n_ranks = {})",
218 state.rank, state.partition.n_ranks
219 ));
220 }
221 let owned = state.partition.chunks_owned_by(state.rank);
222 if state.next_ordinal > owned {
223 return Err(format!(
224 "NodeWorkerCheckpoint: next_ordinal {} exceeds owned chunk count {owned}",
225 state.next_ordinal
226 ));
227 }
228 Ok(Self {
229 partition: state.partition,
230 rank: state.rank,
231 next_ordinal: state.next_ordinal,
232 })
233 }
234
235 /// Serialize the cursor. Write this (durably) after each successful ship.
236 pub fn checkpoint(&self) -> NodeWorkerCheckpoint {
237 NodeWorkerCheckpoint {
238 partition: self.partition,
239 rank: self.rank,
240 next_ordinal: self.next_ordinal,
241 }
242 }
243
244 /// `true` once this rank's sequence is exhausted.
245 pub fn is_done(&self) -> bool {
246 self.partition
247 .owned_chunk(self.rank, self.next_ordinal)
248 .is_none()
249 }
250
251 /// Global chunk index and row range of the next chunk to compute, or
252 /// `None` when done. The caller fetches exactly these rows.
253 pub fn next_chunk_rows(&self) -> Option<(usize, std::ops::Range<usize>)> {
254 let idx = self.partition.owned_chunk(self.rank, self.next_ordinal)?;
255 Some((idx, self.partition.chunk_rows(idx)))
256 }
257
258 /// Compute the next chunk's deterministic partial from its rows and advance
259 /// the cursor. `rows` must be exactly the rows of
260 /// [`NodeWorker::next_chunk_rows`] (shape-validated here; content is the
261 /// caller's contract, same as the in-process path).
262 pub fn emit(&mut self, rows: ArrayView2<'_, f64>) -> Result<NodePartial, String> {
263 let (chunk_index, range) = self
264 .next_chunk_rows()
265 .ok_or_else(|| format!("NodeWorker rank {}: sequence exhausted", self.rank))?;
266 if rows.nrows() != range.len() || rows.ncols() != self.partition.border_dim {
267 return Err(format!(
268 "NodeWorker rank {}: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
269 self.rank,
270 rows.nrows(),
271 rows.ncols(),
272 range.len(),
273 self.partition.border_dim
274 ));
275 }
276 let gram = chunk_gram_flat(rows);
277 self.next_ordinal += 1;
278 Ok(NodePartial {
279 rank: self.rank,
280 chunk_index,
281 gram,
282 })
283 }
284}
285
286/// Serializable coordinator state: the inner accumulation state plus the
287/// per-rank receipt cursors.
288#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
289pub struct CrossNodeCheckpoint {
290 pub partition: CrossNodePartition,
291 /// The wrapped [`StreamingBorderGram`] state (fold forest + pending +
292 /// chunk frontier).
293 pub inner: BorderGramCheckpoint,
294 /// Per-rank count of partials received so far — each is an ordinal cursor
295 /// into that rank's deterministic owned-chunk sequence, used both for
296 /// receipt validation (in-sequence, no gaps per rank) and to tell a
297 /// restarted fleet where each rank should resume.
298 pub received_per_rank: Vec<usize>,
299}
300
301/// Coordinator-side reduction: receives [`NodePartial`]s from the fleet and
302/// folds them into the single global fixed-tree accumulator.
303///
304/// Receipt validation is per rank and in-sequence: rank `r`'s `i`-th accepted
305/// partial must be its `i`-th owned chunk. This makes the per-rank cursor in
306/// [`CrossNodeCheckpoint::received_per_rank`] a complete description of what
307/// has been received, which is what lets a dead rank resume from a bare
308/// ordinal. Cross-rank arrival order is unconstrained (the inner accumulator
309/// buffers out-of-order chunks), so slow nodes never block fast ones.
310pub struct CrossNodeGramReduction {
311 partition: CrossNodePartition,
312 inner: StreamingBorderGram,
313 received_per_rank: Vec<usize>,
314}
315
316impl CrossNodeGramReduction {
317 /// Fresh coordinator for the given partition.
318 pub fn new(partition: CrossNodePartition) -> Result<Self, String> {
319 let inner =
320 StreamingBorderGram::new(partition.border_dim, partition.n_rows, partition.chunk_size)?;
321 Ok(Self {
322 received_per_rank: vec![0; partition.n_ranks],
323 partition,
324 inner,
325 })
326 }
327
328 /// The shared partition (workers must be constructed with an equal one).
329 pub fn partition(&self) -> CrossNodePartition {
330 self.partition
331 }
332
333 /// How many partials rank `rank` has had accepted — the ordinal a
334 /// replacement worker for that rank should resume from.
335 pub fn rank_cursor(&self, rank: usize) -> Option<usize> {
336 self.received_per_rank.get(rank).copied()
337 }
338
339 /// `true` once every chunk of every rank has been received and folded.
340 pub fn is_complete(&self) -> bool {
341 self.inner.is_complete()
342 }
343
344 /// Receive one shipped partial. Validates rank, ownership, and per-rank
345 /// sequence position, then folds through
346 /// [`StreamingBorderGram::submit_chunk_gram`] (which re-validates index
347 /// range, duplicates, and partial shape). A duplicate of an already-folded
348 /// chunk — the signature of an at-least-once transport retry or a worker
349 /// that resumed from a stale cursor — is rejected with an error naming the
350 /// chunk, never silently double-counted.
351 pub fn receive(&mut self, partial: NodePartial) -> Result<(), String> {
352 let NodePartial {
353 rank,
354 chunk_index,
355 gram,
356 } = partial;
357 if rank >= self.partition.n_ranks {
358 return Err(format!(
359 "CrossNodeGramReduction: rank {rank} out of range (n_ranks = {})",
360 self.partition.n_ranks
361 ));
362 }
363 if self.partition.owner_rank(chunk_index) != rank {
364 return Err(format!(
365 "CrossNodeGramReduction: chunk {chunk_index} is owned by rank {}, not rank {rank}",
366 self.partition.owner_rank(chunk_index)
367 ));
368 }
369 let cursor = self.received_per_rank[rank];
370 match self.partition.owned_chunk(rank, cursor) {
371 Some(expected) if expected == chunk_index => {}
372 Some(expected) => {
373 return Err(format!(
374 "CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} but its \
375 cursor expects chunk {expected} (ordinal {cursor}); a worker resumed from \
376 a stale or future checkpoint"
377 ));
378 }
379 None => {
380 return Err(format!(
381 "CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} past the \
382 end of its owned sequence"
383 ));
384 }
385 }
386 self.inner.submit_chunk_gram(chunk_index, gram)?;
387 self.received_per_rank[rank] = cursor + 1;
388 Ok(())
389 }
390
391 /// Serialize the full coordinator state. Resume-equals-straight-through is
392 /// inherited bit-for-bit from the inner accumulator; the per-rank cursors
393 /// resume receipt validation exactly where it stopped.
394 pub fn checkpoint(&self) -> CrossNodeCheckpoint {
395 CrossNodeCheckpoint {
396 partition: self.partition,
397 inner: self.inner.checkpoint(),
398 received_per_rank: self.received_per_rank.clone(),
399 }
400 }
401
402 /// Reconstruct a coordinator from a checkpoint, validating the cursor
403 /// structure against the partition so corruption is rejected loudly.
404 pub fn resume(state: CrossNodeCheckpoint) -> Result<Self, String> {
405 if state.received_per_rank.len() != state.partition.n_ranks {
406 return Err(format!(
407 "CrossNodeCheckpoint: {} rank cursors for n_ranks = {}",
408 state.received_per_rank.len(),
409 state.partition.n_ranks
410 ));
411 }
412 if state.inner.border_dim != state.partition.border_dim
413 || state.inner.n_rows != state.partition.n_rows
414 || state.inner.chunk_size != state.partition.chunk_size
415 {
416 return Err(
417 "CrossNodeCheckpoint: inner accumulator partition disagrees with the cross-node \
418 partition"
419 .to_string(),
420 );
421 }
422 for (rank, &cursor) in state.received_per_rank.iter().enumerate() {
423 if cursor > state.partition.chunks_owned_by(rank) {
424 return Err(format!(
425 "CrossNodeCheckpoint: rank {rank} cursor {cursor} exceeds its owned chunk \
426 count {}",
427 state.partition.chunks_owned_by(rank)
428 ));
429 }
430 }
431 let inner = StreamingBorderGram::resume(state.inner)?;
432 Ok(Self {
433 partition: state.partition,
434 inner,
435 received_per_rank: state.received_per_rank,
436 })
437 }
438
439 /// Finish the pass, returning the `k×k` border Gram. Errors if any rank's
440 /// sequence is incomplete. The result is a pure function of the row content
441 /// and `(n_rows, chunk_size)` — identical bits for any node count, any
442 /// arrival interleaving, and any checkpoint/resume history on either side.
443 pub fn finish(self) -> Result<Array2<f64>, String> {
444 for (rank, &cursor) in self.received_per_rank.iter().enumerate() {
445 let owned = self.partition.chunks_owned_by(rank);
446 if cursor != owned {
447 return Err(format!(
448 "CrossNodeGramReduction: finish() with rank {rank} at ordinal {cursor} of \
449 {owned} owned chunks"
450 ));
451 }
452 }
453 self.inner.finish()
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use ndarray::s;
461
462 /// Deterministic pseudo-random row matrix keyed purely by index (same
463 /// recipe as the streaming_border tests, so cross-file comparisons hold).
464 fn planted_rows(n: usize, k: usize) -> Array2<f64> {
465 Array2::from_shape_fn((n, k), |(i, j)| {
466 let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
467 (x.sin() * 43_758.547).fract() * 2.0 - 1.0
468 })
469 }
470
471 fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
472 assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
473 for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
474 assert_eq!(
475 x.to_bits(),
476 y.to_bits(),
477 "{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
478 );
479 }
480 }
481
482 /// Run a whole fleet of `n_ranks` workers to completion against one
483 /// coordinator, interleaving ranks round-robin with a deterministic skew so
484 /// arrival order exercises the pending buffer. Returns the final Gram.
485 fn run_fleet(rows: &Array2<f64>, chunk_size: usize, n_ranks: usize) -> Array2<f64> {
486 let partition =
487 CrossNodePartition::new(rows.ncols(), rows.nrows(), chunk_size, n_ranks).unwrap();
488 let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
489 let mut workers: Vec<NodeWorker> = (0..n_ranks)
490 .map(|r| NodeWorker::new(partition, r).unwrap())
491 .collect();
492 // Deterministic skewed interleaving: each sweep lets rank r ship
493 // (r % 3 + 1) chunks, so ranks run ahead/behind each other and the
494 // coordinator's out-of-order pending path is exercised.
495 let mut any_live = true;
496 while any_live {
497 any_live = false;
498 for (r, worker) in workers.iter_mut().enumerate() {
499 for _ in 0..(r % 3 + 1) {
500 let Some((_, range)) = worker.next_chunk_rows() else {
501 break;
502 };
503 let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
504 coordinator.receive(partial).unwrap();
505 any_live = true;
506 }
507 if !worker.is_done() {
508 any_live = true;
509 }
510 }
511 }
512 assert!(coordinator.is_complete());
513 coordinator.finish().unwrap()
514 }
515
516 #[test]
517 fn node_count_never_changes_bits() {
518 // The frontier invariant: 1, 3, and 5 nodes produce the identical bit
519 // pattern, and all match the single-process StreamingBorderGram.
520 let n = 977; // not a multiple of the chunk size
521 let k = 4;
522 let chunk_size = 7;
523 let rows = planted_rows(n, k);
524
525 let mut single = StreamingBorderGram::new(k, n, chunk_size).unwrap();
526 for j in 0..single.n_chunks() {
527 let range = single.chunk_rows(j);
528 single.submit_chunk(j, rows.slice(s![range, ..])).unwrap();
529 }
530 let reference = single.finish().unwrap();
531
532 for n_ranks in [1usize, 3, 5] {
533 let fleet = run_fleet(&rows, chunk_size, n_ranks);
534 assert_bit_identical(
535 &reference,
536 &fleet,
537 &format!("single-process vs {n_ranks}-node fleet"),
538 );
539 }
540 }
541
542 #[test]
543 fn dead_node_resumes_from_cursor_bit_identically() {
544 let n = 530;
545 let k = 3;
546 let chunk_size = 5;
547 let n_ranks = 3;
548 let rows = planted_rows(n, k);
549 let reference = run_fleet(&rows, chunk_size, n_ranks);
550
551 let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
552 let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
553 let mut workers: Vec<NodeWorker> = (0..n_ranks)
554 .map(|r| NodeWorker::new(partition, r).unwrap())
555 .collect();
556
557 // Rank 1 ships 4 chunks, checkpoints durably, then "dies". The other
558 // ranks ship a few chunks too.
559 let mut rank1_cursor = None;
560 for (r, worker) in workers.iter_mut().enumerate() {
561 let ship = if r == 1 { 4 } else { 2 };
562 for _ in 0..ship {
563 let Some((_, range)) = worker.next_chunk_rows() else {
564 break;
565 };
566 let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
567 coordinator.receive(partial).unwrap();
568 }
569 if r == 1 {
570 let json = serde_json::to_string(&worker.checkpoint()).unwrap();
571 rank1_cursor = Some(json);
572 }
573 }
574 workers.remove(1); // the death: rank-1 worker removed and dropped here
575
576 // Coordinator also survives a checkpoint round-trip mid-pass.
577 let coord_json = serde_json::to_string(&coordinator.checkpoint()).unwrap();
578 drop(coordinator);
579 let restored: CrossNodeCheckpoint = serde_json::from_str(&coord_json).unwrap();
580 let mut coordinator = CrossNodeGramReduction::resume(restored).unwrap();
581
582 // A replacement process resumes rank 1 from its serialized cursor; the
583 // coordinator's own cursor agrees with it.
584 let cp: NodeWorkerCheckpoint = serde_json::from_str(&rank1_cursor.unwrap()).unwrap();
585 assert_eq!(coordinator.rank_cursor(1), Some(cp.next_ordinal));
586 let replacement = NodeWorker::resume(cp).unwrap();
587 workers.insert(1, replacement);
588
589 // Drain the fleet to completion.
590 let mut any_live = true;
591 while any_live {
592 any_live = false;
593 for worker in workers.iter_mut() {
594 if let Some((_, range)) = worker.next_chunk_rows() {
595 let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
596 coordinator.receive(partial).unwrap();
597 any_live = true;
598 }
599 }
600 }
601 let resumed = coordinator.finish().unwrap();
602 assert_bit_identical(&reference, &resumed, "death-resume vs straight-through");
603 }
604
605 #[test]
606 fn receipt_validation_rejects_misrouted_and_out_of_sequence_partials() {
607 let n = 60;
608 let k = 2;
609 let chunk_size = 4; // 15 chunks
610 let n_ranks = 3;
611 let rows = planted_rows(n, k);
612 let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
613 let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
614 let mut w0 = NodeWorker::new(partition, 0).unwrap();
615
616 let (idx, range) = w0.next_chunk_rows().unwrap();
617 assert_eq!(idx, 0);
618 let mut partial = w0.emit(rows.slice(s![range, ..])).unwrap();
619
620 // Misrouted: claim the partial came from rank 1 (which does not own
621 // chunk 0).
622 partial.rank = 1;
623 let err = coordinator.receive(partial.clone()).unwrap_err();
624 assert!(err.contains("owned by rank 0"), "got: {err}");
625
626 // Correctly routed: accepted.
627 partial.rank = 0;
628 coordinator.receive(partial.clone()).unwrap();
629
630 // Duplicate (transport retry): rejected, never double-counted.
631 let err = coordinator.receive(partial).unwrap_err();
632 assert!(err.contains("cursor expects chunk 3"), "got: {err}");
633
634 // Out of sequence: rank 0's cursor expects its ordinal-1 chunk (global
635 // chunk 3), not its ordinal-2 chunk (global chunk 6).
636 let (idx, range) = w0.next_chunk_rows().unwrap();
637 assert_eq!(idx, 3);
638 let skipped = w0.emit(rows.slice(s![range, ..])).unwrap();
639 let (idx6, range6) = w0.next_chunk_rows().unwrap();
640 assert_eq!(idx6, 6);
641 let ahead = w0.emit(rows.slice(s![range6, ..])).unwrap();
642 let err = coordinator.receive(ahead).unwrap_err();
643 assert!(err.contains("expects chunk 3"), "got: {err}");
644 coordinator.receive(skipped).unwrap();
645 }
646
647 #[test]
648 fn assignment_is_a_pure_partition() {
649 // Every chunk is owned by exactly one rank, and the per-rank owned
650 // sequences tile [0, n_chunks) — for several fleet sizes.
651 for (n_rows, chunk_size, n_ranks) in [(100, 7, 1), (100, 7, 4), (3, 10, 8), (0, 5, 3)] {
652 let partition = CrossNodePartition::new(2, n_rows, chunk_size, n_ranks).unwrap();
653 let n_chunks = partition.n_chunks();
654 let mut seen = vec![false; n_chunks];
655 let mut total = 0usize;
656 for rank in 0..n_ranks {
657 let owned = partition.chunks_owned_by(rank);
658 for ordinal in 0..owned {
659 let idx = partition.owned_chunk(rank, ordinal).unwrap();
660 assert_eq!(partition.owner_rank(idx), rank);
661 assert!(!seen[idx], "chunk {idx} assigned twice");
662 seen[idx] = true;
663 total += 1;
664 }
665 assert!(partition.owned_chunk(rank, owned).is_none());
666 }
667 assert_eq!(total, n_chunks, "assignment must tile all chunks");
668 }
669 }
670}