1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use crate::error::{NexarError, Result};
use crate::rpc::registry::RpcRegistry;
use crate::types::Rank;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Mutex, RwLock};
use super::NexarClient;
use super::async_client::RawRecvSource;
use super::hash::fnv1a_comm_id;
impl NexarClient {
/// Split this communicator into sub-groups.
///
/// All ranks must call `split` with the same arguments simultaneously.
/// Ranks with the same `color` end up in the same sub-communicator.
/// Within each group, ranks are ordered by `key` (ties broken by original rank).
///
/// The returned client has new rank/world_size within the sub-group and uses
/// a unique `comm_id` for its raw stream traffic, so collectives on the
/// sub-communicator don't interfere with the parent or other sub-groups.
///
/// The parent client's routers demux raw streams by `comm_id`, so the parent
/// must remain alive for the duration of the split client's use.
pub async fn split(&self, color: u32, key: u32) -> Result<NexarClient> {
let world = self.world_size as usize;
let rank = self.rank;
// Step 1: Exchange (color, key, split_generation) tuples with all peers.
// Encode as 16 bytes: [color: u32 LE][key: u32 LE][generation: u64 LE].
let current_gen = self.split_generation.load(Ordering::Relaxed);
let mut my_info = [0u8; 16];
my_info[..4].copy_from_slice(&color.to_le_bytes());
my_info[4..8].copy_from_slice(&key.to_le_bytes());
my_info[8..16].copy_from_slice(¤t_gen.to_le_bytes());
// AllGather the info from all ranks.
let mut all_info = vec![0u8; 16 * world];
all_info[rank as usize * 16..(rank as usize + 1) * 16].copy_from_slice(&my_info);
// Use the existing allgather collective. We pass raw pointers to our
// stack-allocated buffers.
let send_ptr = my_info.as_ptr() as u64;
let recv_ptr = all_info.as_mut_ptr() as u64;
let tag = self.next_collective_tag();
unsafe {
crate::collective::ring_allgather(
self,
send_ptr,
recv_ptr,
16, // 16 bytes per rank
crate::types::DataType::U8,
Some(tag),
)
.await?;
}
// Step 2: Parse all (color, key) tuples and validate generation consistency.
let mut entries: Vec<(Rank, u32, u32)> = Vec::with_capacity(world);
for r in 0..world {
let off = r * 16;
let c = u32::from_le_bytes(
all_info[off..off + 4]
.try_into()
.map_err(|_| NexarError::DecodeFailed("split color bytes".into()))?,
);
let k = u32::from_le_bytes(
all_info[off + 4..off + 8]
.try_into()
.map_err(|_| NexarError::DecodeFailed("split key bytes".into()))?,
);
let peer_gen = u64::from_le_bytes(
all_info[off + 8..off + 16]
.try_into()
.map_err(|_| NexarError::DecodeFailed("split generation bytes".into()))?,
);
if peer_gen != current_gen {
return Err(NexarError::CollectiveFailed {
operation: "split",
rank: r as Rank,
reason: format!(
"split_generation mismatch: local={current_gen}, rank {r}={peer_gen}"
),
});
}
entries.push((r as Rank, c, k));
}
// Step 3: Find our group (same color), sort by (key, original_rank).
let my_color = color;
let mut group: Vec<(Rank, u32)> = entries
.iter()
.filter(|&&(_, c, _)| c == my_color)
.map(|&(r, _, k)| (r, k))
.collect();
group.sort_by_key(|&(orig_rank, k)| (k, orig_rank));
let new_world_size = group.len() as u32;
let new_rank =
group
.iter()
.position(|&(r, _)| r == rank)
.ok_or(NexarError::CollectiveFailed {
operation: "split",
rank,
reason: "rank not found in its own color group".into(),
})? as Rank;
// Step 4: Generate a deterministic comm_id agreed upon by all ranks.
// All ranks in this communicator advance split_generation in lockstep
// (split is collective). Combine parent comm_id, generation, and color
// to produce a unique comm_id per split group.
let split_gen = self.split_generation.fetch_add(1, Ordering::Relaxed);
let new_comm_id = fnv1a_comm_id([
&self.comm_id.to_le_bytes()[..],
&split_gen.to_le_bytes(),
&my_color.to_le_bytes(),
]);
// Step 5: Build rank_map (new_rank -> original_rank) and peer subset.
let mut rank_map = HashMap::new();
let mut new_peers = HashMap::new();
let mut comm_receivers = HashMap::new();
for (new_r, &(orig_rank, _)) in group.iter().enumerate() {
let new_r = new_r as Rank;
rank_map.insert(new_r, orig_rank);
if orig_rank != rank {
// Share the parent's PeerConnection (keyed by original rank).
let peer = self.peer(orig_rank)?;
new_peers.insert(new_r, Arc::clone(peer));
// Register a per-comm_id channel on the parent's router for this peer.
let original_rank_key = orig_rank;
let router =
self.routers
.get(&original_rank_key)
.ok_or(NexarError::UnknownPeer {
rank: original_rank_key,
})?;
let rx = router.register_comm(new_comm_id).await;
comm_receivers.insert(new_r, Mutex::new(rx));
}
}
// Step 6: Build the split client. It shares the parent's routers
// but uses comm-specific raw channels.
// Note: The split client doesn't own routers or router handles — it
// borrows the parent's routers indirectly through the registered comm channels.
// Control/data/RPC lanes are still on the parent's routers.
// Split clients don't have their own routers — they use comm channels for
// collective raw data. Barrier uses comm_id in the message to avoid cross-talk.
Ok(NexarClient {
rank: new_rank,
world_size: new_world_size,
comm_id: new_comm_id,
peers: new_peers,
routers: HashMap::new(), // Split clients don't own routers
raw_recv: RawRecvSource::Comm(comm_receivers),
_router_handles: Vec::new(),
adapter: Arc::clone(&self.adapter),
_pool: Arc::clone(&self._pool),
barrier_epoch: AtomicU64::new(0),
rpc_registry: Arc::new(RwLock::new(RpcRegistry::new())),
rpc_req_id: AtomicU64::new(0),
split_generation: AtomicU64::new(0),
rank_map,
collective_tag: AtomicU64::new(1),
tagged_receivers: tokio::sync::Mutex::new(HashMap::new()),
config: Arc::clone(&self.config),
failure_tx: Arc::clone(&self.failure_tx),
failure_rx: self.failure_rx.clone(),
_monitor_handle: None, // Split clients share the parent's monitor
routing_table: self.routing_table.clone(),
relay_deliveries: self.relay_deliveries.clone(),
_relay_handles: Vec::new(), // Split clients share the parent's relay
_endpoints: Vec::new(),
})
}
}