blueprint_networking_round_based_extension/
lib.rs

1use blueprint_core::trace;
2use blueprint_crypto::KeyType;
3use blueprint_networking::{service_handle::NetworkServiceHandle, types::ProtocolMessage};
4use futures::{Sink, Stream};
5use libp2p::PeerId;
6use round_based::{Delivery, Incoming, MessageDestination, MessageType, Outgoing, PartyIndex};
7use serde::{Serialize, de::DeserializeOwned};
8use std::{
9    collections::HashMap,
10    pin::Pin,
11    sync::{
12        Arc,
13        atomic::{AtomicU64, Ordering},
14    },
15    task::{Context, Poll},
16};
17
18/// Wrapper to adapt [`NetworkServiceHandle`] to round-based protocols
19pub struct RoundBasedNetworkAdapter<M, K: KeyType> {
20    /// The underlying network handle
21    handle: NetworkServiceHandle<K>,
22    /// Counter for message IDs
23    next_msg_id: Arc<AtomicU64>,
24    /// Protocol identifier
25    protocol_id: String,
26    _phantom: std::marker::PhantomData<M>,
27}
28
29impl<M, K: KeyType> RoundBasedNetworkAdapter<M, K>
30where
31    M: Clone + Send + Sync + Unpin + 'static,
32    M: Serialize + DeserializeOwned,
33    M: round_based::ProtocolMessage,
34{
35    pub fn new(
36        handle: NetworkServiceHandle<K>,
37        _party_index: PartyIndex,
38        _parties: &HashMap<PartyIndex, PeerId>,
39        protocol_id: impl Into<String>,
40    ) -> Self {
41        Self {
42            handle,
43            next_msg_id: Arc::new(AtomicU64::new(0)),
44            protocol_id: protocol_id.into(),
45            _phantom: std::marker::PhantomData,
46        }
47    }
48}
49
50impl<M, K: KeyType> Delivery<M> for RoundBasedNetworkAdapter<M, K>
51where
52    M: Clone + Send + Sync + Unpin + 'static,
53    M: Serialize + DeserializeOwned,
54    M: round_based::ProtocolMessage,
55    K::Public: Unpin,
56    K::Secret: Unpin,
57{
58    type Send = RoundBasedSender<M, K>;
59    type Receive = RoundBasedReceiver<M, K>;
60    type SendError = NetworkError;
61    type ReceiveError = NetworkError;
62
63    fn split(self) -> (Self::Receive, Self::Send) {
64        let RoundBasedNetworkAdapter {
65            handle,
66            next_msg_id,
67            protocol_id,
68            ..
69        } = self;
70
71        let sender = RoundBasedSender {
72            handle: handle.clone(),
73            next_msg_id: next_msg_id.clone(),
74            protocol_id: protocol_id.clone(),
75            _phantom: std::marker::PhantomData,
76        };
77
78        let receiver = RoundBasedReceiver::new(handle);
79
80        (receiver, sender)
81    }
82}
83
84pub struct RoundBasedSender<M, K: KeyType> {
85    handle: NetworkServiceHandle<K>,
86    next_msg_id: Arc<AtomicU64>,
87    protocol_id: String,
88    _phantom: std::marker::PhantomData<M>,
89}
90
91impl<M, K: KeyType> Sink<Outgoing<M>> for RoundBasedSender<M, K>
92where
93    M: Serialize + round_based::ProtocolMessage + Clone + Unpin,
94    K::Public: Unpin,
95    K::Secret: Unpin,
96{
97    type Error = NetworkError;
98
99    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        Poll::Ready(Ok(()))
101    }
102
103    fn start_send(self: Pin<&mut Self>, outgoing: Outgoing<M>) -> Result<(), Self::Error> {
104        let this = self.get_mut();
105        let msg_id = this.next_msg_id.fetch_add(1, Ordering::Relaxed);
106        let round = outgoing.msg.round();
107        let party_index = this
108            .handle
109            .peer_manager
110            .get_whitelist_index_from_peer_id(&this.handle.local_peer_id)
111            .unwrap_or_default();
112
113        trace!(
114            i = %party_index,
115            recipient = ?outgoing.recipient,
116            %round,
117            %msg_id,
118            protocol_id = %this.protocol_id,
119            "Sending message",
120        );
121
122        let (recipient, _) = match outgoing.recipient {
123            MessageDestination::AllParties => (None, None),
124            MessageDestination::OneParty(p) => {
125                let key = this
126                    .handle
127                    .peer_manager
128                    .get_peer_id_from_whitelist_index(p as usize);
129                (Some(p), key)
130            }
131        };
132
133        let protocol_message = ProtocolMessage {
134            protocol: format!("{}/{}", this.protocol_id, round),
135            routing: blueprint_networking::types::MessageRouting {
136                message_id: msg_id,
137                round_id: round,
138                sender: this.handle.local_peer_id,
139                recipient: recipient.and_then(|p| {
140                    this.handle
141                        .peer_manager
142                        .get_peer_id_from_whitelist_index(p as usize)
143                }),
144            },
145            payload: serde_json::to_vec(&outgoing.msg).map_err(NetworkError::Serialization)?,
146        };
147
148        trace!(
149            %round,
150            %msg_id,
151            protocol_id = %this.protocol_id,
152            "Sending message to network",
153        );
154
155        this.handle
156            .send(protocol_message.routing, protocol_message.payload)
157            .map_err(NetworkError::Send)
158    }
159
160    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161        Poll::Ready(Ok(()))
162    }
163
164    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        Poll::Ready(Ok(()))
166    }
167}
168
169pub struct RoundBasedReceiver<M, K: KeyType> {
170    handle: NetworkServiceHandle<K>,
171    _phantom: std::marker::PhantomData<M>,
172}
173
174impl<M, K: KeyType> RoundBasedReceiver<M, K> {
175    fn new(handle: NetworkServiceHandle<K>) -> Self {
176        Self {
177            handle,
178            _phantom: std::marker::PhantomData,
179        }
180    }
181}
182
183impl<M, K: KeyType> Stream for RoundBasedReceiver<M, K>
184where
185    M: DeserializeOwned + round_based::ProtocolMessage + Unpin,
186    K::Public: Unpin,
187    K::Secret: Unpin,
188{
189    type Item = Result<Incoming<M>, NetworkError>;
190
191    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        // Get a mutable reference to self
193        let this = self.get_mut();
194        let party_index = this
195            .handle
196            .peer_manager
197            .get_whitelist_index_from_peer_id(&this.handle.local_peer_id)
198            .unwrap_or_default();
199        let next_protocol_message = this.handle.next_protocol_message();
200        match next_protocol_message {
201            Some(protocol_message) => {
202                let msg_type = if protocol_message.routing.recipient.is_some() {
203                    MessageType::P2P
204                } else {
205                    MessageType::Broadcast
206                };
207
208                let sender = protocol_message.routing.sender;
209                let sender_index = this
210                    .handle
211                    .peer_manager
212                    .get_whitelist_index_from_peer_id(&sender);
213                let id = protocol_message.routing.message_id;
214
215                match sender_index {
216                    Some(sender_index) => match serde_json::from_slice(&protocol_message.payload) {
217                        Ok(msg) => {
218                            trace!(
219                                i = %party_index,
220                                sender = ?sender_index,
221                                %id,
222                                protocol_id = %protocol_message.protocol,
223                                ?msg_type,
224                                size = %protocol_message.payload.len(),
225                                "Received message",
226                            );
227                            Poll::Ready(Some(Ok(Incoming {
228                                msg,
229                                sender: u16::try_from(sender_index).unwrap_or(0),
230                                id,
231                                msg_type,
232                            })))
233                        }
234                        Err(e) => Poll::Ready(Some(Err(NetworkError::Serialization(e)))),
235                    },
236                    None => {
237                        trace!(
238                            i = %party_index,
239                            sender = ?sender,
240                            %id,
241                            protocol_id = %protocol_message.protocol,
242                            "Received message from unknown sender; ignoring",
243                        );
244                        cx.waker().wake_by_ref();
245                        Poll::Pending
246                    }
247                }
248            }
249            None => {
250                //trace!(i = %this.party_index, "No message received; the waker will wake us up when there is a new message");
251                // In this case, tell the waker to wake us up when there is a new message
252                cx.waker().wake_by_ref();
253                Poll::Pending
254            }
255        }
256    }
257}
258
259#[derive(Debug, thiserror::Error)]
260pub enum NetworkError {
261    #[error("Failed to serialize message: {0}")]
262    Serialization(#[from] serde_json::Error),
263    #[error("Network error: {0}")]
264    Send(String),
265}