Skip to main content

fips_core/transport/tcp/
stream.rs

1//! FMP-Aware Stream Reader
2//!
3//! Recovers FIPS packet boundaries from a TCP byte stream using the
4//! existing 4-byte FMP common prefix `[ver+phase:1][flags:1][payload_len:2 LE]`.
5//!
6//! This module is deliberately separate from the TCP transport so it can
7//! be reused by the future Tor transport.
8
9use tokio::io::{AsyncRead, AsyncReadExt};
10
11/// FMP phase values (low nibble of byte 0).
12const PHASE_ESTABLISHED: u8 = 0x0;
13const PHASE_MSG1: u8 = 0x1;
14const PHASE_MSG2: u8 = 0x2;
15
16/// Size of the FMP common prefix.
17const PREFIX_SIZE: usize = 4;
18
19/// Overhead for established frames: 12 bytes remaining header + 16 bytes AEAD tag.
20/// The full established header is 16 bytes (PREFIX_SIZE + 12), so after reading
21/// the 4-byte prefix, 12 more header bytes remain. Then payload_len bytes of
22/// ciphertext, then 16 bytes of AEAD tag.
23const ESTABLISHED_REMAINING_HEADER: usize = 12;
24const AEAD_TAG_SIZE: usize = 16;
25
26/// Errors from the FMP stream reader.
27#[derive(Debug)]
28pub enum StreamError {
29    /// Unknown FMP version — not a FIPS connection (e.g., TLS ClientHello).
30    UnknownVersion(u8),
31    /// Unknown FMP phase byte — protocol error, close connection.
32    UnknownPhase(u8),
33    /// Payload length exceeds the connection's MTU — corrupted or malicious.
34    PayloadTooLarge {
35        payload_len: u16,
36        max_payload_len: u16,
37    },
38    /// Handshake packet has unexpected payload_len for its phase.
39    HandshakeSizeMismatch { phase: u8, expected: u16, got: u16 },
40    /// I/O error (including EOF).
41    Io(std::io::Error),
42}
43
44impl std::fmt::Display for StreamError {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        match self {
47            StreamError::UnknownVersion(v) => write!(f, "unknown FMP version: {}", v),
48            StreamError::UnknownPhase(p) => write!(f, "unknown FMP phase: 0x{:02x}", p),
49            StreamError::PayloadTooLarge {
50                payload_len,
51                max_payload_len,
52            } => {
53                write!(
54                    f,
55                    "payload_len {} exceeds max {}",
56                    payload_len, max_payload_len
57                )
58            }
59            StreamError::HandshakeSizeMismatch {
60                phase,
61                expected,
62                got,
63            } => {
64                write!(
65                    f,
66                    "handshake phase 0x{:x}: expected payload_len {}, got {}",
67                    phase, expected, got
68                )
69            }
70            StreamError::Io(e) => write!(f, "io: {}", e),
71        }
72    }
73}
74
75impl std::error::Error for StreamError {}
76
77impl From<std::io::Error> for StreamError {
78    fn from(e: std::io::Error) -> Self {
79        StreamError::Io(e)
80    }
81}
82
83/// Known wire sizes for handshake messages.
84/// msg1: 4 (prefix) + 4 (sender_idx) + 106 (noise_msg1) = 114 bytes
85/// msg2: 4 (prefix) + 4 (sender_idx) + 4 (receiver_idx) + 57 (noise_msg2) = 69 bytes
86const MSG1_WIRE_SIZE: usize = 114;
87const MSG2_WIRE_SIZE: usize = 69;
88
89/// Expected payload_len for msg1: sender_idx(4) + noise_msg1(106) = 110.
90const MSG1_PAYLOAD_LEN: u16 = (MSG1_WIRE_SIZE - PREFIX_SIZE) as u16;
91
92/// Expected payload_len for msg2: sender_idx(4) + receiver_idx(4) + noise_msg2(57) = 65.
93const MSG2_PAYLOAD_LEN: u16 = (MSG2_WIRE_SIZE - PREFIX_SIZE) as u16;
94
95/// Read one complete FMP packet from an async reader.
96///
97/// Uses the 4-byte FMP common prefix to determine the total packet size,
98/// then reads the remaining bytes. Returns the complete packet as a `Vec<u8>`.
99///
100/// # Arguments
101///
102/// * `reader` - Any async reader (typically an `OwnedReadHalf`)
103/// * `mtu` - The connection's MTU for validation of established frame sizes
104///
105/// # Errors
106///
107/// * `UnknownVersion` — non-zero version nibble (not a FIPS connection)
108/// * `UnknownPhase` — unrecognized phase nibble (protocol error)
109/// * `PayloadTooLarge` — established frame exceeds MTU
110/// * `HandshakeSizeMismatch` — handshake packet has wrong payload_len
111/// * `Io` — underlying read error (including EOF)
112pub async fn read_fmp_packet<R: AsyncRead + Unpin>(
113    reader: &mut R,
114    mtu: u16,
115) -> Result<Vec<u8>, StreamError> {
116    // Read the 4-byte FMP common prefix
117    let mut prefix = [0u8; PREFIX_SIZE];
118    reader.read_exact(&mut prefix).await?;
119
120    let version = prefix[0] >> 4;
121    let phase = prefix[0] & 0x0F;
122
123    if version != 0 {
124        return Err(StreamError::UnknownVersion(version));
125    }
126
127    let payload_len = u16::from_le_bytes([prefix[2], prefix[3]]);
128
129    // Compute remaining bytes based on phase
130    let remaining = match phase {
131        PHASE_ESTABLISHED => {
132            // Validate payload_len against MTU:
133            // total packet = 16 (header) + payload_len + 16 (tag) = payload_len + 32
134            // max_payload_len = mtu - 32
135            let max_payload_len = mtu.saturating_sub(
136                (ESTABLISHED_REMAINING_HEADER + PREFIX_SIZE + AEAD_TAG_SIZE) as u16,
137            );
138            if payload_len > max_payload_len {
139                return Err(StreamError::PayloadTooLarge {
140                    payload_len,
141                    max_payload_len,
142                });
143            }
144            // remaining = 12 (rest of header) + payload_len + 16 (AEAD tag)
145            ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE
146        }
147        PHASE_MSG1 => {
148            if payload_len != MSG1_PAYLOAD_LEN {
149                return Err(StreamError::HandshakeSizeMismatch {
150                    phase,
151                    expected: MSG1_PAYLOAD_LEN,
152                    got: payload_len,
153                });
154            }
155            payload_len as usize
156        }
157        PHASE_MSG2 => {
158            if payload_len != MSG2_PAYLOAD_LEN {
159                return Err(StreamError::HandshakeSizeMismatch {
160                    phase,
161                    expected: MSG2_PAYLOAD_LEN,
162                    got: payload_len,
163                });
164            }
165            payload_len as usize
166        }
167        _ => {
168            return Err(StreamError::UnknownPhase(phase));
169        }
170    };
171
172    // Allocate buffer for the complete packet (prefix + remaining)
173    let total = PREFIX_SIZE + remaining;
174    let mut packet = vec![0u8; total];
175    packet[..PREFIX_SIZE].copy_from_slice(&prefix);
176    reader.read_exact(&mut packet[PREFIX_SIZE..]).await?;
177
178    Ok(packet)
179}
180
181// ============================================================================
182// Tests
183// ============================================================================
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use std::io::Cursor;
189
190    /// Build a minimal established frame with the given payload_len.
191    /// Layout: [ver+phase:1][flags:1][payload_len:2 LE][12 bytes header][payload_len bytes][16 bytes tag]
192    fn build_established_frame(payload_len: u16) -> Vec<u8> {
193        let total =
194            PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE;
195        let mut frame = vec![0u8; total];
196        frame[0] = 0x00; // ver=0, phase=0 (established)
197        frame[1] = 0x00; // flags
198        frame[2..4].copy_from_slice(&payload_len.to_le_bytes());
199        // Fill remaining with pattern for verification
200        for (i, byte) in frame[PREFIX_SIZE..total].iter_mut().enumerate() {
201            *byte = ((PREFIX_SIZE + i) & 0xFF) as u8;
202        }
203        frame
204    }
205
206    /// Build a msg1 frame (114 bytes total).
207    fn build_msg1_frame() -> Vec<u8> {
208        let mut frame = vec![0xAA; MSG1_WIRE_SIZE];
209        frame[0] = 0x01; // ver=0, phase=1
210        frame[1] = 0x00; // flags
211        frame[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
212        frame
213    }
214
215    /// Build a msg2 frame (69 bytes total).
216    fn build_msg2_frame() -> Vec<u8> {
217        let mut frame = vec![0xBB; MSG2_WIRE_SIZE];
218        frame[0] = 0x02; // ver=0, phase=2
219        frame[1] = 0x00; // flags
220        frame[2..4].copy_from_slice(&MSG2_PAYLOAD_LEN.to_le_bytes());
221        frame
222    }
223
224    #[tokio::test]
225    async fn test_read_established_frame() {
226        let payload_len = 64u16;
227        let frame = build_established_frame(payload_len);
228        let expected = frame.clone();
229
230        let mut cursor = Cursor::new(frame);
231        let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
232        assert_eq!(packet, expected);
233    }
234
235    #[tokio::test]
236    async fn test_read_msg1_frame() {
237        let frame = build_msg1_frame();
238        let expected = frame.clone();
239
240        let mut cursor = Cursor::new(frame);
241        let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
242        assert_eq!(packet.len(), MSG1_WIRE_SIZE);
243        assert_eq!(packet, expected);
244    }
245
246    #[tokio::test]
247    async fn test_read_msg2_frame() {
248        let frame = build_msg2_frame();
249        let expected = frame.clone();
250
251        let mut cursor = Cursor::new(frame);
252        let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
253        assert_eq!(packet.len(), MSG2_WIRE_SIZE);
254        assert_eq!(packet, expected);
255    }
256
257    #[tokio::test]
258    async fn test_read_multiple_packets() {
259        let mut data = Vec::new();
260        let msg1 = build_msg1_frame();
261        let est = build_established_frame(32);
262        let msg2 = build_msg2_frame();
263        data.extend_from_slice(&msg1);
264        data.extend_from_slice(&est);
265        data.extend_from_slice(&msg2);
266
267        let mut cursor = Cursor::new(data);
268        let p1 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
269        assert_eq!(p1.len(), MSG1_WIRE_SIZE);
270
271        let p2 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
272        assert_eq!(p2, est);
273
274        let p3 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
275        assert_eq!(p3.len(), MSG2_WIRE_SIZE);
276    }
277
278    #[tokio::test]
279    async fn test_unknown_version_error() {
280        // TLS ClientHello starts with 0x16 (record type "Handshake"),
281        // which parses as FMP version=1, phase=6.
282        let mut frame = vec![0u8; 100];
283        frame[0] = 0x16;
284        let mut cursor = Cursor::new(frame);
285        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
286        assert!(matches!(err, StreamError::UnknownVersion(1)));
287    }
288
289    #[tokio::test]
290    async fn test_unknown_phase_error() {
291        let mut frame = vec![0u8; 100];
292        frame[0] = 0x05; // unknown phase
293        frame[2..4].copy_from_slice(&10u16.to_le_bytes());
294
295        let mut cursor = Cursor::new(frame);
296        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
297        assert!(matches!(err, StreamError::UnknownPhase(0x5)));
298    }
299
300    #[tokio::test]
301    async fn test_payload_too_large() {
302        // mtu=100, max_payload_len = 100 - 32 = 68
303        let payload_len = 100u16; // exceeds max of 68
304        let mut prefix = [0u8; 4];
305        prefix[0] = 0x00; // established
306        prefix[2..4].copy_from_slice(&payload_len.to_le_bytes());
307
308        // Provide enough bytes for the reader to read prefix
309        let mut data = prefix.to_vec();
310        data.extend_from_slice(&[0u8; 200]); // extra bytes
311
312        let mut cursor = Cursor::new(data);
313        let err = read_fmp_packet(&mut cursor, 100).await.unwrap_err();
314        assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
315    }
316
317    #[tokio::test]
318    async fn test_handshake_size_mismatch_msg1() {
319        let mut frame = vec![0u8; 200];
320        frame[0] = 0x01; // msg1
321        // Wrong payload_len (should be 110)
322        frame[2..4].copy_from_slice(&50u16.to_le_bytes());
323
324        let mut cursor = Cursor::new(frame);
325        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
326        assert!(matches!(
327            err,
328            StreamError::HandshakeSizeMismatch { phase: 0x1, .. }
329        ));
330    }
331
332    #[tokio::test]
333    async fn test_handshake_size_mismatch_msg2() {
334        let mut frame = vec![0u8; 200];
335        frame[0] = 0x02; // msg2
336        // Wrong payload_len (should be 65)
337        frame[2..4].copy_from_slice(&50u16.to_le_bytes());
338
339        let mut cursor = Cursor::new(frame);
340        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
341        assert!(matches!(
342            err,
343            StreamError::HandshakeSizeMismatch { phase: 0x2, .. }
344        ));
345    }
346
347    #[tokio::test]
348    async fn test_eof_on_prefix() {
349        // Only 2 bytes available (need 4 for prefix)
350        let data = vec![0u8; 2];
351        let mut cursor = Cursor::new(data);
352        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
353        assert!(matches!(err, StreamError::Io(_)));
354    }
355
356    #[tokio::test]
357    async fn test_eof_on_body() {
358        // Valid msg1 prefix but truncated body
359        let mut data = vec![0u8; 10]; // need 114 total
360        data[0] = 0x01; // msg1
361        data[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
362
363        let mut cursor = Cursor::new(data);
364        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
365        assert!(matches!(err, StreamError::Io(_)));
366    }
367
368    #[tokio::test]
369    async fn test_zero_payload_established() {
370        // payload_len = 0 is valid (header-only encrypted frame with tag)
371        let frame = build_established_frame(0);
372        let expected_len = PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + AEAD_TAG_SIZE;
373        assert_eq!(frame.len(), expected_len);
374
375        let mut cursor = Cursor::new(frame.clone());
376        let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
377        assert_eq!(packet.len(), expected_len);
378        assert_eq!(packet, frame);
379    }
380
381    #[tokio::test]
382    async fn test_max_payload_at_mtu_boundary() {
383        // mtu=1400, max_payload_len = 1400 - 32 = 1368
384        let max_payload = 1400u16 - 32;
385        let frame = build_established_frame(max_payload);
386
387        let mut cursor = Cursor::new(frame.clone());
388        let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
389        assert_eq!(packet, frame);
390    }
391
392    #[tokio::test]
393    async fn test_payload_one_over_mtu() {
394        // mtu=1400, max_payload_len = 1368, try 1369
395        let over = 1400u16 - 32 + 1;
396        let mut prefix = [0u8; 4];
397        prefix[0] = 0x00; // established
398        prefix[2..4].copy_from_slice(&over.to_le_bytes());
399
400        let mut data = prefix.to_vec();
401        data.extend_from_slice(&vec![0u8; 2000]);
402
403        let mut cursor = Cursor::new(data);
404        let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
405        assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
406    }
407}