fips_core/transport/tcp/
stream.rs1use tokio::io::{AsyncRead, AsyncReadExt};
10
11const PHASE_ESTABLISHED: u8 = 0x0;
13const PHASE_MSG1: u8 = 0x1;
14const PHASE_MSG2: u8 = 0x2;
15
16const PREFIX_SIZE: usize = 4;
18
19const ESTABLISHED_REMAINING_HEADER: usize = 12;
24const AEAD_TAG_SIZE: usize = 16;
25
26#[derive(Debug)]
28pub enum StreamError {
29 UnknownVersion(u8),
31 UnknownPhase(u8),
33 PayloadTooLarge {
35 payload_len: u16,
36 max_payload_len: u16,
37 },
38 HandshakeSizeMismatch { phase: u8, expected: u16, got: u16 },
40 Io(std::io::Error),
42}
43
44impl std::fmt::Display for StreamError {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 match self {
47 StreamError::UnknownVersion(v) => write!(f, "unknown FMP version: {}", v),
48 StreamError::UnknownPhase(p) => write!(f, "unknown FMP phase: 0x{:02x}", p),
49 StreamError::PayloadTooLarge {
50 payload_len,
51 max_payload_len,
52 } => {
53 write!(
54 f,
55 "payload_len {} exceeds max {}",
56 payload_len, max_payload_len
57 )
58 }
59 StreamError::HandshakeSizeMismatch {
60 phase,
61 expected,
62 got,
63 } => {
64 write!(
65 f,
66 "handshake phase 0x{:x}: expected payload_len {}, got {}",
67 phase, expected, got
68 )
69 }
70 StreamError::Io(e) => write!(f, "io: {}", e),
71 }
72 }
73}
74
75impl std::error::Error for StreamError {}
76
77impl From<std::io::Error> for StreamError {
78 fn from(e: std::io::Error) -> Self {
79 StreamError::Io(e)
80 }
81}
82
83const MSG1_WIRE_SIZE: usize = 114;
87const MSG2_WIRE_SIZE: usize = 69;
88
89const MSG1_PAYLOAD_LEN: u16 = (MSG1_WIRE_SIZE - PREFIX_SIZE) as u16;
91
92const MSG2_PAYLOAD_LEN: u16 = (MSG2_WIRE_SIZE - PREFIX_SIZE) as u16;
94
95pub async fn read_fmp_packet<R: AsyncRead + Unpin>(
113 reader: &mut R,
114 mtu: u16,
115) -> Result<Vec<u8>, StreamError> {
116 let mut prefix = [0u8; PREFIX_SIZE];
118 reader.read_exact(&mut prefix).await?;
119
120 let version = prefix[0] >> 4;
121 let phase = prefix[0] & 0x0F;
122
123 if version != 0 {
124 return Err(StreamError::UnknownVersion(version));
125 }
126
127 let payload_len = u16::from_le_bytes([prefix[2], prefix[3]]);
128
129 let remaining = match phase {
131 PHASE_ESTABLISHED => {
132 let max_payload_len = mtu.saturating_sub(
136 (ESTABLISHED_REMAINING_HEADER + PREFIX_SIZE + AEAD_TAG_SIZE) as u16,
137 );
138 if payload_len > max_payload_len {
139 return Err(StreamError::PayloadTooLarge {
140 payload_len,
141 max_payload_len,
142 });
143 }
144 ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE
146 }
147 PHASE_MSG1 => {
148 if payload_len != MSG1_PAYLOAD_LEN {
149 return Err(StreamError::HandshakeSizeMismatch {
150 phase,
151 expected: MSG1_PAYLOAD_LEN,
152 got: payload_len,
153 });
154 }
155 payload_len as usize
156 }
157 PHASE_MSG2 => {
158 if payload_len != MSG2_PAYLOAD_LEN {
159 return Err(StreamError::HandshakeSizeMismatch {
160 phase,
161 expected: MSG2_PAYLOAD_LEN,
162 got: payload_len,
163 });
164 }
165 payload_len as usize
166 }
167 _ => {
168 return Err(StreamError::UnknownPhase(phase));
169 }
170 };
171
172 let total = PREFIX_SIZE + remaining;
174 let mut packet = vec![0u8; total];
175 packet[..PREFIX_SIZE].copy_from_slice(&prefix);
176 reader.read_exact(&mut packet[PREFIX_SIZE..]).await?;
177
178 Ok(packet)
179}
180
181#[cfg(test)]
186mod tests {
187 use super::*;
188 use std::io::Cursor;
189
190 fn build_established_frame(payload_len: u16) -> Vec<u8> {
193 let total =
194 PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE;
195 let mut frame = vec![0u8; total];
196 frame[0] = 0x00; frame[1] = 0x00; frame[2..4].copy_from_slice(&payload_len.to_le_bytes());
199 for (i, byte) in frame[PREFIX_SIZE..total].iter_mut().enumerate() {
201 *byte = ((PREFIX_SIZE + i) & 0xFF) as u8;
202 }
203 frame
204 }
205
206 fn build_msg1_frame() -> Vec<u8> {
208 let mut frame = vec![0xAA; MSG1_WIRE_SIZE];
209 frame[0] = 0x01; frame[1] = 0x00; frame[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
212 frame
213 }
214
215 fn build_msg2_frame() -> Vec<u8> {
217 let mut frame = vec![0xBB; MSG2_WIRE_SIZE];
218 frame[0] = 0x02; frame[1] = 0x00; frame[2..4].copy_from_slice(&MSG2_PAYLOAD_LEN.to_le_bytes());
221 frame
222 }
223
224 #[tokio::test]
225 async fn test_read_established_frame() {
226 let payload_len = 64u16;
227 let frame = build_established_frame(payload_len);
228 let expected = frame.clone();
229
230 let mut cursor = Cursor::new(frame);
231 let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
232 assert_eq!(packet, expected);
233 }
234
235 #[tokio::test]
236 async fn test_read_msg1_frame() {
237 let frame = build_msg1_frame();
238 let expected = frame.clone();
239
240 let mut cursor = Cursor::new(frame);
241 let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
242 assert_eq!(packet.len(), MSG1_WIRE_SIZE);
243 assert_eq!(packet, expected);
244 }
245
246 #[tokio::test]
247 async fn test_read_msg2_frame() {
248 let frame = build_msg2_frame();
249 let expected = frame.clone();
250
251 let mut cursor = Cursor::new(frame);
252 let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
253 assert_eq!(packet.len(), MSG2_WIRE_SIZE);
254 assert_eq!(packet, expected);
255 }
256
257 #[tokio::test]
258 async fn test_read_multiple_packets() {
259 let mut data = Vec::new();
260 let msg1 = build_msg1_frame();
261 let est = build_established_frame(32);
262 let msg2 = build_msg2_frame();
263 data.extend_from_slice(&msg1);
264 data.extend_from_slice(&est);
265 data.extend_from_slice(&msg2);
266
267 let mut cursor = Cursor::new(data);
268 let p1 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
269 assert_eq!(p1.len(), MSG1_WIRE_SIZE);
270
271 let p2 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
272 assert_eq!(p2, est);
273
274 let p3 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
275 assert_eq!(p3.len(), MSG2_WIRE_SIZE);
276 }
277
278 #[tokio::test]
279 async fn test_unknown_version_error() {
280 let mut frame = vec![0u8; 100];
283 frame[0] = 0x16;
284 let mut cursor = Cursor::new(frame);
285 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
286 assert!(matches!(err, StreamError::UnknownVersion(1)));
287 }
288
289 #[tokio::test]
290 async fn test_unknown_phase_error() {
291 let mut frame = vec![0u8; 100];
292 frame[0] = 0x05; frame[2..4].copy_from_slice(&10u16.to_le_bytes());
294
295 let mut cursor = Cursor::new(frame);
296 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
297 assert!(matches!(err, StreamError::UnknownPhase(0x5)));
298 }
299
300 #[tokio::test]
301 async fn test_payload_too_large() {
302 let payload_len = 100u16; let mut prefix = [0u8; 4];
305 prefix[0] = 0x00; prefix[2..4].copy_from_slice(&payload_len.to_le_bytes());
307
308 let mut data = prefix.to_vec();
310 data.extend_from_slice(&[0u8; 200]); let mut cursor = Cursor::new(data);
313 let err = read_fmp_packet(&mut cursor, 100).await.unwrap_err();
314 assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
315 }
316
317 #[tokio::test]
318 async fn test_handshake_size_mismatch_msg1() {
319 let mut frame = vec![0u8; 200];
320 frame[0] = 0x01; frame[2..4].copy_from_slice(&50u16.to_le_bytes());
323
324 let mut cursor = Cursor::new(frame);
325 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
326 assert!(matches!(
327 err,
328 StreamError::HandshakeSizeMismatch { phase: 0x1, .. }
329 ));
330 }
331
332 #[tokio::test]
333 async fn test_handshake_size_mismatch_msg2() {
334 let mut frame = vec![0u8; 200];
335 frame[0] = 0x02; frame[2..4].copy_from_slice(&50u16.to_le_bytes());
338
339 let mut cursor = Cursor::new(frame);
340 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
341 assert!(matches!(
342 err,
343 StreamError::HandshakeSizeMismatch { phase: 0x2, .. }
344 ));
345 }
346
347 #[tokio::test]
348 async fn test_eof_on_prefix() {
349 let data = vec![0u8; 2];
351 let mut cursor = Cursor::new(data);
352 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
353 assert!(matches!(err, StreamError::Io(_)));
354 }
355
356 #[tokio::test]
357 async fn test_eof_on_body() {
358 let mut data = vec![0u8; 10]; data[0] = 0x01; data[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
362
363 let mut cursor = Cursor::new(data);
364 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
365 assert!(matches!(err, StreamError::Io(_)));
366 }
367
368 #[tokio::test]
369 async fn test_zero_payload_established() {
370 let frame = build_established_frame(0);
372 let expected_len = PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + AEAD_TAG_SIZE;
373 assert_eq!(frame.len(), expected_len);
374
375 let mut cursor = Cursor::new(frame.clone());
376 let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
377 assert_eq!(packet.len(), expected_len);
378 assert_eq!(packet, frame);
379 }
380
381 #[tokio::test]
382 async fn test_max_payload_at_mtu_boundary() {
383 let max_payload = 1400u16 - 32;
385 let frame = build_established_frame(max_payload);
386
387 let mut cursor = Cursor::new(frame.clone());
388 let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
389 assert_eq!(packet, frame);
390 }
391
392 #[tokio::test]
393 async fn test_payload_one_over_mtu() {
394 let over = 1400u16 - 32 + 1;
396 let mut prefix = [0u8; 4];
397 prefix[0] = 0x00; prefix[2..4].copy_from_slice(&over.to_le_bytes());
399
400 let mut data = prefix.to_vec();
401 data.extend_from_slice(&vec![0u8; 2000]);
402
403 let mut cursor = Cursor::new(data);
404 let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
405 assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
406 }
407}