blueprint_networking_round_based_extension/
lib.rs1use blueprint_core::trace;
2use blueprint_crypto::KeyType;
3use blueprint_networking::{service_handle::NetworkServiceHandle, types::ProtocolMessage};
4use futures::{Sink, Stream};
5use libp2p::PeerId;
6use round_based::{Delivery, Incoming, MessageDestination, MessageType, Outgoing, PartyIndex};
7use serde::{Serialize, de::DeserializeOwned};
8use std::{
9 collections::HashMap,
10 pin::Pin,
11 sync::{
12 Arc,
13 atomic::{AtomicU64, Ordering},
14 },
15 task::{Context, Poll},
16};
17
18pub struct RoundBasedNetworkAdapter<M, K: KeyType> {
20 handle: NetworkServiceHandle<K>,
22 next_msg_id: Arc<AtomicU64>,
24 protocol_id: String,
26 _phantom: std::marker::PhantomData<M>,
27}
28
29impl<M, K: KeyType> RoundBasedNetworkAdapter<M, K>
30where
31 M: Clone + Send + Sync + Unpin + 'static,
32 M: Serialize + DeserializeOwned,
33 M: round_based::ProtocolMessage,
34{
35 pub fn new(
36 handle: NetworkServiceHandle<K>,
37 _party_index: PartyIndex,
38 _parties: &HashMap<PartyIndex, PeerId>,
39 protocol_id: impl Into<String>,
40 ) -> Self {
41 Self {
42 handle,
43 next_msg_id: Arc::new(AtomicU64::new(0)),
44 protocol_id: protocol_id.into(),
45 _phantom: std::marker::PhantomData,
46 }
47 }
48}
49
50impl<M, K: KeyType> Delivery<M> for RoundBasedNetworkAdapter<M, K>
51where
52 M: Clone + Send + Sync + Unpin + 'static,
53 M: Serialize + DeserializeOwned,
54 M: round_based::ProtocolMessage,
55 K::Public: Unpin,
56 K::Secret: Unpin,
57{
58 type Send = RoundBasedSender<M, K>;
59 type Receive = RoundBasedReceiver<M, K>;
60 type SendError = NetworkError;
61 type ReceiveError = NetworkError;
62
63 fn split(self) -> (Self::Receive, Self::Send) {
64 let RoundBasedNetworkAdapter {
65 handle,
66 next_msg_id,
67 protocol_id,
68 ..
69 } = self;
70
71 let sender = RoundBasedSender {
72 handle: handle.clone(),
73 next_msg_id: next_msg_id.clone(),
74 protocol_id: protocol_id.clone(),
75 _phantom: std::marker::PhantomData,
76 };
77
78 let receiver = RoundBasedReceiver::new(handle);
79
80 (receiver, sender)
81 }
82}
83
84pub struct RoundBasedSender<M, K: KeyType> {
85 handle: NetworkServiceHandle<K>,
86 next_msg_id: Arc<AtomicU64>,
87 protocol_id: String,
88 _phantom: std::marker::PhantomData<M>,
89}
90
91impl<M, K: KeyType> Sink<Outgoing<M>> for RoundBasedSender<M, K>
92where
93 M: Serialize + round_based::ProtocolMessage + Clone + Unpin,
94 K::Public: Unpin,
95 K::Secret: Unpin,
96{
97 type Error = NetworkError;
98
99 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 Poll::Ready(Ok(()))
101 }
102
103 fn start_send(self: Pin<&mut Self>, outgoing: Outgoing<M>) -> Result<(), Self::Error> {
104 let this = self.get_mut();
105 let msg_id = this.next_msg_id.fetch_add(1, Ordering::Relaxed);
106 let round = outgoing.msg.round();
107 let party_index = this
108 .handle
109 .peer_manager
110 .get_whitelist_index_from_peer_id(&this.handle.local_peer_id)
111 .unwrap_or_default();
112
113 trace!(
114 i = %party_index,
115 recipient = ?outgoing.recipient,
116 %round,
117 %msg_id,
118 protocol_id = %this.protocol_id,
119 "Sending message",
120 );
121
122 let (recipient, _) = match outgoing.recipient {
123 MessageDestination::AllParties => (None, None),
124 MessageDestination::OneParty(p) => {
125 let key = this
126 .handle
127 .peer_manager
128 .get_peer_id_from_whitelist_index(p as usize);
129 (Some(p), key)
130 }
131 };
132
133 let protocol_message = ProtocolMessage {
134 protocol: format!("{}/{}", this.protocol_id, round),
135 routing: blueprint_networking::types::MessageRouting {
136 message_id: msg_id,
137 round_id: round,
138 sender: this.handle.local_peer_id,
139 recipient: recipient.and_then(|p| {
140 this.handle
141 .peer_manager
142 .get_peer_id_from_whitelist_index(p as usize)
143 }),
144 },
145 payload: serde_json::to_vec(&outgoing.msg).map_err(NetworkError::Serialization)?,
146 };
147
148 trace!(
149 %round,
150 %msg_id,
151 protocol_id = %this.protocol_id,
152 "Sending message to network",
153 );
154
155 this.handle
156 .send(protocol_message.routing, protocol_message.payload)
157 .map_err(NetworkError::Send)
158 }
159
160 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165 Poll::Ready(Ok(()))
166 }
167}
168
169pub struct RoundBasedReceiver<M, K: KeyType> {
170 handle: NetworkServiceHandle<K>,
171 _phantom: std::marker::PhantomData<M>,
172}
173
174impl<M, K: KeyType> RoundBasedReceiver<M, K> {
175 fn new(handle: NetworkServiceHandle<K>) -> Self {
176 Self {
177 handle,
178 _phantom: std::marker::PhantomData,
179 }
180 }
181}
182
183impl<M, K: KeyType> Stream for RoundBasedReceiver<M, K>
184where
185 M: DeserializeOwned + round_based::ProtocolMessage + Unpin,
186 K::Public: Unpin,
187 K::Secret: Unpin,
188{
189 type Item = Result<Incoming<M>, NetworkError>;
190
191 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 let this = self.get_mut();
194 let party_index = this
195 .handle
196 .peer_manager
197 .get_whitelist_index_from_peer_id(&this.handle.local_peer_id)
198 .unwrap_or_default();
199 let next_protocol_message = this.handle.next_protocol_message();
200 match next_protocol_message {
201 Some(protocol_message) => {
202 let msg_type = if protocol_message.routing.recipient.is_some() {
203 MessageType::P2P
204 } else {
205 MessageType::Broadcast
206 };
207
208 let sender = protocol_message.routing.sender;
209 let sender_index = this
210 .handle
211 .peer_manager
212 .get_whitelist_index_from_peer_id(&sender);
213 let id = protocol_message.routing.message_id;
214
215 match sender_index {
216 Some(sender_index) => match serde_json::from_slice(&protocol_message.payload) {
217 Ok(msg) => {
218 trace!(
219 i = %party_index,
220 sender = ?sender_index,
221 %id,
222 protocol_id = %protocol_message.protocol,
223 ?msg_type,
224 size = %protocol_message.payload.len(),
225 "Received message",
226 );
227 Poll::Ready(Some(Ok(Incoming {
228 msg,
229 sender: u16::try_from(sender_index).unwrap_or(0),
230 id,
231 msg_type,
232 })))
233 }
234 Err(e) => Poll::Ready(Some(Err(NetworkError::Serialization(e)))),
235 },
236 None => {
237 trace!(
238 i = %party_index,
239 sender = ?sender,
240 %id,
241 protocol_id = %protocol_message.protocol,
242 "Received message from unknown sender; ignoring",
243 );
244 cx.waker().wake_by_ref();
245 Poll::Pending
246 }
247 }
248 }
249 None => {
250 cx.waker().wake_by_ref();
253 Poll::Pending
254 }
255 }
256 }
257}
258
259#[derive(Debug, thiserror::Error)]
260pub enum NetworkError {
261 #[error("Failed to serialize message: {0}")]
262 Serialization(#[from] serde_json::Error),
263 #[error("Network error: {0}")]
264 Send(String),
265}