Skip to main content

gam_solve/
cross_node.rs

1//! Cross-node deterministic border-Gram reduction (#987, extending #973).
2//!
3//! [`crate::streaming_border`] made the in-process accumulation of the
4//! Schur border Gram `G = Σ_n x_n x_nᵀ` bit-reproducible by construction: the
5//! chunk partition is a pure function of `(n_rows, chunk_size)`, per-chunk
6//! partials are deterministic [`chunk_gram_flat`] reductions, and the
7//! cross-chunk fold is a fixed pairwise tree keyed by **chunk index** — never by
8//! arrival order, thread timing, or device count. This module extends that same
9//! fixed-shape-by-construction discipline **one level up**, to a fleet of
10//! worker nodes, with three properties the frontier corpus regime
11//! (10⁹–10¹¹ tokens, hundreds of TB of activations) demands:
12//!
13//! 1. **Node count never changes bits.** A node's partials are not leaves of a
14//!    *separate* per-node tree that then gets merged (that shape would depend
15//!    on the node count). Instead every node computes the **globally indexed**
16//!    per-chunk partials it owns and ships `(chunk_index, k·k partial)`
17//!    messages; the coordinator folds them through the *single* global cascade
18//!    of [`StreamingBorderGram`], which accepts any arrival order and folds in
19//!    chunk-index order. The reduction topology is therefore a pure function of
20//!    `(n_rows, chunk_size)` alone — running on 1 node, 3 nodes, or 64 nodes
21//!    yields the identical bit pattern, because the tree never saw the node
22//!    count. (The chunk→node *assignment* is rank-indexed and deterministic,
23//!    but it only decides who computes a partial, never how partials combine.)
24//! 2. **Checkpoint/resume is the job model, not an afterthought.** Any worker's
25//!    death resumes from its serialized [`NodeWorkerCheckpoint`] (a cursor into
26//!    its owned chunk sequence); the coordinator's full state — the in-order
27//!    fold forest, the pending out-of-order partials, and the per-rank receipt
28//!    cursors — serializes to a [`CrossNodeCheckpoint`]. Resume-equals-
29//!    straight-through holds at the bit level on both sides because both
30//!    cursors are positions in deterministic sequences.
31//! 3. **Partials, never rows, cross the wire.** A worker streams its shard rows
32//!    locally (object store / mmap — `gam_sae::corpus`) and ships
33//!    only `k·k` f64 partials. The coordinator's ingest seam is
34//!    [`StreamingBorderGram::submit_chunk_gram`]; both producers route through
35//!    the one [`chunk_gram_flat`] free function, so a shipped partial is
36//!    bit-identical to the partial the coordinator would have computed from the
37//!    same rows.
38//!
39//! ## Chunk→rank assignment
40//!
41//! Round-robin by chunk index: rank `r` of `n_ranks` owns chunks
42//! `{j : j ≡ r (mod n_ranks)}`, in increasing order. Round-robin (rather than
43//! contiguous ranges) keeps the coordinator's in-order fold frontier advancing
44//! steadily while all ranks make progress at similar rates, which bounds the
45//! pending out-of-order buffer by O(`n_ranks` × inter-node skew) instead of
46//! O(total chunks). The assignment is a pure function of
47//! `(chunk_index, n_ranks)`; no scheduler, no work stealing — work stealing
48//! would not change bits (the fold is index-keyed) but it *would* break the
49//! one-cursor-per-rank resume model, so it is deliberately absent.
50//!
51//! Pure library: no networking, no flags, no environment variables. The
52//! transport (MPI, gRPC, files on a shared filesystem) is the caller's; this
53//! module owns the deterministic topology, the cursors, and the validation.
54
55use crate::streaming_border::{BorderGramCheckpoint, StreamingBorderGram, chunk_gram_flat};
56use ndarray::{Array2, ArrayView2};
57use serde::{Deserialize, Serialize};
58
59/// The deterministic chunk partition + rank-indexed assignment shared by every
60/// participant of one cross-node pass. A pure function of its four fields; two
61/// participants constructed with the same fields agree on every derived
62/// quantity, with no communication.
63#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
64pub struct CrossNodePartition {
65    /// Border dimension `k` (columns of every chunk; partials are `k·k`).
66    pub border_dim: usize,
67    /// Total row count of the full pass.
68    pub n_rows: usize,
69    /// Fixed chunk size (rows per chunk; the last chunk may be shorter).
70    pub chunk_size: usize,
71    /// Number of worker ranks in the fleet.
72    pub n_ranks: usize,
73}
74
75impl CrossNodePartition {
76    pub fn new(
77        border_dim: usize,
78        n_rows: usize,
79        chunk_size: usize,
80        n_ranks: usize,
81    ) -> Result<Self, String> {
82        if border_dim == 0 {
83            return Err("CrossNodePartition: border_dim must be positive".to_string());
84        }
85        if chunk_size == 0 {
86            return Err("CrossNodePartition: chunk_size must be positive".to_string());
87        }
88        if n_ranks == 0 {
89            return Err("CrossNodePartition: n_ranks must be positive".to_string());
90        }
91        Ok(Self {
92            border_dim,
93            n_rows,
94            chunk_size,
95            n_ranks,
96        })
97    }
98
99    /// Total number of chunks of the pass: `ceil(n_rows / chunk_size)`.
100    /// Identical to [`StreamingBorderGram::n_chunks`] for the same partition
101    /// parameters — the global tree this assignment feeds.
102    pub fn n_chunks(&self) -> usize {
103        self.n_rows.div_ceil(self.chunk_size)
104    }
105
106    /// Row range covered by global chunk `chunk_index` — the same pure function
107    /// as [`StreamingBorderGram::chunk_rows`], duplicated here so a worker can
108    /// slice its rows without constructing a coordinator-side accumulator.
109    pub fn chunk_rows(&self, chunk_index: usize) -> std::ops::Range<usize> {
110        let lo = chunk_index * self.chunk_size;
111        let hi = ((chunk_index + 1) * self.chunk_size).min(self.n_rows);
112        lo..hi
113    }
114
115    /// Which rank owns global chunk `chunk_index`: round-robin by index.
116    #[inline]
117    pub fn owner_rank(&self, chunk_index: usize) -> usize {
118        chunk_index % self.n_ranks
119    }
120
121    /// Number of chunks rank `rank` owns.
122    pub fn chunks_owned_by(&self, rank: usize) -> usize {
123        let n = self.n_chunks();
124        if rank >= self.n_ranks || n == 0 {
125            return 0;
126        }
127        // Chunks r, r + n_ranks, r + 2·n_ranks, … below n.
128        if rank < n {
129            (n - rank - 1) / self.n_ranks + 1
130        } else {
131            0
132        }
133    }
134
135    /// The `ordinal`-th (0-based) global chunk index owned by `rank`, or `None`
136    /// past the end of the rank's sequence. The worker cursor is an ordinal
137    /// into exactly this sequence.
138    pub fn owned_chunk(&self, rank: usize, ordinal: usize) -> Option<usize> {
139        if rank >= self.n_ranks {
140            return None;
141        }
142        let idx = rank + ordinal * self.n_ranks;
143        if idx < self.n_chunks() {
144            Some(idx)
145        } else {
146            None
147        }
148    }
149}
150
151/// One shipped partial: the global chunk index plus the deterministic `k·k`
152/// per-chunk Gram. This is the only message that crosses the node boundary —
153/// `k·k` f64 values per chunk, never rows.
154#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
155pub struct NodePartial {
156    /// Rank that produced this partial (its assignment is validated on
157    /// receipt, so a misconfigured worker is rejected loudly).
158    pub rank: usize,
159    /// Global chunk index of the partial.
160    pub chunk_index: usize,
161    /// Flattened `k·k` row-major per-chunk Gram, as produced by
162    /// [`chunk_gram_flat`] over the chunk's rows.
163    pub gram: Vec<f64>,
164}
165
166/// Serialized cursor of one worker: everything needed for a **replacement**
167/// process (same rank, any host) to continue the dead worker's deterministic
168/// chunk sequence from where receipts stopped. Pure data; the worker's row
169/// source re-seeks by row range, which is a pure function of the partition.
170#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
171pub struct NodeWorkerCheckpoint {
172    pub partition: CrossNodePartition,
173    pub rank: usize,
174    /// Ordinal (into the rank's owned-chunk sequence) of the next chunk to
175    /// compute and ship. Everything below it has been durably shipped.
176    pub next_ordinal: usize,
177}
178
179/// Worker-side driver for one rank: walks the rank's deterministic owned-chunk
180/// sequence, turning row slices into shippable [`NodePartial`]s.
181///
182/// The worker does **not** do I/O: the caller streams rows (from its shards /
183/// object store) for the row range [`NodeWorker::next_chunk_rows`] names, hands
184/// them to [`NodeWorker::emit`], and ships the returned partial. The cursor
185/// advances only on `emit`, so "ship durably, then checkpoint" gives exactly-
186/// once production under crash-resume (re-shipping an already-folded chunk is
187/// rejected by the coordinator as a duplicate, which is the safe failure).
188#[derive(Clone, Debug)]
189pub struct NodeWorker {
190    partition: CrossNodePartition,
191    rank: usize,
192    next_ordinal: usize,
193}
194
195impl NodeWorker {
196    /// Fresh worker for `rank`, starting at the beginning of its sequence.
197    pub fn new(partition: CrossNodePartition, rank: usize) -> Result<Self, String> {
198        if rank >= partition.n_ranks {
199            return Err(format!(
200                "NodeWorker: rank {rank} out of range (n_ranks = {})",
201                partition.n_ranks
202            ));
203        }
204        Ok(Self {
205            partition,
206            rank,
207            next_ordinal: 0,
208        })
209    }
210
211    /// Resume a (replacement) worker from a serialized cursor. Validates the
212    /// cursor against the partition so a checkpoint from a different pass is
213    /// rejected loudly.
214    pub fn resume(state: NodeWorkerCheckpoint) -> Result<Self, String> {
215        if state.rank >= state.partition.n_ranks {
216            return Err(format!(
217                "NodeWorkerCheckpoint: rank {} out of range (n_ranks = {})",
218                state.rank, state.partition.n_ranks
219            ));
220        }
221        let owned = state.partition.chunks_owned_by(state.rank);
222        if state.next_ordinal > owned {
223            return Err(format!(
224                "NodeWorkerCheckpoint: next_ordinal {} exceeds owned chunk count {owned}",
225                state.next_ordinal
226            ));
227        }
228        Ok(Self {
229            partition: state.partition,
230            rank: state.rank,
231            next_ordinal: state.next_ordinal,
232        })
233    }
234
235    /// Serialize the cursor. Write this (durably) after each successful ship.
236    pub fn checkpoint(&self) -> NodeWorkerCheckpoint {
237        NodeWorkerCheckpoint {
238            partition: self.partition,
239            rank: self.rank,
240            next_ordinal: self.next_ordinal,
241        }
242    }
243
244    /// `true` once this rank's sequence is exhausted.
245    pub fn is_done(&self) -> bool {
246        self.partition
247            .owned_chunk(self.rank, self.next_ordinal)
248            .is_none()
249    }
250
251    /// Global chunk index and row range of the next chunk to compute, or
252    /// `None` when done. The caller fetches exactly these rows.
253    pub fn next_chunk_rows(&self) -> Option<(usize, std::ops::Range<usize>)> {
254        let idx = self.partition.owned_chunk(self.rank, self.next_ordinal)?;
255        Some((idx, self.partition.chunk_rows(idx)))
256    }
257
258    /// Compute the next chunk's deterministic partial from its rows and advance
259    /// the cursor. `rows` must be exactly the rows of
260    /// [`NodeWorker::next_chunk_rows`] (shape-validated here; content is the
261    /// caller's contract, same as the in-process path).
262    pub fn emit(&mut self, rows: ArrayView2<'_, f64>) -> Result<NodePartial, String> {
263        let (chunk_index, range) = self
264            .next_chunk_rows()
265            .ok_or_else(|| format!("NodeWorker rank {}: sequence exhausted", self.rank))?;
266        if rows.nrows() != range.len() || rows.ncols() != self.partition.border_dim {
267            return Err(format!(
268                "NodeWorker rank {}: chunk {chunk_index} has shape ({}, {}) but expected ({}, {})",
269                self.rank,
270                rows.nrows(),
271                rows.ncols(),
272                range.len(),
273                self.partition.border_dim
274            ));
275        }
276        let gram = chunk_gram_flat(rows);
277        self.next_ordinal += 1;
278        Ok(NodePartial {
279            rank: self.rank,
280            chunk_index,
281            gram,
282        })
283    }
284}
285
286/// Serializable coordinator state: the inner accumulation state plus the
287/// per-rank receipt cursors.
288#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
289pub struct CrossNodeCheckpoint {
290    pub partition: CrossNodePartition,
291    /// The wrapped [`StreamingBorderGram`] state (fold forest + pending +
292    /// chunk frontier).
293    pub inner: BorderGramCheckpoint,
294    /// Per-rank count of partials received so far — each is an ordinal cursor
295    /// into that rank's deterministic owned-chunk sequence, used both for
296    /// receipt validation (in-sequence, no gaps per rank) and to tell a
297    /// restarted fleet where each rank should resume.
298    pub received_per_rank: Vec<usize>,
299}
300
301/// Coordinator-side reduction: receives [`NodePartial`]s from the fleet and
302/// folds them into the single global fixed-tree accumulator.
303///
304/// Receipt validation is per rank and in-sequence: rank `r`'s `i`-th accepted
305/// partial must be its `i`-th owned chunk. This makes the per-rank cursor in
306/// [`CrossNodeCheckpoint::received_per_rank`] a complete description of what
307/// has been received, which is what lets a dead rank resume from a bare
308/// ordinal. Cross-rank arrival order is unconstrained (the inner accumulator
309/// buffers out-of-order chunks), so slow nodes never block fast ones.
310pub struct CrossNodeGramReduction {
311    partition: CrossNodePartition,
312    inner: StreamingBorderGram,
313    received_per_rank: Vec<usize>,
314}
315
316impl CrossNodeGramReduction {
317    /// Fresh coordinator for the given partition.
318    pub fn new(partition: CrossNodePartition) -> Result<Self, String> {
319        let inner =
320            StreamingBorderGram::new(partition.border_dim, partition.n_rows, partition.chunk_size)?;
321        Ok(Self {
322            received_per_rank: vec![0; partition.n_ranks],
323            partition,
324            inner,
325        })
326    }
327
328    /// The shared partition (workers must be constructed with an equal one).
329    pub fn partition(&self) -> CrossNodePartition {
330        self.partition
331    }
332
333    /// How many partials rank `rank` has had accepted — the ordinal a
334    /// replacement worker for that rank should resume from.
335    pub fn rank_cursor(&self, rank: usize) -> Option<usize> {
336        self.received_per_rank.get(rank).copied()
337    }
338
339    /// `true` once every chunk of every rank has been received and folded.
340    pub fn is_complete(&self) -> bool {
341        self.inner.is_complete()
342    }
343
344    /// Receive one shipped partial. Validates rank, ownership, and per-rank
345    /// sequence position, then folds through
346    /// [`StreamingBorderGram::submit_chunk_gram`] (which re-validates index
347    /// range, duplicates, and partial shape). A duplicate of an already-folded
348    /// chunk — the signature of an at-least-once transport retry or a worker
349    /// that resumed from a stale cursor — is rejected with an error naming the
350    /// chunk, never silently double-counted.
351    pub fn receive(&mut self, partial: NodePartial) -> Result<(), String> {
352        let NodePartial {
353            rank,
354            chunk_index,
355            gram,
356        } = partial;
357        if rank >= self.partition.n_ranks {
358            return Err(format!(
359                "CrossNodeGramReduction: rank {rank} out of range (n_ranks = {})",
360                self.partition.n_ranks
361            ));
362        }
363        if self.partition.owner_rank(chunk_index) != rank {
364            return Err(format!(
365                "CrossNodeGramReduction: chunk {chunk_index} is owned by rank {}, not rank {rank}",
366                self.partition.owner_rank(chunk_index)
367            ));
368        }
369        let cursor = self.received_per_rank[rank];
370        match self.partition.owned_chunk(rank, cursor) {
371            Some(expected) if expected == chunk_index => {}
372            Some(expected) => {
373                return Err(format!(
374                    "CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} but its \
375                     cursor expects chunk {expected} (ordinal {cursor}); a worker resumed from \
376                     a stale or future checkpoint"
377                ));
378            }
379            None => {
380                return Err(format!(
381                    "CrossNodeGramReduction: rank {rank} shipped chunk {chunk_index} past the \
382                     end of its owned sequence"
383                ));
384            }
385        }
386        self.inner.submit_chunk_gram(chunk_index, gram)?;
387        self.received_per_rank[rank] = cursor + 1;
388        Ok(())
389    }
390
391    /// Serialize the full coordinator state. Resume-equals-straight-through is
392    /// inherited bit-for-bit from the inner accumulator; the per-rank cursors
393    /// resume receipt validation exactly where it stopped.
394    pub fn checkpoint(&self) -> CrossNodeCheckpoint {
395        CrossNodeCheckpoint {
396            partition: self.partition,
397            inner: self.inner.checkpoint(),
398            received_per_rank: self.received_per_rank.clone(),
399        }
400    }
401
402    /// Reconstruct a coordinator from a checkpoint, validating the cursor
403    /// structure against the partition so corruption is rejected loudly.
404    pub fn resume(state: CrossNodeCheckpoint) -> Result<Self, String> {
405        if state.received_per_rank.len() != state.partition.n_ranks {
406            return Err(format!(
407                "CrossNodeCheckpoint: {} rank cursors for n_ranks = {}",
408                state.received_per_rank.len(),
409                state.partition.n_ranks
410            ));
411        }
412        if state.inner.border_dim != state.partition.border_dim
413            || state.inner.n_rows != state.partition.n_rows
414            || state.inner.chunk_size != state.partition.chunk_size
415        {
416            return Err(
417                "CrossNodeCheckpoint: inner accumulator partition disagrees with the cross-node \
418                 partition"
419                    .to_string(),
420            );
421        }
422        for (rank, &cursor) in state.received_per_rank.iter().enumerate() {
423            if cursor > state.partition.chunks_owned_by(rank) {
424                return Err(format!(
425                    "CrossNodeCheckpoint: rank {rank} cursor {cursor} exceeds its owned chunk \
426                     count {}",
427                    state.partition.chunks_owned_by(rank)
428                ));
429            }
430        }
431        let inner = StreamingBorderGram::resume(state.inner)?;
432        Ok(Self {
433            partition: state.partition,
434            inner,
435            received_per_rank: state.received_per_rank,
436        })
437    }
438
439    /// Finish the pass, returning the `k×k` border Gram. Errors if any rank's
440    /// sequence is incomplete. The result is a pure function of the row content
441    /// and `(n_rows, chunk_size)` — identical bits for any node count, any
442    /// arrival interleaving, and any checkpoint/resume history on either side.
443    pub fn finish(self) -> Result<Array2<f64>, String> {
444        for (rank, &cursor) in self.received_per_rank.iter().enumerate() {
445            let owned = self.partition.chunks_owned_by(rank);
446            if cursor != owned {
447                return Err(format!(
448                    "CrossNodeGramReduction: finish() with rank {rank} at ordinal {cursor} of \
449                     {owned} owned chunks"
450                ));
451            }
452        }
453        self.inner.finish()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use ndarray::s;
461
462    /// Deterministic pseudo-random row matrix keyed purely by index (same
463    /// recipe as the streaming_border tests, so cross-file comparisons hold).
464    fn planted_rows(n: usize, k: usize) -> Array2<f64> {
465        Array2::from_shape_fn((n, k), |(i, j)| {
466            let x = (i as f64 + 1.0) * 0.7390851 + (j as f64 + 1.0) * 1.6180339;
467            (x.sin() * 43_758.547).fract() * 2.0 - 1.0
468        })
469    }
470
471    fn assert_bit_identical(a: &Array2<f64>, b: &Array2<f64>, label: &str) {
472        assert_eq!(a.dim(), b.dim(), "{label}: shape mismatch");
473        for ((idx, x), y) in a.indexed_iter().zip(b.iter()) {
474            assert_eq!(
475                x.to_bits(),
476                y.to_bits(),
477                "{label}: entry {idx:?} differs bitwise: {x:?} vs {y:?}"
478            );
479        }
480    }
481
482    /// Run a whole fleet of `n_ranks` workers to completion against one
483    /// coordinator, interleaving ranks round-robin with a deterministic skew so
484    /// arrival order exercises the pending buffer. Returns the final Gram.
485    fn run_fleet(rows: &Array2<f64>, chunk_size: usize, n_ranks: usize) -> Array2<f64> {
486        let partition =
487            CrossNodePartition::new(rows.ncols(), rows.nrows(), chunk_size, n_ranks).unwrap();
488        let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
489        let mut workers: Vec<NodeWorker> = (0..n_ranks)
490            .map(|r| NodeWorker::new(partition, r).unwrap())
491            .collect();
492        // Deterministic skewed interleaving: each sweep lets rank r ship
493        // (r % 3 + 1) chunks, so ranks run ahead/behind each other and the
494        // coordinator's out-of-order pending path is exercised.
495        let mut any_live = true;
496        while any_live {
497            any_live = false;
498            for (r, worker) in workers.iter_mut().enumerate() {
499                for _ in 0..(r % 3 + 1) {
500                    let Some((_, range)) = worker.next_chunk_rows() else {
501                        break;
502                    };
503                    let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
504                    coordinator.receive(partial).unwrap();
505                    any_live = true;
506                }
507                if !worker.is_done() {
508                    any_live = true;
509                }
510            }
511        }
512        assert!(coordinator.is_complete());
513        coordinator.finish().unwrap()
514    }
515
516    #[test]
517    fn node_count_never_changes_bits() {
518        // The frontier invariant: 1, 3, and 5 nodes produce the identical bit
519        // pattern, and all match the single-process StreamingBorderGram.
520        let n = 977; // not a multiple of the chunk size
521        let k = 4;
522        let chunk_size = 7;
523        let rows = planted_rows(n, k);
524
525        let mut single = StreamingBorderGram::new(k, n, chunk_size).unwrap();
526        for j in 0..single.n_chunks() {
527            let range = single.chunk_rows(j);
528            single.submit_chunk(j, rows.slice(s![range, ..])).unwrap();
529        }
530        let reference = single.finish().unwrap();
531
532        for n_ranks in [1usize, 3, 5] {
533            let fleet = run_fleet(&rows, chunk_size, n_ranks);
534            assert_bit_identical(
535                &reference,
536                &fleet,
537                &format!("single-process vs {n_ranks}-node fleet"),
538            );
539        }
540    }
541
542    #[test]
543    fn dead_node_resumes_from_cursor_bit_identically() {
544        let n = 530;
545        let k = 3;
546        let chunk_size = 5;
547        let n_ranks = 3;
548        let rows = planted_rows(n, k);
549        let reference = run_fleet(&rows, chunk_size, n_ranks);
550
551        let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
552        let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
553        let mut workers: Vec<NodeWorker> = (0..n_ranks)
554            .map(|r| NodeWorker::new(partition, r).unwrap())
555            .collect();
556
557        // Rank 1 ships 4 chunks, checkpoints durably, then "dies". The other
558        // ranks ship a few chunks too.
559        let mut rank1_cursor = None;
560        for (r, worker) in workers.iter_mut().enumerate() {
561            let ship = if r == 1 { 4 } else { 2 };
562            for _ in 0..ship {
563                let Some((_, range)) = worker.next_chunk_rows() else {
564                    break;
565                };
566                let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
567                coordinator.receive(partial).unwrap();
568            }
569            if r == 1 {
570                let json = serde_json::to_string(&worker.checkpoint()).unwrap();
571                rank1_cursor = Some(json);
572            }
573        }
574        workers.remove(1); // the death: rank-1 worker removed and dropped here
575
576        // Coordinator also survives a checkpoint round-trip mid-pass.
577        let coord_json = serde_json::to_string(&coordinator.checkpoint()).unwrap();
578        drop(coordinator);
579        let restored: CrossNodeCheckpoint = serde_json::from_str(&coord_json).unwrap();
580        let mut coordinator = CrossNodeGramReduction::resume(restored).unwrap();
581
582        // A replacement process resumes rank 1 from its serialized cursor; the
583        // coordinator's own cursor agrees with it.
584        let cp: NodeWorkerCheckpoint = serde_json::from_str(&rank1_cursor.unwrap()).unwrap();
585        assert_eq!(coordinator.rank_cursor(1), Some(cp.next_ordinal));
586        let replacement = NodeWorker::resume(cp).unwrap();
587        workers.insert(1, replacement);
588
589        // Drain the fleet to completion.
590        let mut any_live = true;
591        while any_live {
592            any_live = false;
593            for worker in workers.iter_mut() {
594                if let Some((_, range)) = worker.next_chunk_rows() {
595                    let partial = worker.emit(rows.slice(s![range, ..])).unwrap();
596                    coordinator.receive(partial).unwrap();
597                    any_live = true;
598                }
599            }
600        }
601        let resumed = coordinator.finish().unwrap();
602        assert_bit_identical(&reference, &resumed, "death-resume vs straight-through");
603    }
604
605    #[test]
606    fn receipt_validation_rejects_misrouted_and_out_of_sequence_partials() {
607        let n = 60;
608        let k = 2;
609        let chunk_size = 4; // 15 chunks
610        let n_ranks = 3;
611        let rows = planted_rows(n, k);
612        let partition = CrossNodePartition::new(k, n, chunk_size, n_ranks).unwrap();
613        let mut coordinator = CrossNodeGramReduction::new(partition).unwrap();
614        let mut w0 = NodeWorker::new(partition, 0).unwrap();
615
616        let (idx, range) = w0.next_chunk_rows().unwrap();
617        assert_eq!(idx, 0);
618        let mut partial = w0.emit(rows.slice(s![range, ..])).unwrap();
619
620        // Misrouted: claim the partial came from rank 1 (which does not own
621        // chunk 0).
622        partial.rank = 1;
623        let err = coordinator.receive(partial.clone()).unwrap_err();
624        assert!(err.contains("owned by rank 0"), "got: {err}");
625
626        // Correctly routed: accepted.
627        partial.rank = 0;
628        coordinator.receive(partial.clone()).unwrap();
629
630        // Duplicate (transport retry): rejected, never double-counted.
631        let err = coordinator.receive(partial).unwrap_err();
632        assert!(err.contains("cursor expects chunk 3"), "got: {err}");
633
634        // Out of sequence: rank 0's cursor expects its ordinal-1 chunk (global
635        // chunk 3), not its ordinal-2 chunk (global chunk 6).
636        let (idx, range) = w0.next_chunk_rows().unwrap();
637        assert_eq!(idx, 3);
638        let skipped = w0.emit(rows.slice(s![range, ..])).unwrap();
639        let (idx6, range6) = w0.next_chunk_rows().unwrap();
640        assert_eq!(idx6, 6);
641        let ahead = w0.emit(rows.slice(s![range6, ..])).unwrap();
642        let err = coordinator.receive(ahead).unwrap_err();
643        assert!(err.contains("expects chunk 3"), "got: {err}");
644        coordinator.receive(skipped).unwrap();
645    }
646
647    #[test]
648    fn assignment_is_a_pure_partition() {
649        // Every chunk is owned by exactly one rank, and the per-rank owned
650        // sequences tile [0, n_chunks) — for several fleet sizes.
651        for (n_rows, chunk_size, n_ranks) in [(100, 7, 1), (100, 7, 4), (3, 10, 8), (0, 5, 3)] {
652            let partition = CrossNodePartition::new(2, n_rows, chunk_size, n_ranks).unwrap();
653            let n_chunks = partition.n_chunks();
654            let mut seen = vec![false; n_chunks];
655            let mut total = 0usize;
656            for rank in 0..n_ranks {
657                let owned = partition.chunks_owned_by(rank);
658                for ordinal in 0..owned {
659                    let idx = partition.owned_chunk(rank, ordinal).unwrap();
660                    assert_eq!(partition.owner_rank(idx), rank);
661                    assert!(!seen[idx], "chunk {idx} assigned twice");
662                    seen[idx] = true;
663                    total += 1;
664                }
665                assert!(partition.owned_chunk(rank, owned).is_none());
666            }
667            assert_eq!(total, n_chunks, "assignment must tile all chunks");
668        }
669    }
670}