1use super::{Relay, async_trait};
4use crate::{Error, PartyId, Result, SessionId};
5use dashmap::DashMap;
6use serde::{Serialize, de::DeserializeOwned};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9
10#[derive(Debug)]
18pub struct MemoryRelay {
19 broadcasts: Arc<DashMap<(SessionId, u32), Vec<Vec<u8>>>>,
21 directs: Arc<DashMap<(SessionId, u32, PartyId), Vec<Vec<u8>>>>,
23 notify: broadcast::Sender<()>,
25 timeout_ms: u64,
27}
28
29impl MemoryRelay {
30 pub fn new() -> Self {
32 Self::with_timeout(30_000) }
34
35 pub fn with_timeout(timeout_ms: u64) -> Self {
37 let (notify, _) = broadcast::channel(1000);
38 Self {
39 broadcasts: Arc::new(DashMap::new()),
40 directs: Arc::new(DashMap::new()),
41 notify,
42 timeout_ms,
43 }
44 }
45
46 pub fn clear(&self) {
48 self.broadcasts.clear();
49 self.directs.clear();
50 }
51
52 pub fn broadcast_count(&self, session_id: &SessionId, round: u32) -> usize {
54 self.broadcasts
55 .get(&(*session_id, round))
56 .map(|v| v.len())
57 .unwrap_or(0)
58 }
59
60 pub fn direct_count(&self, session_id: &SessionId, round: u32, to: PartyId) -> usize {
62 self.directs
63 .get(&(*session_id, round, to))
64 .map(|v| v.len())
65 .unwrap_or(0)
66 }
67}
68
69impl Default for MemoryRelay {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl Clone for MemoryRelay {
76 fn clone(&self) -> Self {
77 Self {
78 broadcasts: Arc::clone(&self.broadcasts),
79 directs: Arc::clone(&self.directs),
80 notify: self.notify.clone(),
81 timeout_ms: self.timeout_ms,
82 }
83 }
84}
85
86fn serialize<T: Serialize>(value: &T) -> Result<Vec<u8>> {
87 serde_json::to_vec(value).map_err(|e| Error::Serialization(e.to_string()))
88}
89
90fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T> {
91 serde_json::from_slice(bytes).map_err(|e| Error::Deserialization(e.to_string()))
92}
93
94#[async_trait]
95impl Relay for MemoryRelay {
96 async fn broadcast<T: Serialize + Send + Sync>(
97 &self,
98 session_id: &SessionId,
99 round: u32,
100 message: &T,
101 ) -> Result<()> {
102 let bytes = serialize(message)?;
103
104 self.broadcasts
105 .entry((*session_id, round))
106 .or_default()
107 .push(bytes);
108
109 let _ = self.notify.send(());
111 Ok(())
112 }
113
114 async fn send_direct<T: Serialize + Send + Sync>(
115 &self,
116 session_id: &SessionId,
117 round: u32,
118 to: PartyId,
119 message: &T,
120 ) -> Result<()> {
121 let bytes = serialize(message)?;
122
123 self.directs
124 .entry((*session_id, round, to))
125 .or_default()
126 .push(bytes);
127
128 let _ = self.notify.send(());
130 Ok(())
131 }
132
133 async fn collect_broadcasts<T: DeserializeOwned + Send>(
134 &self,
135 session_id: &SessionId,
136 round: u32,
137 count: usize,
138 ) -> Result<Vec<T>> {
139 let mut rx = self.notify.subscribe();
140 let deadline =
141 std::time::Instant::now() + std::time::Duration::from_millis(self.timeout_ms);
142
143 loop {
144 if let Some(messages) = self.broadcasts.get(&(*session_id, round)) {
146 if messages.len() >= count {
147 let result: Result<Vec<T>> = messages
148 .iter()
149 .take(count)
150 .map(|bytes| deserialize(bytes))
151 .collect();
152 return result;
153 }
154 }
155
156 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
158 if remaining.is_zero() {
159 return Err(Error::Timeout(format!(
160 "Waiting for {} broadcast messages in round {}",
161 count, round
162 )));
163 }
164
165 tokio::select! {
167 _ = rx.recv() => continue,
168 _ = tokio::time::sleep(std::time::Duration::from_millis(100).min(remaining)) => continue,
169 }
170 }
171 }
172
173 async fn collect_direct<T: DeserializeOwned + Send>(
174 &self,
175 session_id: &SessionId,
176 round: u32,
177 my_id: PartyId,
178 count: usize,
179 ) -> Result<Vec<T>> {
180 let mut rx = self.notify.subscribe();
181 let deadline =
182 std::time::Instant::now() + std::time::Duration::from_millis(self.timeout_ms);
183
184 loop {
185 if let Some(messages) = self.directs.get(&(*session_id, round, my_id)) {
187 if messages.len() >= count {
188 let result: Result<Vec<T>> = messages
189 .iter()
190 .take(count)
191 .map(|bytes| deserialize(bytes))
192 .collect();
193 return result;
194 }
195 }
196
197 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
199 if remaining.is_zero() {
200 return Err(Error::Timeout(format!(
201 "Waiting for {} direct messages to party {} in round {}",
202 count, my_id, round
203 )));
204 }
205
206 tokio::select! {
208 _ = rx.recv() => continue,
209 _ = tokio::time::sleep(std::time::Duration::from_millis(100).min(remaining)) => continue,
210 }
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use serde::{Deserialize, Serialize};
219
220 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
221 struct TestMessage {
222 value: u32,
223 data: String,
224 }
225
226 #[tokio::test]
227 async fn test_broadcast() {
228 let relay = MemoryRelay::new();
229 let session_id = [0u8; 32];
230
231 relay
232 .broadcast(
233 &session_id,
234 1,
235 &TestMessage {
236 value: 42,
237 data: "hello".to_string(),
238 },
239 )
240 .await
241 .unwrap();
242
243 relay
244 .broadcast(
245 &session_id,
246 1,
247 &TestMessage {
248 value: 43,
249 data: "world".to_string(),
250 },
251 )
252 .await
253 .unwrap();
254
255 let messages: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 2).await.unwrap();
256
257 assert_eq!(messages.len(), 2);
258 assert_eq!(messages[0].value, 42);
259 assert_eq!(messages[1].value, 43);
260 }
261
262 #[tokio::test]
263 async fn test_direct() {
264 let relay = MemoryRelay::new();
265 let session_id = [0u8; 32];
266
267 relay
268 .send_direct(
269 &session_id,
270 1,
271 0,
272 &TestMessage {
273 value: 100,
274 data: "direct".to_string(),
275 },
276 )
277 .await
278 .unwrap();
279
280 let messages: Vec<TestMessage> = relay.collect_direct(&session_id, 1, 0, 1).await.unwrap();
281
282 assert_eq!(messages.len(), 1);
283 assert_eq!(messages[0].value, 100);
284 }
285
286 #[tokio::test]
287 async fn test_concurrent_broadcast() {
288 let relay = MemoryRelay::new();
289 let session_id = [0u8; 32];
290
291 let handles: Vec<_> = (0..3)
293 .map(|i| {
294 let r = relay.clone();
295 let sid = session_id;
296 tokio::spawn(async move {
297 r.broadcast(
298 &sid,
299 1,
300 &TestMessage {
301 value: i,
302 data: format!("msg-{}", i),
303 },
304 )
305 .await
306 })
307 })
308 .collect();
309
310 for h in handles {
312 h.await.unwrap().unwrap();
313 }
314
315 let messages: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 3).await.unwrap();
317 assert_eq!(messages.len(), 3);
318 }
319
320 #[tokio::test]
321 async fn test_timeout() {
322 let relay = MemoryRelay::with_timeout(100); let session_id = [0u8; 32];
324
325 relay
327 .broadcast(
328 &session_id,
329 1,
330 &TestMessage {
331 value: 1,
332 data: "only one".to_string(),
333 },
334 )
335 .await
336 .unwrap();
337
338 let result: Result<Vec<TestMessage>> = relay.collect_broadcasts(&session_id, 1, 2).await;
339 assert!(result.is_err());
340 assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
341 }
342
343 #[tokio::test]
344 async fn test_separate_sessions() {
345 let relay = MemoryRelay::new();
346 let session1 = [1u8; 32];
347 let session2 = [2u8; 32];
348
349 relay
350 .broadcast(
351 &session1,
352 1,
353 &TestMessage {
354 value: 1,
355 data: "s1".to_string(),
356 },
357 )
358 .await
359 .unwrap();
360
361 relay
362 .broadcast(
363 &session2,
364 1,
365 &TestMessage {
366 value: 2,
367 data: "s2".to_string(),
368 },
369 )
370 .await
371 .unwrap();
372
373 let msgs1: Vec<TestMessage> = relay.collect_broadcasts(&session1, 1, 1).await.unwrap();
374 let msgs2: Vec<TestMessage> = relay.collect_broadcasts(&session2, 1, 1).await.unwrap();
375
376 assert_eq!(msgs1[0].value, 1);
377 assert_eq!(msgs2[0].value, 2);
378 }
379
380 #[tokio::test]
381 async fn test_separate_rounds() {
382 let relay = MemoryRelay::new();
383 let session_id = [0u8; 32];
384
385 relay
386 .broadcast(
387 &session_id,
388 1,
389 &TestMessage {
390 value: 1,
391 data: "r1".to_string(),
392 },
393 )
394 .await
395 .unwrap();
396
397 relay
398 .broadcast(
399 &session_id,
400 2,
401 &TestMessage {
402 value: 2,
403 data: "r2".to_string(),
404 },
405 )
406 .await
407 .unwrap();
408
409 let msgs1: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 1).await.unwrap();
410 let msgs2: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 2, 1).await.unwrap();
411
412 assert_eq!(msgs1[0].value, 1);
413 assert_eq!(msgs2[0].value, 2);
414 }
415
416 #[test]
417 fn test_clear() {
418 let relay = MemoryRelay::new();
419 let session_id = [0u8; 32];
420
421 relay.broadcasts.insert(
423 (session_id, 1),
424 vec![
425 serde_json::to_vec(&TestMessage {
426 value: 1,
427 data: "test".to_string(),
428 })
429 .unwrap(),
430 ],
431 );
432
433 assert_eq!(relay.broadcast_count(&session_id, 1), 1);
434
435 relay.clear();
436
437 assert_eq!(relay.broadcast_count(&session_id, 1), 0);
438 }
439}