1use std::sync::{Arc, RwLock};
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
48pub struct Rank(pub u32);
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub struct SymmetricBuffer {
55 pub rank: Rank,
56 pub offset: usize,
57 pub len: usize,
58}
59
60pub trait SymmetricTransport {
65 fn num_ranks(&self) -> u32;
67 fn this_rank(&self) -> Rank;
69
70 fn put(&self, buf: SymmetricBuffer, src: &[u8]) -> Result<(), CollectiveError>;
72 fn get(&self, buf: SymmetricBuffer, dst: &mut [u8]) -> Result<(), CollectiveError>;
74
75 fn barrier(&self) -> Result<(), CollectiveError>;
79}
80
81#[derive(Debug, Clone)]
82pub enum CollectiveError {
83 OutOfBounds {
85 rank: Rank,
86 offset: usize,
87 len: usize,
88 heap_size: usize,
89 },
90 LengthMismatch { expected: usize, got: usize },
92 UnknownRank { rank: Rank, num_ranks: u32 },
94 TransportError { reason: String },
96}
97
98impl std::fmt::Display for CollectiveError {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 match self {
101 Self::OutOfBounds {
102 rank,
103 offset,
104 len,
105 heap_size,
106 } => write!(
107 f,
108 "OOB on rank {}: offset {offset} + len {len} > heap_size {heap_size}",
109 rank.0
110 ),
111 Self::LengthMismatch { expected, got } => {
112 write!(f, "length mismatch: expected {expected}, got {got}")
113 }
114 Self::UnknownRank { rank, num_ranks } => {
115 write!(f, "unknown rank {} (have {num_ranks})", rank.0)
116 }
117 Self::TransportError { reason } => write!(f, "transport: {reason}"),
118 }
119 }
120}
121
122impl std::error::Error for CollectiveError {}
123
124#[derive(Debug)]
127pub struct SymmetricHeap {
128 storage: Vec<Arc<RwLock<Vec<u8>>>>,
129 pub heap_size: usize,
130}
131
132impl SymmetricHeap {
133 pub fn new(num_ranks: u32, heap_size: usize) -> Self {
134 let storage = (0..num_ranks)
135 .map(|_| Arc::new(RwLock::new(vec![0u8; heap_size])))
136 .collect();
137 Self { storage, heap_size }
138 }
139
140 pub fn num_ranks(&self) -> u32 {
141 self.storage.len() as u32
142 }
143
144 pub fn rank_view(&self, rank: Rank) -> Option<Arc<RwLock<Vec<u8>>>> {
145 self.storage.get(rank.0 as usize).cloned()
146 }
147}
148
149#[derive(Debug, Clone)]
155pub struct LocalTransport {
156 heap: Arc<SymmetricHeap>,
157 me: Rank,
158 barrier_count: Arc<std::sync::atomic::AtomicU32>,
159 barrier_target: u32,
160}
161
162impl LocalTransport {
163 pub fn new(num_ranks: u32, heap_size: usize, this_rank: Rank) -> Self {
164 let heap = Arc::new(SymmetricHeap::new(num_ranks, heap_size));
165 Self::with_heap(heap, this_rank)
166 }
167
168 pub fn fan_out(num_ranks: u32, heap_size: usize) -> Vec<Self> {
172 let heap = Arc::new(SymmetricHeap::new(num_ranks, heap_size));
173 (0..num_ranks)
174 .map(|i| Self::with_heap(heap.clone(), Rank(i)))
175 .collect()
176 }
177
178 fn with_heap(heap: Arc<SymmetricHeap>, me: Rank) -> Self {
179 let n = heap.num_ranks();
180 Self {
181 heap,
182 me,
183 barrier_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
184 barrier_target: n,
185 }
186 }
187
188 fn check_buf(&self, buf: SymmetricBuffer) -> Result<(), CollectiveError> {
189 if buf.rank.0 >= self.heap.num_ranks() {
190 return Err(CollectiveError::UnknownRank {
191 rank: buf.rank,
192 num_ranks: self.heap.num_ranks(),
193 });
194 }
195 if buf.offset + buf.len > self.heap.heap_size {
196 return Err(CollectiveError::OutOfBounds {
197 rank: buf.rank,
198 offset: buf.offset,
199 len: buf.len,
200 heap_size: self.heap.heap_size,
201 });
202 }
203 Ok(())
204 }
205}
206
207impl SymmetricTransport for LocalTransport {
208 fn num_ranks(&self) -> u32 {
209 self.heap.num_ranks()
210 }
211 fn this_rank(&self) -> Rank {
212 self.me
213 }
214
215 fn put(&self, buf: SymmetricBuffer, src: &[u8]) -> Result<(), CollectiveError> {
216 self.check_buf(buf)?;
217 if src.len() != buf.len {
218 return Err(CollectiveError::LengthMismatch {
219 expected: buf.len,
220 got: src.len(),
221 });
222 }
223 let view = self.heap.rank_view(buf.rank).expect("checked above");
224 let mut guard = view.write().unwrap();
225 guard[buf.offset..buf.offset + buf.len].copy_from_slice(src);
226 Ok(())
227 }
228
229 fn get(&self, buf: SymmetricBuffer, dst: &mut [u8]) -> Result<(), CollectiveError> {
230 self.check_buf(buf)?;
231 if dst.len() != buf.len {
232 return Err(CollectiveError::LengthMismatch {
233 expected: buf.len,
234 got: dst.len(),
235 });
236 }
237 let view = self.heap.rank_view(buf.rank).expect("checked above");
238 let guard = view.read().unwrap();
239 dst.copy_from_slice(&guard[buf.offset..buf.offset + buf.len]);
240 Ok(())
241 }
242
243 fn barrier(&self) -> Result<(), CollectiveError> {
244 use std::sync::atomic::Ordering;
250 self.barrier_count.fetch_add(1, Ordering::AcqRel);
251 while self.barrier_count.load(Ordering::Acquire) < self.barrier_target {
254 std::hint::spin_loop();
255 }
256 Ok(())
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn put_then_get_round_trips() {
266 let t = LocalTransport::new(4, 1024, Rank(0));
267 let buf = SymmetricBuffer {
268 rank: Rank(2),
269 offset: 16,
270 len: 8,
271 };
272 t.put(buf, &[1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
273 let mut dst = [0u8; 8];
274 t.get(buf, &mut dst).unwrap();
275 assert_eq!(&dst, &[1, 2, 3, 4, 5, 6, 7, 8]);
276 }
277
278 #[test]
279 fn fan_out_yields_one_transport_per_rank() {
280 let ts = LocalTransport::fan_out(3, 64);
281 assert_eq!(ts.len(), 3);
282 for (i, t) in ts.iter().enumerate() {
283 assert_eq!(t.this_rank(), Rank(i as u32));
284 assert_eq!(t.num_ranks(), 3);
285 }
286 }
287
288 #[test]
289 fn put_visible_to_other_rank_via_shared_heap() {
290 let ts = LocalTransport::fan_out(3, 32);
293 let payload = [9u8, 9, 9, 9];
294 ts[0]
295 .put(
296 SymmetricBuffer {
297 rank: Rank(2),
298 offset: 0,
299 len: 4,
300 },
301 &payload,
302 )
303 .unwrap();
304 let mut dst = [0u8; 4];
305 ts[2]
306 .get(
307 SymmetricBuffer {
308 rank: Rank(2),
309 offset: 0,
310 len: 4,
311 },
312 &mut dst,
313 )
314 .unwrap();
315 assert_eq!(dst, payload);
316 }
317
318 #[test]
319 fn oob_offset_errors() {
320 let t = LocalTransport::new(2, 8, Rank(0));
321 let err = t
322 .put(
323 SymmetricBuffer {
324 rank: Rank(1),
325 offset: 4,
326 len: 8,
327 },
328 &[0u8; 8],
329 )
330 .unwrap_err();
331 assert!(matches!(err, CollectiveError::OutOfBounds { .. }));
332 }
333
334 #[test]
335 fn unknown_rank_errors() {
336 let t = LocalTransport::new(2, 8, Rank(0));
337 let err = t
338 .get(
339 SymmetricBuffer {
340 rank: Rank(99),
341 offset: 0,
342 len: 4,
343 },
344 &mut [0u8; 4],
345 )
346 .unwrap_err();
347 assert!(matches!(err, CollectiveError::UnknownRank { .. }));
348 }
349
350 #[test]
351 fn length_mismatch_errors() {
352 let t = LocalTransport::new(2, 32, Rank(0));
353 let err = t
354 .put(
355 SymmetricBuffer {
356 rank: Rank(1),
357 offset: 0,
358 len: 4,
359 },
360 &[0u8; 8],
361 )
362 .unwrap_err();
363 assert!(matches!(err, CollectiveError::LengthMismatch { .. }));
364 }
365}