mpc_wallet_core/mpc/mod.rs
1//! MPC coordination utilities
2//!
3//! This module provides the communication infrastructure for MPC protocol execution.
4//! The `Relay` trait abstracts message passing between parties, enabling different
5//! transport mechanisms (in-memory, WebSocket, REST API).
6
7use crate::{PartyId, Result, SessionId};
8use serde::{Serialize, de::DeserializeOwned};
9
10pub use async_trait::async_trait;
11
12pub mod memory;
13
14pub use memory::MemoryRelay;
15
16/// Message relay trait for MPC communication
17///
18/// Implementations of this trait handle the transport of messages between
19/// MPC protocol participants. The relay is responsible for:
20/// - Broadcasting messages to all parties
21/// - Sending direct (point-to-point) messages
22/// - Collecting and delivering messages by round
23#[async_trait]
24pub trait Relay: Send + Sync {
25 /// Broadcast a message to all parties in the session
26 ///
27 /// # Arguments
28 /// * `session_id` - Unique session identifier
29 /// * `round` - Protocol round number
30 /// * `message` - Message to broadcast (will be serialized)
31 async fn broadcast<T: Serialize + Send + Sync>(
32 &self,
33 session_id: &SessionId,
34 round: u32,
35 message: &T,
36 ) -> Result<()>;
37
38 /// Send a direct message to a specific party
39 ///
40 /// # Arguments
41 /// * `session_id` - Unique session identifier
42 /// * `round` - Protocol round number
43 /// * `to` - Target party ID
44 /// * `message` - Message to send (will be serialized)
45 async fn send_direct<T: Serialize + Send + Sync>(
46 &self,
47 session_id: &SessionId,
48 round: u32,
49 to: PartyId,
50 message: &T,
51 ) -> Result<()>;
52
53 /// Collect broadcast messages from all parties for a round
54 ///
55 /// This method blocks until `count` messages have been received.
56 ///
57 /// # Arguments
58 /// * `session_id` - Unique session identifier
59 /// * `round` - Protocol round number
60 /// * `count` - Number of messages to collect
61 async fn collect_broadcasts<T: DeserializeOwned + Send>(
62 &self,
63 session_id: &SessionId,
64 round: u32,
65 count: usize,
66 ) -> Result<Vec<T>>;
67
68 /// Collect direct messages sent to this party
69 ///
70 /// This method blocks until `count` messages have been received.
71 ///
72 /// # Arguments
73 /// * `session_id` - Unique session identifier
74 /// * `round` - Protocol round number
75 /// * `my_id` - This party's ID
76 /// * `count` - Number of messages to collect
77 async fn collect_direct<T: DeserializeOwned + Send>(
78 &self,
79 session_id: &SessionId,
80 round: u32,
81 my_id: PartyId,
82 count: usize,
83 ) -> Result<Vec<T>>;
84}
85
86/// Extension trait for relay with timeout support
87#[async_trait]
88pub trait RelayExt: Relay {
89 /// Broadcast with timeout
90 async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
91 &self,
92 session_id: &SessionId,
93 round: u32,
94 message: &T,
95 timeout: std::time::Duration,
96 ) -> Result<()>;
97
98 /// Collect broadcasts with timeout
99 async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
100 &self,
101 session_id: &SessionId,
102 round: u32,
103 count: usize,
104 timeout: std::time::Duration,
105 ) -> Result<Vec<T>>;
106}
107
108#[async_trait]
109impl<R: Relay + ?Sized> RelayExt for R {
110 async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
111 &self,
112 session_id: &SessionId,
113 round: u32,
114 message: &T,
115 timeout: std::time::Duration,
116 ) -> Result<()> {
117 tokio::time::timeout(timeout, self.broadcast(session_id, round, message))
118 .await
119 .map_err(|_| crate::Error::Timeout("broadcast".to_string()))?
120 }
121
122 async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
123 &self,
124 session_id: &SessionId,
125 round: u32,
126 count: usize,
127 timeout: std::time::Duration,
128 ) -> Result<Vec<T>> {
129 tokio::time::timeout(timeout, self.collect_broadcasts(session_id, round, count))
130 .await
131 .map_err(|_| crate::Error::Timeout("collect_broadcasts".to_string()))?
132 }
133}