mpc_wallet_core/mpc/
memory.rs

1//! In-memory relay implementation for testing and local development
2
3use super::{Relay, async_trait};
4use crate::{Error, PartyId, Result, SessionId};
5use dashmap::DashMap;
6use serde::{Serialize, de::DeserializeOwned};
7use std::sync::Arc;
8use tokio::sync::broadcast;
9
10/// In-memory message relay for local testing
11///
12/// This relay stores all messages in memory and uses channels for notification.
13/// It's useful for:
14/// - Unit and integration testing
15/// - Local development
16/// - Single-process multi-party simulation
17#[derive(Debug)]
18pub struct MemoryRelay {
19    /// Broadcast messages: (session_id, round) -> Vec<message_bytes>
20    broadcasts: Arc<DashMap<(SessionId, u32), Vec<Vec<u8>>>>,
21    /// Direct messages: (session_id, round, to) -> Vec<message_bytes>
22    directs: Arc<DashMap<(SessionId, u32, PartyId), Vec<Vec<u8>>>>,
23    /// Notification channel for new messages
24    notify: broadcast::Sender<()>,
25    /// Timeout for waiting on messages (milliseconds)
26    timeout_ms: u64,
27}
28
29impl MemoryRelay {
30    /// Create a new in-memory relay with default timeout
31    pub fn new() -> Self {
32        Self::with_timeout(30_000) // 30 seconds default
33    }
34
35    /// Create a new in-memory relay with custom timeout
36    pub fn with_timeout(timeout_ms: u64) -> Self {
37        let (notify, _) = broadcast::channel(1000);
38        Self {
39            broadcasts: Arc::new(DashMap::new()),
40            directs: Arc::new(DashMap::new()),
41            notify,
42            timeout_ms,
43        }
44    }
45
46    /// Clear all messages (useful for test cleanup)
47    pub fn clear(&self) {
48        self.broadcasts.clear();
49        self.directs.clear();
50    }
51
52    /// Get the number of broadcast messages for a session/round
53    pub fn broadcast_count(&self, session_id: &SessionId, round: u32) -> usize {
54        self.broadcasts
55            .get(&(*session_id, round))
56            .map(|v| v.len())
57            .unwrap_or(0)
58    }
59
60    /// Get the number of direct messages for a party
61    pub fn direct_count(&self, session_id: &SessionId, round: u32, to: PartyId) -> usize {
62        self.directs
63            .get(&(*session_id, round, to))
64            .map(|v| v.len())
65            .unwrap_or(0)
66    }
67}
68
69impl Default for MemoryRelay {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl Clone for MemoryRelay {
76    fn clone(&self) -> Self {
77        Self {
78            broadcasts: Arc::clone(&self.broadcasts),
79            directs: Arc::clone(&self.directs),
80            notify: self.notify.clone(),
81            timeout_ms: self.timeout_ms,
82        }
83    }
84}
85
86fn serialize<T: Serialize>(value: &T) -> Result<Vec<u8>> {
87    serde_json::to_vec(value).map_err(|e| Error::Serialization(e.to_string()))
88}
89
90fn deserialize<T: DeserializeOwned>(bytes: &[u8]) -> Result<T> {
91    serde_json::from_slice(bytes).map_err(|e| Error::Deserialization(e.to_string()))
92}
93
94#[async_trait]
95impl Relay for MemoryRelay {
96    async fn broadcast<T: Serialize + Send + Sync>(
97        &self,
98        session_id: &SessionId,
99        round: u32,
100        message: &T,
101    ) -> Result<()> {
102        let bytes = serialize(message)?;
103
104        self.broadcasts
105            .entry((*session_id, round))
106            .or_default()
107            .push(bytes);
108
109        // Notify waiting collectors
110        let _ = self.notify.send(());
111        Ok(())
112    }
113
114    async fn send_direct<T: Serialize + Send + Sync>(
115        &self,
116        session_id: &SessionId,
117        round: u32,
118        to: PartyId,
119        message: &T,
120    ) -> Result<()> {
121        let bytes = serialize(message)?;
122
123        self.directs
124            .entry((*session_id, round, to))
125            .or_default()
126            .push(bytes);
127
128        // Notify waiting collectors
129        let _ = self.notify.send(());
130        Ok(())
131    }
132
133    async fn collect_broadcasts<T: DeserializeOwned + Send>(
134        &self,
135        session_id: &SessionId,
136        round: u32,
137        count: usize,
138    ) -> Result<Vec<T>> {
139        let mut rx = self.notify.subscribe();
140        let deadline =
141            std::time::Instant::now() + std::time::Duration::from_millis(self.timeout_ms);
142
143        loop {
144            // Check if we have enough messages
145            if let Some(messages) = self.broadcasts.get(&(*session_id, round)) {
146                if messages.len() >= count {
147                    let result: Result<Vec<T>> = messages
148                        .iter()
149                        .take(count)
150                        .map(|bytes| deserialize(bytes))
151                        .collect();
152                    return result;
153                }
154            }
155
156            // Check timeout
157            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
158            if remaining.is_zero() {
159                return Err(Error::Timeout(format!(
160                    "Waiting for {} broadcast messages in round {}",
161                    count, round
162                )));
163            }
164
165            // Wait for notification or timeout
166            tokio::select! {
167                _ = rx.recv() => continue,
168                _ = tokio::time::sleep(std::time::Duration::from_millis(100).min(remaining)) => continue,
169            }
170        }
171    }
172
173    async fn collect_direct<T: DeserializeOwned + Send>(
174        &self,
175        session_id: &SessionId,
176        round: u32,
177        my_id: PartyId,
178        count: usize,
179    ) -> Result<Vec<T>> {
180        let mut rx = self.notify.subscribe();
181        let deadline =
182            std::time::Instant::now() + std::time::Duration::from_millis(self.timeout_ms);
183
184        loop {
185            // Check if we have enough messages
186            if let Some(messages) = self.directs.get(&(*session_id, round, my_id)) {
187                if messages.len() >= count {
188                    let result: Result<Vec<T>> = messages
189                        .iter()
190                        .take(count)
191                        .map(|bytes| deserialize(bytes))
192                        .collect();
193                    return result;
194                }
195            }
196
197            // Check timeout
198            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
199            if remaining.is_zero() {
200                return Err(Error::Timeout(format!(
201                    "Waiting for {} direct messages to party {} in round {}",
202                    count, my_id, round
203                )));
204            }
205
206            // Wait for notification or timeout
207            tokio::select! {
208                _ = rx.recv() => continue,
209                _ = tokio::time::sleep(std::time::Duration::from_millis(100).min(remaining)) => continue,
210            }
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use serde::{Deserialize, Serialize};
219
220    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
221    struct TestMessage {
222        value: u32,
223        data: String,
224    }
225
226    #[tokio::test]
227    async fn test_broadcast() {
228        let relay = MemoryRelay::new();
229        let session_id = [0u8; 32];
230
231        relay
232            .broadcast(
233                &session_id,
234                1,
235                &TestMessage {
236                    value: 42,
237                    data: "hello".to_string(),
238                },
239            )
240            .await
241            .unwrap();
242
243        relay
244            .broadcast(
245                &session_id,
246                1,
247                &TestMessage {
248                    value: 43,
249                    data: "world".to_string(),
250                },
251            )
252            .await
253            .unwrap();
254
255        let messages: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 2).await.unwrap();
256
257        assert_eq!(messages.len(), 2);
258        assert_eq!(messages[0].value, 42);
259        assert_eq!(messages[1].value, 43);
260    }
261
262    #[tokio::test]
263    async fn test_direct() {
264        let relay = MemoryRelay::new();
265        let session_id = [0u8; 32];
266
267        relay
268            .send_direct(
269                &session_id,
270                1,
271                0,
272                &TestMessage {
273                    value: 100,
274                    data: "direct".to_string(),
275                },
276            )
277            .await
278            .unwrap();
279
280        let messages: Vec<TestMessage> = relay.collect_direct(&session_id, 1, 0, 1).await.unwrap();
281
282        assert_eq!(messages.len(), 1);
283        assert_eq!(messages[0].value, 100);
284    }
285
286    #[tokio::test]
287    async fn test_concurrent_broadcast() {
288        let relay = MemoryRelay::new();
289        let session_id = [0u8; 32];
290
291        // Spawn multiple broadcasters
292        let handles: Vec<_> = (0..3)
293            .map(|i| {
294                let r = relay.clone();
295                let sid = session_id;
296                tokio::spawn(async move {
297                    r.broadcast(
298                        &sid,
299                        1,
300                        &TestMessage {
301                            value: i,
302                            data: format!("msg-{}", i),
303                        },
304                    )
305                    .await
306                })
307            })
308            .collect();
309
310        // Wait for all broadcasts
311        for h in handles {
312            h.await.unwrap().unwrap();
313        }
314
315        // Collect all messages
316        let messages: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 3).await.unwrap();
317        assert_eq!(messages.len(), 3);
318    }
319
320    #[tokio::test]
321    async fn test_timeout() {
322        let relay = MemoryRelay::with_timeout(100); // 100ms timeout
323        let session_id = [0u8; 32];
324
325        // Only send 1 message but request 2
326        relay
327            .broadcast(
328                &session_id,
329                1,
330                &TestMessage {
331                    value: 1,
332                    data: "only one".to_string(),
333                },
334            )
335            .await
336            .unwrap();
337
338        let result: Result<Vec<TestMessage>> = relay.collect_broadcasts(&session_id, 1, 2).await;
339        assert!(result.is_err());
340        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
341    }
342
343    #[tokio::test]
344    async fn test_separate_sessions() {
345        let relay = MemoryRelay::new();
346        let session1 = [1u8; 32];
347        let session2 = [2u8; 32];
348
349        relay
350            .broadcast(
351                &session1,
352                1,
353                &TestMessage {
354                    value: 1,
355                    data: "s1".to_string(),
356                },
357            )
358            .await
359            .unwrap();
360
361        relay
362            .broadcast(
363                &session2,
364                1,
365                &TestMessage {
366                    value: 2,
367                    data: "s2".to_string(),
368                },
369            )
370            .await
371            .unwrap();
372
373        let msgs1: Vec<TestMessage> = relay.collect_broadcasts(&session1, 1, 1).await.unwrap();
374        let msgs2: Vec<TestMessage> = relay.collect_broadcasts(&session2, 1, 1).await.unwrap();
375
376        assert_eq!(msgs1[0].value, 1);
377        assert_eq!(msgs2[0].value, 2);
378    }
379
380    #[tokio::test]
381    async fn test_separate_rounds() {
382        let relay = MemoryRelay::new();
383        let session_id = [0u8; 32];
384
385        relay
386            .broadcast(
387                &session_id,
388                1,
389                &TestMessage {
390                    value: 1,
391                    data: "r1".to_string(),
392                },
393            )
394            .await
395            .unwrap();
396
397        relay
398            .broadcast(
399                &session_id,
400                2,
401                &TestMessage {
402                    value: 2,
403                    data: "r2".to_string(),
404                },
405            )
406            .await
407            .unwrap();
408
409        let msgs1: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 1, 1).await.unwrap();
410        let msgs2: Vec<TestMessage> = relay.collect_broadcasts(&session_id, 2, 1).await.unwrap();
411
412        assert_eq!(msgs1[0].value, 1);
413        assert_eq!(msgs2[0].value, 2);
414    }
415
416    #[test]
417    fn test_clear() {
418        let relay = MemoryRelay::new();
419        let session_id = [0u8; 32];
420
421        // Add some messages synchronously using the underlying maps
422        relay.broadcasts.insert(
423            (session_id, 1),
424            vec![
425                serde_json::to_vec(&TestMessage {
426                    value: 1,
427                    data: "test".to_string(),
428                })
429                .unwrap(),
430            ],
431        );
432
433        assert_eq!(relay.broadcast_count(&session_id, 1), 1);
434
435        relay.clear();
436
437        assert_eq!(relay.broadcast_count(&session_id, 1), 0);
438    }
439}