Skip to main content

dkls23_core/protocols/
messages.rs

1use bincode::Options;
2use serde::{de::DeserializeOwned, Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5pub trait MessageTag: Serialize + DeserializeOwned {
6    const TAG: u8;
7}
8
9#[derive(Debug, Clone)]
10pub enum MessageError {
11    Serialization(String),
12    Deserialization(String),
13    TagMismatch { expected: u8, found: u8 },
14    NotFound { sender: u8 },
15    InvalidFrame(String),
16}
17
18impl std::fmt::Display for MessageError {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            Self::Serialization(s) => write!(f, "serialization error: {s}"),
22            Self::Deserialization(s) => write!(f, "deserialization error: {s}"),
23            Self::TagMismatch { expected, found } => {
24                write!(
25                    f,
26                    "tag mismatch: expected {expected:#04x}, found {found:#04x}"
27                )
28            }
29            Self::NotFound { sender } => write!(f, "message not found for sender {sender}"),
30            Self::InvalidFrame(s) => write!(f, "invalid frame: {s}"),
31        }
32    }
33}
34
35impl std::error::Error for MessageError {}
36
37#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct PhaseOutput {
39    pub broadcasts: Vec<Vec<u8>>,
40    pub p2p: BTreeMap<u8, Vec<u8>>,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
44pub struct PhaseInput {
45    pub broadcasts: BTreeMap<u8, Vec<u8>>,
46    pub p2p: BTreeMap<u8, Vec<u8>>,
47}
48
49const FRAME_HEADER_LEN: usize = 5; // 1 byte tag + 4 bytes length
50
51fn encode_frame<T: MessageTag>(message: &T) -> Result<Vec<u8>, MessageError> {
52    let payload =
53        bincode::serialize(message).map_err(|e| MessageError::Serialization(e.to_string()))?;
54    let len = u32::try_from(payload.len())
55        .map_err(|_| MessageError::Serialization("payload exceeds u32::MAX".into()))?;
56    let mut buf = Vec::with_capacity(FRAME_HEADER_LEN + payload.len());
57    buf.push(T::TAG);
58    buf.extend_from_slice(&len.to_be_bytes());
59    buf.extend_from_slice(&payload);
60    Ok(buf)
61}
62
63fn find_in_stream<T: MessageTag>(stream: &[u8]) -> Result<T, MessageError> {
64    let mut offset = 0;
65    while offset < stream.len() {
66        if offset + FRAME_HEADER_LEN > stream.len() {
67            return Err(MessageError::InvalidFrame("truncated header".into()));
68        }
69        let tag = stream[offset];
70        let len = u32::from_be_bytes([
71            stream[offset + 1],
72            stream[offset + 2],
73            stream[offset + 3],
74            stream[offset + 4],
75        ]) as usize;
76        offset += FRAME_HEADER_LEN;
77        if offset + len > stream.len() {
78            return Err(MessageError::InvalidFrame("truncated payload".into()));
79        }
80        if tag == T::TAG {
81            // Limit deserialization to the actual payload size to prevent
82            // internal length-prefix attacks (e.g. a tiny frame claiming a
83            // multi-GB string). We cap at the payload length already validated
84            // by the frame header.
85            let payload = &stream[offset..offset + len];
86            return bincode::options()
87                .with_fixint_encoding()
88                .with_limit(len as u64)
89                .allow_trailing_bytes()
90                .deserialize(payload)
91                .map_err(|e| MessageError::Deserialization(e.to_string()));
92        }
93        offset += len;
94    }
95    // Tag never matched — caller wraps with the sender index.
96    Err(MessageError::NotFound { sender: 0 })
97}
98
99impl Default for PhaseOutput {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl PhaseOutput {
106    #[must_use]
107    pub fn new() -> Self {
108        PhaseOutput {
109            broadcasts: Vec::new(),
110            p2p: BTreeMap::new(),
111        }
112    }
113
114    pub fn add_broadcast<T: MessageTag>(&mut self, message: &T) -> Result<(), MessageError> {
115        self.broadcasts.push(encode_frame(message)?);
116        Ok(())
117    }
118
119    pub fn add_p2p<T: MessageTag>(
120        &mut self,
121        receiver: u8,
122        message: &T,
123    ) -> Result<(), MessageError> {
124        self.p2p
125            .entry(receiver)
126            .or_default()
127            .extend_from_slice(&encode_frame(message)?);
128        Ok(())
129    }
130}
131
132impl PhaseInput {
133    pub fn get_broadcast<T: MessageTag>(&self, sender: u8) -> Result<T, MessageError> {
134        let stream = self
135            .broadcasts
136            .get(&sender)
137            .ok_or(MessageError::NotFound { sender })?;
138        find_in_stream::<T>(stream).map_err(|e| match e {
139            MessageError::NotFound { .. } => MessageError::NotFound { sender },
140            other => other,
141        })
142    }
143
144    pub fn get_p2p<T: MessageTag>(&self, sender: u8) -> Result<T, MessageError> {
145        let stream = self
146            .p2p
147            .get(&sender)
148            .ok_or(MessageError::NotFound { sender })?;
149        find_in_stream::<T>(stream).map_err(|e| match e {
150            MessageError::NotFound { .. } => MessageError::NotFound { sender },
151            other => other,
152        })
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
161    struct MsgA {
162        value: u32,
163    }
164
165    const MSG_A_TAG: u8 = 0xA0;
166    impl MessageTag for MsgA {
167        const TAG: u8 = MSG_A_TAG;
168    }
169
170    #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
171    struct MsgB {
172        name: String,
173    }
174
175    const MSG_B_TAG: u8 = 0xB0;
176    impl MessageTag for MsgB {
177        const TAG: u8 = MSG_B_TAG;
178    }
179
180    #[test]
181    fn test_broadcast_round_trip() {
182        const TEST_VALUE_A: u32 = 42;
183        let msg = MsgA {
184            value: TEST_VALUE_A,
185        };
186        let mut output = PhaseOutput::new();
187        output.add_broadcast(&msg).unwrap();
188
189        // Simulate network: sender 1's broadcasts → receiver's input
190        let mut input = PhaseInput {
191            broadcasts: BTreeMap::new(),
192            p2p: BTreeMap::new(),
193        };
194        for blob in &output.broadcasts {
195            input
196                .broadcasts
197                .entry(1)
198                .or_default()
199                .extend_from_slice(blob);
200        }
201
202        let decoded: MsgA = input.get_broadcast(1).unwrap();
203        assert_eq!(decoded, msg);
204    }
205
206    #[test]
207    fn test_p2p_round_trip() {
208        const TEST_VALUE_B: u32 = 99;
209        let msg = MsgA {
210            value: TEST_VALUE_B,
211        };
212        let mut output = PhaseOutput::new();
213        output.add_p2p(2, &msg).unwrap();
214
215        // Simulate network: sender 1's p2p[2] → receiver 2's p2p[1]
216        let mut input = PhaseInput {
217            broadcasts: BTreeMap::new(),
218            p2p: BTreeMap::new(),
219        };
220        input.p2p.insert(1, output.p2p[&2].clone());
221
222        let decoded: MsgA = input.get_p2p(1).unwrap();
223        assert_eq!(decoded, msg);
224    }
225
226    #[test]
227    fn test_tag_mismatch() {
228        let msg = MsgA { value: 1 };
229        let mut output = PhaseOutput::new();
230        output.add_broadcast(&msg).unwrap();
231
232        let mut input = PhaseInput {
233            broadcasts: BTreeMap::new(),
234            p2p: BTreeMap::new(),
235        };
236        for blob in &output.broadcasts {
237            input
238                .broadcasts
239                .entry(1)
240                .or_default()
241                .extend_from_slice(blob);
242        }
243
244        // Try to decode as MsgB — should fail with NotFound
245        let result = input.get_broadcast::<MsgB>(1);
246        assert!(result.is_err());
247        assert!(matches!(
248            result.unwrap_err(),
249            MessageError::NotFound { sender: 1 }
250        ));
251    }
252
253    #[test]
254    fn test_multiple_p2p_same_receiver() {
255        let msg_a = MsgA { value: 7 };
256        let msg_b = MsgB {
257            name: "hello".into(),
258        };
259
260        let mut output = PhaseOutput::new();
261        output.add_p2p(2, &msg_a).unwrap();
262        output.add_p2p(2, &msg_b).unwrap();
263
264        // Both frames are in the same byte stream for receiver 2
265        let mut input = PhaseInput {
266            broadcasts: BTreeMap::new(),
267            p2p: BTreeMap::new(),
268        };
269        input.p2p.insert(1, output.p2p[&2].clone());
270
271        let decoded_a: MsgA = input.get_p2p(1).unwrap();
272        let decoded_b: MsgB = input.get_p2p(1).unwrap();
273        assert_eq!(decoded_a, msg_a);
274        assert_eq!(decoded_b, msg_b);
275    }
276
277    #[test]
278    fn test_multiple_broadcasts_same_sender() {
279        let msg_a = MsgA { value: 10 };
280        let msg_b = MsgB {
281            name: "world".into(),
282        };
283
284        let mut output = PhaseOutput::new();
285        output.add_broadcast(&msg_a).unwrap();
286        output.add_broadcast(&msg_b).unwrap();
287
288        // Concatenate all broadcast blobs into one stream for sender 1
289        let mut input = PhaseInput {
290            broadcasts: BTreeMap::new(),
291            p2p: BTreeMap::new(),
292        };
293        for blob in &output.broadcasts {
294            input
295                .broadcasts
296                .entry(1)
297                .or_default()
298                .extend_from_slice(blob);
299        }
300
301        let decoded_a: MsgA = input.get_broadcast(1).unwrap();
302        let decoded_b: MsgB = input.get_broadcast(1).unwrap();
303        assert_eq!(decoded_a, msg_a);
304        assert_eq!(decoded_b, msg_b);
305    }
306
307    #[test]
308    fn test_missing_sender() {
309        let input = PhaseInput {
310            broadcasts: BTreeMap::new(),
311            p2p: BTreeMap::new(),
312        };
313
314        const UNKNOWN_SENDER: u8 = 99;
315        let result = input.get_broadcast::<MsgA>(UNKNOWN_SENDER);
316        assert!(result.is_err());
317        assert!(matches!(
318            result.unwrap_err(),
319            MessageError::NotFound {
320                sender: UNKNOWN_SENDER
321            }
322        ));
323    }
324
325    #[test]
326    fn test_truncated_frame() {
327        let mut input = PhaseInput {
328            broadcasts: BTreeMap::new(),
329            p2p: BTreeMap::new(),
330        };
331        // Only 3 bytes — less than the 5-byte header
332        input.broadcasts.insert(1, vec![0x00, 0x01, 0x02]);
333
334        let result = input.get_broadcast::<MsgA>(1);
335        assert!(result.is_err());
336        assert!(matches!(result.unwrap_err(), MessageError::InvalidFrame(_)));
337    }
338
339    #[test]
340    fn test_large_u32_round_trip() {
341        // Use a value > 250 to distinguish varint from fixint encoding.
342        let msg = MsgA { value: 100_000 };
343        let mut output = PhaseOutput::new();
344        output.add_broadcast(&msg).unwrap();
345
346        let mut input = PhaseInput {
347            broadcasts: BTreeMap::new(),
348            p2p: BTreeMap::new(),
349        };
350        for blob in &output.broadcasts {
351            input
352                .broadcasts
353                .entry(1)
354                .or_default()
355                .extend_from_slice(blob);
356        }
357
358        let decoded: MsgA = input.get_broadcast(1).unwrap();
359        assert_eq!(decoded, msg);
360    }
361
362    #[test]
363    fn test_truncated_payload() {
364        let mut input = PhaseInput {
365            broadcasts: BTreeMap::new(),
366            p2p: BTreeMap::new(),
367        };
368        // Header says 100 bytes of payload but only 2 follow
369        let mut buf = vec![0xA0];
370        buf.extend_from_slice(&100u32.to_be_bytes());
371        buf.extend_from_slice(&[0x00, 0x01]);
372        input.broadcasts.insert(1, buf);
373
374        let result = input.get_broadcast::<MsgA>(1);
375        assert!(result.is_err());
376        assert!(matches!(result.unwrap_err(), MessageError::InvalidFrame(_)));
377    }
378}