1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use hmac::{Hmac, Mac};
3use pot::{from_slice, to_vec};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::error::Error;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tls_helpers::privkey_from_base64;
10
11type HmacSha256 = Hmac<Sha256>;
12
13const MAX_PACKET_SIZE: usize = 1316; const PACKET_HEADER_SIZE: usize = 13; const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - PACKET_HEADER_SIZE;
16
17pub trait SignableMessage: Serialize + for<'de> Deserialize<'de> {
19 fn validate(&self) -> Result<(), Box<dyn Error>> {
21 Ok(()) }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct SignedMessageEnvelope {
28 pub sequence: u64,
29 pub content: Vec<u8>,
30 pub timestamp: u64,
31 pub signature: Vec<u8>,
32}
33
34impl SignedMessageEnvelope {
35 pub fn to_bytes(&self) -> Bytes {
36 let mut buf = BytesMut::new();
37 buf.put_u64(self.sequence);
38 buf.put_u64(self.timestamp);
39 buf.put_u32(self.content.len() as u32);
40 buf.extend_from_slice(&self.content);
41 buf.put_u32(self.signature.len() as u32);
42 buf.extend_from_slice(&self.signature);
43 buf.freeze()
44 }
45
46 pub fn from_bytes(bytes: &[u8]) -> Result<Self, Box<dyn Error>> {
47 if bytes.len() < 16 {
48 return Err("Buffer too small".into());
49 }
50
51 let mut buf = &bytes[..];
52 let sequence = buf.get_u64();
53 let timestamp = buf.get_u64();
54
55 let content_len = buf.get_u32() as usize;
56 if buf.remaining() < content_len {
57 return Err("Invalid content length".into());
58 }
59 let content = buf[..content_len].to_vec();
60 buf.advance(content_len);
61
62 let signature_len = buf.get_u32() as usize;
63 if buf.remaining() != signature_len {
64 return Err("Invalid signature length".into());
65 }
66 let signature = buf[..signature_len].to_vec();
67
68 Ok(SignedMessageEnvelope {
69 sequence,
70 content,
71 timestamp,
72 signature,
73 })
74 }
75
76 pub fn to_packets(&self) -> Vec<Bytes> {
77 let full_data = self.to_bytes();
78 let mut packets = Vec::new();
79 let mut remaining = full_data.as_ref();
80 let mut packet_sequence = 0u32;
81
82 while !remaining.is_empty() {
83 let chunk_size = remaining.len().min(MAX_PAYLOAD_SIZE);
84 let (chunk, rest) = remaining.split_at(chunk_size);
85
86 let mut packet = BytesMut::with_capacity(PACKET_HEADER_SIZE + chunk_size);
87 packet.put_u8(if rest.is_empty() { 1 } else { 0 }); packet.put_u64(self.sequence); packet.put_u32(packet_sequence); packet.extend_from_slice(chunk);
91
92 packets.push(packet.freeze());
93 remaining = rest;
94 packet_sequence += 1;
95 }
96
97 packets
98 }
99}
100
101pub struct MessageSigner {
102 signing_key: Vec<u8>,
103 sequence: u64,
104}
105
106impl MessageSigner {
107 pub fn new(base64_encoded_pem_key: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
108 let private_key = privkey_from_base64(base64_encoded_pem_key)?;
109 let mut hasher = Sha256::new();
110 hasher.update(&private_key.0);
111 let signing_key = hasher.finalize().to_vec();
112
113 Ok(Self {
114 signing_key,
115 sequence: 0,
116 })
117 }
118
119 pub fn sign<T: SignableMessage>(
120 &mut self,
121 message: &T,
122 ) -> Result<SignedMessageEnvelope, Box<dyn Error>> {
123 message.validate()?;
124 let content = to_vec(message)?;
125 let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
126
127 let sequence = self.sequence;
128 self.sequence = self.sequence.wrapping_add(1);
129
130 let mut data = Vec::with_capacity(content.len() + 16);
131 data.extend_from_slice(&sequence.to_be_bytes());
132 data.extend_from_slice(&content);
133 data.extend_from_slice(×tamp.to_be_bytes());
134
135 let mut mac = HmacSha256::new_from_slice(&self.signing_key)?;
136 mac.update(&data);
137 let signature = mac.finalize().into_bytes();
138
139 Ok(SignedMessageEnvelope {
140 sequence,
141 content,
142 timestamp,
143 signature: signature.to_vec(),
144 })
145 }
146
147 pub fn verify<T: SignableMessage>(
148 &self,
149 envelope: &SignedMessageEnvelope,
150 ) -> Result<T, Box<dyn Error>> {
151 let mut data = Vec::with_capacity(envelope.content.len() + 16);
152 data.extend_from_slice(&envelope.sequence.to_be_bytes());
153 data.extend_from_slice(&envelope.content);
154 data.extend_from_slice(&envelope.timestamp.to_be_bytes());
155
156 let mut mac = HmacSha256::new_from_slice(&self.signing_key)?;
157 mac.update(&data);
158 mac.verify_slice(&envelope.signature)?;
159
160 let message: T = from_slice(&envelope.content)?;
161 message.validate()?;
162 Ok(message)
163 }
164}
165
166struct PartialMessage {
167 packets: Vec<(u32, BytesMut)>, total_size: usize,
169 got_last: bool,
170}
171
172pub struct SignedMessageDemuxer {
173 partial_messages: HashMap<u64, PartialMessage>,
174}
175
176impl SignedMessageDemuxer {
177 pub fn new() -> Self {
178 Self {
179 partial_messages: HashMap::new(),
180 }
181 }
182}
183
184#[derive(Debug)]
185pub enum DemuxError {
186 InvalidPacket(String),
187 MessageCorrupted {
188 sequence: u64,
189 reason: String,
190 },
191 EnvelopeParseError {
192 sequence: u64,
193 error: Box<dyn Error>,
194 },
195}
196
197impl std::fmt::Display for DemuxError {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
200 DemuxError::InvalidPacket(msg) => write!(f, "Invalid packet: {}", msg),
201 DemuxError::MessageCorrupted { sequence, reason } => {
202 write!(f, "Message {} corrupted: {}", sequence, reason)
203 }
204 DemuxError::EnvelopeParseError { sequence, error } => {
205 write!(f, "Failed to parse message {}: {}", sequence, error)
206 }
207 }
208 }
209}
210
211impl Error for DemuxError {}
212
213#[derive(Debug)]
214pub struct DemuxResult {
215 pub messages: Vec<SignedMessageEnvelope>,
216 pub errors: Vec<DemuxError>,
217}
218
219impl DemuxResult {
220 fn new() -> Self {
221 Self {
222 messages: Vec::new(),
223 errors: Vec::new(),
224 }
225 }
226}
227
228impl SignedMessageDemuxer {
229 pub fn process_packet(&mut self, packet: &[u8]) -> DemuxResult {
230 let mut result = DemuxResult::new();
231
232 if packet.len() < PACKET_HEADER_SIZE {
233 result
234 .errors
235 .push(DemuxError::InvalidPacket("Packet too small".into()));
236 return result;
237 }
238
239 let mut buf = &packet[..];
240 let flags = buf.get_u8();
241 let msg_sequence = buf.get_u64();
242 let packet_sequence = buf.get_u32();
243 let payload = BytesMut::from(&packet[PACKET_HEADER_SIZE..]);
244 let is_last = (flags & 1) == 1;
245
246 let message = self
247 .partial_messages
248 .entry(msg_sequence)
249 .or_insert_with(|| PartialMessage {
250 packets: Vec::new(),
251 total_size: 0,
252 got_last: false,
253 });
254
255 if message
257 .packets
258 .iter()
259 .any(|(seq, _)| *seq == packet_sequence)
260 {
261 result.errors.push(DemuxError::MessageCorrupted {
262 sequence: msg_sequence,
263 reason: format!("Duplicate packet sequence {}", packet_sequence),
264 });
265 self.partial_messages.remove(&msg_sequence);
266 return result;
267 }
268
269 message.packets.push((packet_sequence, payload.clone()));
270 message.total_size += payload.len();
271 if is_last {
272 message.got_last = true;
273 }
274
275 let mut complete_sequences = Vec::new();
277 for (&sequence, message) in &mut self.partial_messages {
278 if message.got_last {
279 message.packets.sort_by_key(|(seq, _)| *seq);
280 let expected_sequences: Vec<_> = (0..message.packets.len() as u32).collect();
281 let actual_sequences: Vec<_> =
282 message.packets.iter().map(|(seq, _)| *seq).collect();
283 if expected_sequences == actual_sequences {
284 complete_sequences.push(sequence);
285 }
286 }
287 }
288
289 for sequence in complete_sequences {
291 if let Some(message) = self.partial_messages.remove(&sequence) {
292 let mut combined = BytesMut::with_capacity(message.total_size);
293 for (_, payload) in message.packets {
294 combined.extend_from_slice(&payload);
295 }
296
297 match SignedMessageEnvelope::from_bytes(&combined) {
298 Ok(envelope) => result.messages.push(envelope),
299 Err(e) => {
300 result
301 .errors
302 .push(DemuxError::EnvelopeParseError { sequence, error: e });
303 }
304 }
305 }
306 }
307
308 result
309 }
310
311 pub fn pending_message_count(&self) -> usize {
312 self.partial_messages.len()
313 }
314
315 pub fn clear(&mut self) {
316 self.partial_messages.clear();
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
325 struct TestMessage {
326 data: String,
327 }
328
329 impl SignableMessage for TestMessage {}
330
331 #[test]
332 fn test_message_roundtrip() -> Result<(), Box<dyn Error>> {
333 let mut signer = MessageSigner::new(&std::env::var("PRIVKEY_PEM").unwrap()).unwrap();
334 let mut demuxer = SignedMessageDemuxer::new();
335
336 let msg1 = TestMessage {
338 data: "first".repeat(500),
339 };
340 let msg2 = TestMessage {
341 data: "second".repeat(500),
342 };
343
344 let env1 = signer.sign(&msg1)?;
345 let env2 = signer.sign(&msg2)?;
346
347 let packets1 = env1.to_packets();
349 let packets2 = env2.to_packets();
350
351 assert!(packets1.len() > 1);
352 assert!(packets2.len() > 1);
353
354 for i in 0..packets1.len().max(packets2.len()) {
356 if i < packets1.len() {
357 let result = demuxer.process_packet(&packets1[i]);
358 assert!(result.errors.is_empty());
359 if i == packets1.len() - 1 {
360 assert_eq!(result.messages.len(), 1);
361 let decoded: TestMessage = signer.verify(&result.messages[0])?;
362 assert_eq!(decoded, msg1);
363 } else {
364 assert!(result.messages.is_empty());
365 }
366 }
367
368 if i < packets2.len() {
369 let result = demuxer.process_packet(&packets2[i]);
370 assert!(result.errors.is_empty());
371 if i == packets2.len() - 1 {
372 assert_eq!(result.messages.len(), 1);
373 let decoded: TestMessage = signer.verify(&result.messages[0])?;
374 assert_eq!(decoded, msg2);
375 } else {
376 assert!(result.messages.is_empty());
377 }
378 }
379 }
380
381 assert_eq!(demuxer.pending_message_count(), 0);
382 Ok(())
383 }
384
385 #[test]
386 fn test_error_handling() -> Result<(), Box<dyn Error>> {
387 let mut demuxer = SignedMessageDemuxer::new();
388
389 let result = demuxer.process_packet(&[1, 2, 3]);
391 assert_eq!(result.messages.len(), 0);
392 assert_eq!(result.errors.len(), 1);
393 match &result.errors[0] {
394 DemuxError::InvalidPacket(_) => (),
395 _ => panic!("Expected InvalidPacket error"),
396 }
397
398 let mut signer = MessageSigner::new(&std::env::var("PRIVKEY_PEM").unwrap()).unwrap();
400 let msg = TestMessage {
401 data: "test".repeat(500),
402 };
403 let env = signer.sign(&msg)?;
404 let packets = env.to_packets();
405
406 let result1 = demuxer.process_packet(&packets[0]);
408 assert!(result1.errors.is_empty());
409 let result2 = demuxer.process_packet(&packets[0]);
410 assert_eq!(result2.errors.len(), 1);
411 match &result2.errors[0] {
412 DemuxError::MessageCorrupted { sequence, .. } => {
413 assert_eq!(*sequence, env.sequence);
414 }
415 _ => panic!("Expected MessageCorrupted error"),
416 }
417
418 Ok(())
419 }
420}