mpc_wallet_core/mpc/
mod.rs

1//! MPC coordination utilities
2//!
3//! This module provides the communication infrastructure for MPC protocol execution.
4//! The `Relay` trait abstracts message passing between parties, enabling different
5//! transport mechanisms (in-memory, WebSocket, REST API).
6
7use crate::{PartyId, Result, SessionId};
8use serde::{Serialize, de::DeserializeOwned};
9
10pub use async_trait::async_trait;
11
12pub mod memory;
13
14pub use memory::MemoryRelay;
15
16/// Message relay trait for MPC communication
17///
18/// Implementations of this trait handle the transport of messages between
19/// MPC protocol participants. The relay is responsible for:
20/// - Broadcasting messages to all parties
21/// - Sending direct (point-to-point) messages
22/// - Collecting and delivering messages by round
23#[async_trait]
24pub trait Relay: Send + Sync {
25    /// Broadcast a message to all parties in the session
26    ///
27    /// # Arguments
28    /// * `session_id` - Unique session identifier
29    /// * `round` - Protocol round number
30    /// * `message` - Message to broadcast (will be serialized)
31    async fn broadcast<T: Serialize + Send + Sync>(
32        &self,
33        session_id: &SessionId,
34        round: u32,
35        message: &T,
36    ) -> Result<()>;
37
38    /// Send a direct message to a specific party
39    ///
40    /// # Arguments
41    /// * `session_id` - Unique session identifier
42    /// * `round` - Protocol round number
43    /// * `to` - Target party ID
44    /// * `message` - Message to send (will be serialized)
45    async fn send_direct<T: Serialize + Send + Sync>(
46        &self,
47        session_id: &SessionId,
48        round: u32,
49        to: PartyId,
50        message: &T,
51    ) -> Result<()>;
52
53    /// Collect broadcast messages from all parties for a round
54    ///
55    /// This method blocks until `count` messages have been received.
56    ///
57    /// # Arguments
58    /// * `session_id` - Unique session identifier
59    /// * `round` - Protocol round number
60    /// * `count` - Number of messages to collect
61    async fn collect_broadcasts<T: DeserializeOwned + Send>(
62        &self,
63        session_id: &SessionId,
64        round: u32,
65        count: usize,
66    ) -> Result<Vec<T>>;
67
68    /// Collect direct messages sent to this party
69    ///
70    /// This method blocks until `count` messages have been received.
71    ///
72    /// # Arguments
73    /// * `session_id` - Unique session identifier
74    /// * `round` - Protocol round number
75    /// * `my_id` - This party's ID
76    /// * `count` - Number of messages to collect
77    async fn collect_direct<T: DeserializeOwned + Send>(
78        &self,
79        session_id: &SessionId,
80        round: u32,
81        my_id: PartyId,
82        count: usize,
83    ) -> Result<Vec<T>>;
84}
85
86/// Extension trait for relay with timeout support
87#[async_trait]
88pub trait RelayExt: Relay {
89    /// Broadcast with timeout
90    async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
91        &self,
92        session_id: &SessionId,
93        round: u32,
94        message: &T,
95        timeout: std::time::Duration,
96    ) -> Result<()>;
97
98    /// Collect broadcasts with timeout
99    async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
100        &self,
101        session_id: &SessionId,
102        round: u32,
103        count: usize,
104        timeout: std::time::Duration,
105    ) -> Result<Vec<T>>;
106}
107
108#[async_trait]
109impl<R: Relay + ?Sized> RelayExt for R {
110    async fn broadcast_with_timeout<T: Serialize + Send + Sync>(
111        &self,
112        session_id: &SessionId,
113        round: u32,
114        message: &T,
115        timeout: std::time::Duration,
116    ) -> Result<()> {
117        tokio::time::timeout(timeout, self.broadcast(session_id, round, message))
118            .await
119            .map_err(|_| crate::Error::Timeout("broadcast".to_string()))?
120    }
121
122    async fn collect_broadcasts_with_timeout<T: DeserializeOwned + Send>(
123        &self,
124        session_id: &SessionId,
125        round: u32,
126        count: usize,
127        timeout: std::time::Duration,
128    ) -> Result<Vec<T>> {
129        tokio::time::timeout(timeout, self.collect_broadcasts(session_id, round, count))
130            .await
131            .map_err(|_| crate::Error::Timeout("collect_broadcasts".to_string()))?
132    }
133}