message_packetizer/
lib.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use hmac::{Hmac, Mac};
3use pot::{from_slice, to_vec};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::error::Error;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tls_helpers::privkey_from_base64;
10
11type HmacSha256 = Hmac<Sha256>;
12
13const MAX_PACKET_SIZE: usize = 1316; // SRT MTU size
14const PACKET_HEADER_SIZE: usize = 13; // 1 byte flags + 8 bytes msg sequence + 4 bytes packet sequence
15const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - PACKET_HEADER_SIZE;
16
17/// Trait for messages that can be signed
18pub trait SignableMessage: Serialize + for<'de> Deserialize<'de> {
19    /// Optional validation logic for the message content
20    fn validate(&self) -> Result<(), Box<dyn Error>> {
21        Ok(()) // Default implementation does no validation
22    }
23}
24
25/// A signed message envelope that can contain any SignableMessage
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SignedMessageEnvelope {
28    pub sequence: u64,
29    pub content: Vec<u8>,
30    pub timestamp: u64,
31    pub signature: Vec<u8>,
32}
33
34impl SignedMessageEnvelope {
35    pub fn to_bytes(&self) -> Bytes {
36        let mut buf = BytesMut::new();
37        buf.put_u64(self.sequence);
38        buf.put_u64(self.timestamp);
39        buf.put_u32(self.content.len() as u32);
40        buf.extend_from_slice(&self.content);
41        buf.put_u32(self.signature.len() as u32);
42        buf.extend_from_slice(&self.signature);
43        buf.freeze()
44    }
45
46    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
47        if bytes.len() < 16 {
48            return Err("Buffer too small".into());
49        }
50
51        let mut buf = &bytes[..];
52        let sequence = buf.get_u64();
53        let timestamp = buf.get_u64();
54
55        let content_len = buf.get_u32() as usize;
56        if buf.remaining() < content_len {
57            return Err("Invalid content length".into());
58        }
59        let content = buf[..content_len].to_vec();
60        buf.advance(content_len);
61
62        let signature_len = buf.get_u32() as usize;
63        if buf.remaining() != signature_len {
64            return Err("Invalid signature length".into());
65        }
66        let signature = buf[..signature_len].to_vec();
67
68        Ok(SignedMessageEnvelope {
69            sequence,
70            content,
71            timestamp,
72            signature,
73        })
74    }
75
76    pub fn to_packets(&self) -> Vec<Bytes> {
77        let full_data = self.to_bytes();
78        let mut packets = Vec::new();
79        let mut remaining = full_data.as_ref();
80        let mut packet_sequence = 0u32;
81
82        while !remaining.is_empty() {
83            let chunk_size = remaining.len().min(MAX_PAYLOAD_SIZE);
84            let (chunk, rest) = remaining.split_at(chunk_size);
85
86            let mut packet = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk_size);
87            packet.put_u8(if rest.is_empty() { 1 } else { 0 }); // flags
88            packet.put_u64(self.sequence); // message sequence
89            packet.put_u32(packet_sequence); // packet sequence
90            packet.extend_from_slice(chunk);
91
92            packets.push(packet.freeze());
93            remaining = rest;
94            packet_sequence += 1;
95        }
96
97        packets
98    }
99}
100
101pub struct MessageSigner {
102    signing_key: Vec<u8>,
103    sequence: u64,
104}
105
106impl MessageSigner {
107    pub fn new(base64_encoded_pem_key: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
108        let private_key = privkey_from_base64(base64_encoded_pem_key)?;
109        let mut hasher = Sha256::new();
110        hasher.update(&private_key.0);
111        let signing_key = hasher.finalize().to_vec();
112
113        Ok(Self {
114            signing_key,
115            sequence: 0,
116        })
117    }
118
119    pub fn sign<T: SignableMessage>(
120        &mut self,
121        message: &T,
122    ) -> Result<SignedMessageEnvelope, Box<dyn Error>> {
123        message.validate()?;
124        let content = to_vec(message)?;
125        let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
126
127        let sequence = self.sequence;
128        self.sequence = self.sequence.wrapping_add(1);
129
130        let mut data = Vec::with_capacity(content.len() + 16);
131        data.extend_from_slice(&sequence.to_be_bytes());
132        data.extend_from_slice(&content);
133        data.extend_from_slice(&timestamp.to_be_bytes());
134
135        let mut mac = HmacSha256::new_from_slice(&self.signing_key)?;
136        mac.update(&data);
137        let signature = mac.finalize().into_bytes();
138
139        Ok(SignedMessageEnvelope {
140            sequence,
141            content,
142            timestamp,
143            signature: signature.to_vec(),
144        })
145    }
146
147    pub fn verify<T: SignableMessage>(
148        &self,
149        envelope: &SignedMessageEnvelope,
150    ) -> Result<T, Box<dyn Error>> {
151        let mut data = Vec::with_capacity(envelope.content.len() + 16);
152        data.extend_from_slice(&envelope.sequence.to_be_bytes());
153        data.extend_from_slice(&envelope.content);
154        data.extend_from_slice(&envelope.timestamp.to_be_bytes());
155
156        let mut mac = HmacSha256::new_from_slice(&self.signing_key)?;
157        mac.update(&data);
158        mac.verify_slice(&envelope.signature)?;
159
160        let message: T = from_slice(&envelope.content)?;
161        message.validate()?;
162        Ok(message)
163    }
164}
165
166struct PartialMessage {
167    packets: Vec<(u32, BytesMut)>, // (packet_sequence, payload)
168    total_size: usize,
169    got_last: bool,
170}
171
172pub struct SignedMessageDemuxer {
173    partial_messages: HashMap<u64, PartialMessage>,
174}
175
176impl SignedMessageDemuxer {
177    pub fn new() -> Self {
178        Self {
179            partial_messages: HashMap::new(),
180        }
181    }
182}
183
184#[derive(Debug)]
185pub enum DemuxError {
186    InvalidPacket(String),
187    MessageCorrupted {
188        sequence: u64,
189        reason: String,
190    },
191    EnvelopeParseError {
192        sequence: u64,
193        error: Box<dyn Error>,
194    },
195}
196
197impl std::fmt::Display for DemuxError {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        match self {
200            DemuxError::InvalidPacket(msg) => write!(f, "Invalid packet: {}", msg),
201            DemuxError::MessageCorrupted { sequence, reason } => {
202                write!(f, "Message {} corrupted: {}", sequence, reason)
203            }
204            DemuxError::EnvelopeParseError { sequence, error } => {
205                write!(f, "Failed to parse message {}: {}", sequence, error)
206            }
207        }
208    }
209}
210
211impl Error for DemuxError {}
212
213#[derive(Debug)]
214pub struct DemuxResult {
215    pub messages: Vec<SignedMessageEnvelope>,
216    pub errors: Vec<DemuxError>,
217}
218
219impl DemuxResult {
220    fn new() -> Self {
221        Self {
222            messages: Vec::new(),
223            errors: Vec::new(),
224        }
225    }
226}
227
228impl SignedMessageDemuxer {
229    pub fn process_packet(&mut self, packet: &[u8]) -> DemuxResult {
230        let mut result = DemuxResult::new();
231
232        if packet.len() < PACKET_HEADER_SIZE {
233            result
234                .errors
235                .push(DemuxError::InvalidPacket("Packet too small".into()));
236            return result;
237        }
238
239        let mut buf = &packet[..];
240        let flags = buf.get_u8();
241        let msg_sequence = buf.get_u64();
242        let packet_sequence = buf.get_u32();
243        let payload = BytesMut::from(&packet[PACKET_HEADER_SIZE..]);
244        let is_last = (flags & 1) == 1;
245
246        let message = self
247            .partial_messages
248            .entry(msg_sequence)
249            .or_insert_with(|| PartialMessage {
250                packets: Vec::new(),
251                total_size: 0,
252                got_last: false,
253            });
254
255        // Check for duplicate packet sequence
256        if message
257            .packets
258            .iter()
259            .any(|(seq, _)| *seq == packet_sequence)
260        {
261            result.errors.push(DemuxError::MessageCorrupted {
262                sequence: msg_sequence,
263                reason: format!("Duplicate packet sequence {}", packet_sequence),
264            });
265            self.partial_messages.remove(&msg_sequence);
266            return result;
267        }
268
269        message.packets.push((packet_sequence, payload.clone()));
270        message.total_size += payload.len();
271        if is_last {
272            message.got_last = true;
273        }
274
275        // Check all messages for completeness
276        let mut complete_sequences = Vec::new();
277        for (&sequence, message) in &mut self.partial_messages {
278            if message.got_last {
279                message.packets.sort_by_key(|(seq, _)| *seq);
280                let expected_sequences: Vec<_> = (0..message.packets.len() as u32).collect();
281                let actual_sequences: Vec<_> =
282                    message.packets.iter().map(|(seq, _)| *seq).collect();
283                if expected_sequences == actual_sequences {
284                    complete_sequences.push(sequence);
285                }
286            }
287        }
288
289        // Process all complete messages
290        for sequence in complete_sequences {
291            if let Some(message) = self.partial_messages.remove(&sequence) {
292                let mut combined = BytesMut::with_capacity(message.total_size);
293                for (_, payload) in message.packets {
294                    combined.extend_from_slice(&payload);
295                }
296
297                match SignedMessageEnvelope::from_bytes(&combined) {
298                    Ok(envelope) => result.messages.push(envelope),
299                    Err(e) => {
300                        result
301                            .errors
302                            .push(DemuxError::EnvelopeParseError { sequence, error: e });
303                    }
304                }
305            }
306        }
307
308        result
309    }
310
311    pub fn pending_message_count(&self) -> usize {
312        self.partial_messages.len()
313    }
314
315    pub fn clear(&mut self) {
316        self.partial_messages.clear();
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
325    struct TestMessage {
326        data: String,
327    }
328
329    impl SignableMessage for TestMessage {}
330
331    #[test]
332    fn test_message_roundtrip() -> Result<(), Box<dyn Error>> {
333        let mut signer = MessageSigner::new(&std::env::var("PRIVKEY_PEM").unwrap()).unwrap();
334        let mut demuxer = SignedMessageDemuxer::new();
335
336        // Create and sign multiple large messages
337        let msg1 = TestMessage {
338            data: "first".repeat(500),
339        };
340        let msg2 = TestMessage {
341            data: "second".repeat(500),
342        };
343
344        let env1 = signer.sign(&msg1)?;
345        let env2 = signer.sign(&msg2)?;
346
347        // Split both into packets
348        let packets1 = env1.to_packets();
349        let packets2 = env2.to_packets();
350
351        assert!(packets1.len() > 1);
352        assert!(packets2.len() > 1);
353
354        // Process packets, interleaving between messages
355        for i in 0..packets1.len().max(packets2.len()) {
356            if i < packets1.len() {
357                let result = demuxer.process_packet(&packets1[i]);
358                assert!(result.errors.is_empty());
359                if i == packets1.len() - 1 {
360                    assert_eq!(result.messages.len(), 1);
361                    let decoded: TestMessage = signer.verify(&result.messages[0])?;
362                    assert_eq!(decoded, msg1);
363                } else {
364                    assert!(result.messages.is_empty());
365                }
366            }
367
368            if i < packets2.len() {
369                let result = demuxer.process_packet(&packets2[i]);
370                assert!(result.errors.is_empty());
371                if i == packets2.len() - 1 {
372                    assert_eq!(result.messages.len(), 1);
373                    let decoded: TestMessage = signer.verify(&result.messages[0])?;
374                    assert_eq!(decoded, msg2);
375                } else {
376                    assert!(result.messages.is_empty());
377                }
378            }
379        }
380
381        assert_eq!(demuxer.pending_message_count(), 0);
382        Ok(())
383    }
384
385    #[test]
386    fn test_error_handling() -> Result<(), Box<dyn Error>> {
387        let mut demuxer = SignedMessageDemuxer::new();
388
389        // Test invalid packet
390        let result = demuxer.process_packet(&[1, 2, 3]);
391        assert_eq!(result.messages.len(), 0);
392        assert_eq!(result.errors.len(), 1);
393        match &result.errors[0] {
394            DemuxError::InvalidPacket(_) => (),
395            _ => panic!("Expected InvalidPacket error"),
396        }
397
398        // Test duplicate packet sequence
399        let mut signer = MessageSigner::new(&std::env::var("PRIVKEY_PEM").unwrap()).unwrap();
400        let msg = TestMessage {
401            data: "test".repeat(500),
402        };
403        let env = signer.sign(&msg)?;
404        let packets = env.to_packets();
405
406        // Send first packet twice
407        let result1 = demuxer.process_packet(&packets[0]);
408        assert!(result1.errors.is_empty());
409        let result2 = demuxer.process_packet(&packets[0]);
410        assert_eq!(result2.errors.len(), 1);
411        match &result2.errors[0] {
412            DemuxError::MessageCorrupted { sequence, .. } => {
413                assert_eq!(*sequence, env.sequence);
414            }
415            _ => panic!("Expected MessageCorrupted error"),
416        }
417
418        Ok(())
419    }
420}