1use bincode::Options;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5pub trait MessageTag: Serialize + DeserializeOwned {
6 const TAG: u8;
7}
8
9#[derive(Debug, Clone)]
10pub enum MessageError {
11 Serialization(String),
12 Deserialization(String),
13 TagMismatch { expected: u8, found: u8 },
14 NotFound { sender: u8 },
15 InvalidFrame(String),
16}
17
18impl std::fmt::Display for MessageError {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 match self {
21 Self::Serialization(s) => write!(f, "serialization error: {s}"),
22 Self::Deserialization(s) => write!(f, "deserialization error: {s}"),
23 Self::TagMismatch { expected, found } => {
24 write!(
25 f,
26 "tag mismatch: expected {expected:#04x}, found {found:#04x}"
27 )
28 }
29 Self::NotFound { sender } => write!(f, "message not found for sender {sender}"),
30 Self::InvalidFrame(s) => write!(f, "invalid frame: {s}"),
31 }
32 }
33}
34
35impl std::error::Error for MessageError {}
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct PhaseOutput {
39 pub broadcasts: Vec<Vec<u8>>,
40 pub p2p: BTreeMap<u8, Vec<u8>>,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
44pub struct PhaseInput {
45 pub broadcasts: BTreeMap<u8, Vec<u8>>,
46 pub p2p: BTreeMap<u8, Vec<u8>>,
47}
48
49const FRAME_HEADER_LEN: usize = 5; fn encode_frame<T: MessageTag>(message: &T) -> Result<Vec<u8>, MessageError> {
52 let payload =
53 bincode::serialize(message).map_err(|e| MessageError::Serialization(e.to_string()))?;
54 let len = u32::try_from(payload.len())
55 .map_err(|_| MessageError::Serialization("payload exceeds u32::MAX".into()))?;
56 let mut buf = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
57 buf.push(T::TAG);
58 buf.extend_from_slice(&len.to_be_bytes());
59 buf.extend_from_slice(&payload);
60 Ok(buf)
61}
62
63fn find_in_stream<T: MessageTag>(stream: &[u8]) -> Result<T, MessageError> {
64 let mut offset = 0;
65 while offset < stream.len() {
66 if offset + FRAME_HEADER_LEN > stream.len() {
67 return Err(MessageError::InvalidFrame("truncated header".into()));
68 }
69 let tag = stream[offset];
70 let len = u32::from_be_bytes([
71 stream[offset + 1],
72 stream[offset + 2],
73 stream[offset + 3],
74 stream[offset + 4],
75 ]) as usize;
76 offset += FRAME_HEADER_LEN;
77 if offset + len > stream.len() {
78 return Err(MessageError::InvalidFrame("truncated payload".into()));
79 }
80 if tag == T::TAG {
81 let payload = &stream[offset..offset + len];
86 return bincode::options()
87 .with_fixint_encoding()
88 .with_limit(len as u64)
89 .allow_trailing_bytes()
90 .deserialize(payload)
91 .map_err(|e| MessageError::Deserialization(e.to_string()));
92 }
93 offset += len;
94 }
95 Err(MessageError::NotFound { sender: 0 })
97}
98
99impl Default for PhaseOutput {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl PhaseOutput {
106 #[must_use]
107 pub fn new() -> Self {
108 PhaseOutput {
109 broadcasts: Vec::new(),
110 p2p: BTreeMap::new(),
111 }
112 }
113
114 pub fn add_broadcast<T: MessageTag>(&mut self, message: &T) -> Result<(), MessageError> {
115 self.broadcasts.push(encode_frame(message)?);
116 Ok(())
117 }
118
119 pub fn add_p2p<T: MessageTag>(
120 &mut self,
121 receiver: u8,
122 message: &T,
123 ) -> Result<(), MessageError> {
124 self.p2p
125 .entry(receiver)
126 .or_default()
127 .extend_from_slice(&encode_frame(message)?);
128 Ok(())
129 }
130}
131
132impl PhaseInput {
133 pub fn get_broadcast<T: MessageTag>(&self, sender: u8) -> Result<T, MessageError> {
134 let stream = self
135 .broadcasts
136 .get(&sender)
137 .ok_or(MessageError::NotFound { sender })?;
138 find_in_stream::<T>(stream).map_err(|e| match e {
139 MessageError::NotFound { .. } => MessageError::NotFound { sender },
140 other => other,
141 })
142 }
143
144 pub fn get_p2p<T: MessageTag>(&self, sender: u8) -> Result<T, MessageError> {
145 let stream = self
146 .p2p
147 .get(&sender)
148 .ok_or(MessageError::NotFound { sender })?;
149 find_in_stream::<T>(stream).map_err(|e| match e {
150 MessageError::NotFound { .. } => MessageError::NotFound { sender },
151 other => other,
152 })
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
161 struct MsgA {
162 value: u32,
163 }
164
165 const MSG_A_TAG: u8 = 0xA0;
166 impl MessageTag for MsgA {
167 const TAG: u8 = MSG_A_TAG;
168 }
169
170 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
171 struct MsgB {
172 name: String,
173 }
174
175 const MSG_B_TAG: u8 = 0xB0;
176 impl MessageTag for MsgB {
177 const TAG: u8 = MSG_B_TAG;
178 }
179
180 #[test]
181 fn test_broadcast_round_trip() {
182 const TEST_VALUE_A: u32 = 42;
183 let msg = MsgA {
184 value: TEST_VALUE_A,
185 };
186 let mut output = PhaseOutput::new();
187 output.add_broadcast(&msg).unwrap();
188
189 let mut input = PhaseInput {
191 broadcasts: BTreeMap::new(),
192 p2p: BTreeMap::new(),
193 };
194 for blob in &output.broadcasts {
195 input
196 .broadcasts
197 .entry(1)
198 .or_default()
199 .extend_from_slice(blob);
200 }
201
202 let decoded: MsgA = input.get_broadcast(1).unwrap();
203 assert_eq!(decoded, msg);
204 }
205
206 #[test]
207 fn test_p2p_round_trip() {
208 const TEST_VALUE_B: u32 = 99;
209 let msg = MsgA {
210 value: TEST_VALUE_B,
211 };
212 let mut output = PhaseOutput::new();
213 output.add_p2p(2, &msg).unwrap();
214
215 let mut input = PhaseInput {
217 broadcasts: BTreeMap::new(),
218 p2p: BTreeMap::new(),
219 };
220 input.p2p.insert(1, output.p2p[&2].clone());
221
222 let decoded: MsgA = input.get_p2p(1).unwrap();
223 assert_eq!(decoded, msg);
224 }
225
226 #[test]
227 fn test_tag_mismatch() {
228 let msg = MsgA { value: 1 };
229 let mut output = PhaseOutput::new();
230 output.add_broadcast(&msg).unwrap();
231
232 let mut input = PhaseInput {
233 broadcasts: BTreeMap::new(),
234 p2p: BTreeMap::new(),
235 };
236 for blob in &output.broadcasts {
237 input
238 .broadcasts
239 .entry(1)
240 .or_default()
241 .extend_from_slice(blob);
242 }
243
244 let result = input.get_broadcast::<MsgB>(1);
246 assert!(result.is_err());
247 assert!(matches!(
248 result.unwrap_err(),
249 MessageError::NotFound { sender: 1 }
250 ));
251 }
252
253 #[test]
254 fn test_multiple_p2p_same_receiver() {
255 let msg_a = MsgA { value: 7 };
256 let msg_b = MsgB {
257 name: "hello".into(),
258 };
259
260 let mut output = PhaseOutput::new();
261 output.add_p2p(2, &msg_a).unwrap();
262 output.add_p2p(2, &msg_b).unwrap();
263
264 let mut input = PhaseInput {
266 broadcasts: BTreeMap::new(),
267 p2p: BTreeMap::new(),
268 };
269 input.p2p.insert(1, output.p2p[&2].clone());
270
271 let decoded_a: MsgA = input.get_p2p(1).unwrap();
272 let decoded_b: MsgB = input.get_p2p(1).unwrap();
273 assert_eq!(decoded_a, msg_a);
274 assert_eq!(decoded_b, msg_b);
275 }
276
277 #[test]
278 fn test_multiple_broadcasts_same_sender() {
279 let msg_a = MsgA { value: 10 };
280 let msg_b = MsgB {
281 name: "world".into(),
282 };
283
284 let mut output = PhaseOutput::new();
285 output.add_broadcast(&msg_a).unwrap();
286 output.add_broadcast(&msg_b).unwrap();
287
288 let mut input = PhaseInput {
290 broadcasts: BTreeMap::new(),
291 p2p: BTreeMap::new(),
292 };
293 for blob in &output.broadcasts {
294 input
295 .broadcasts
296 .entry(1)
297 .or_default()
298 .extend_from_slice(blob);
299 }
300
301 let decoded_a: MsgA = input.get_broadcast(1).unwrap();
302 let decoded_b: MsgB = input.get_broadcast(1).unwrap();
303 assert_eq!(decoded_a, msg_a);
304 assert_eq!(decoded_b, msg_b);
305 }
306
307 #[test]
308 fn test_missing_sender() {
309 let input = PhaseInput {
310 broadcasts: BTreeMap::new(),
311 p2p: BTreeMap::new(),
312 };
313
314 const UNKNOWN_SENDER: u8 = 99;
315 let result = input.get_broadcast::<MsgA>(UNKNOWN_SENDER);
316 assert!(result.is_err());
317 assert!(matches!(
318 result.unwrap_err(),
319 MessageError::NotFound {
320 sender: UNKNOWN_SENDER
321 }
322 ));
323 }
324
325 #[test]
326 fn test_truncated_frame() {
327 let mut input = PhaseInput {
328 broadcasts: BTreeMap::new(),
329 p2p: BTreeMap::new(),
330 };
331 input.broadcasts.insert(1, vec![0x00, 0x01, 0x02]);
333
334 let result = input.get_broadcast::<MsgA>(1);
335 assert!(result.is_err());
336 assert!(matches!(result.unwrap_err(), MessageError::InvalidFrame(_)));
337 }
338
339 #[test]
340 fn test_large_u32_round_trip() {
341 let msg = MsgA { value: 100_000 };
343 let mut output = PhaseOutput::new();
344 output.add_broadcast(&msg).unwrap();
345
346 let mut input = PhaseInput {
347 broadcasts: BTreeMap::new(),
348 p2p: BTreeMap::new(),
349 };
350 for blob in &output.broadcasts {
351 input
352 .broadcasts
353 .entry(1)
354 .or_default()
355 .extend_from_slice(blob);
356 }
357
358 let decoded: MsgA = input.get_broadcast(1).unwrap();
359 assert_eq!(decoded, msg);
360 }
361
362 #[test]
363 fn test_truncated_payload() {
364 let mut input = PhaseInput {
365 broadcasts: BTreeMap::new(),
366 p2p: BTreeMap::new(),
367 };
368 let mut buf = vec![0xA0];
370 buf.extend_from_slice(&100u32.to_be_bytes());
371 buf.extend_from_slice(&[0x00, 0x01]);
372 input.broadcasts.insert(1, buf);
373
374 let result = input.get_broadcast::<MsgA>(1);
375 assert!(result.is_err());
376 assert!(matches!(result.unwrap_err(), MessageError::InvalidFrame(_)));
377 }
378}