use tokio::io::{AsyncRead, AsyncReadExt};
const PHASE_ESTABLISHED: u8 = 0x0;
const PHASE_MSG1: u8 = 0x1;
const PHASE_MSG2: u8 = 0x2;
const PREFIX_SIZE: usize = 4;
const ESTABLISHED_REMAINING_HEADER: usize = 12;
const AEAD_TAG_SIZE: usize = 16;
#[derive(Debug)]
pub enum StreamError {
UnknownVersion(u8),
UnknownPhase(u8),
PayloadTooLarge {
payload_len: u16,
max_payload_len: u16,
},
HandshakeSizeMismatch { phase: u8, expected: u16, got: u16 },
Io(std::io::Error),
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::UnknownVersion(v) => write!(f, "unknown FMP version: {}", v),
StreamError::UnknownPhase(p) => write!(f, "unknown FMP phase: 0x{:02x}", p),
StreamError::PayloadTooLarge {
payload_len,
max_payload_len,
} => {
write!(
f,
"payload_len {} exceeds max {}",
payload_len, max_payload_len
)
}
StreamError::HandshakeSizeMismatch {
phase,
expected,
got,
} => {
write!(
f,
"handshake phase 0x{:x}: expected payload_len {}, got {}",
phase, expected, got
)
}
StreamError::Io(e) => write!(f, "io: {}", e),
}
}
}
impl std::error::Error for StreamError {}
impl From<std::io::Error> for StreamError {
fn from(e: std::io::Error) -> Self {
StreamError::Io(e)
}
}
const MSG1_WIRE_SIZE: usize = 114;
const MSG2_WIRE_SIZE: usize = 69;
const MSG1_PAYLOAD_LEN: u16 = (MSG1_WIRE_SIZE - PREFIX_SIZE) as u16;
const MSG2_PAYLOAD_LEN: u16 = (MSG2_WIRE_SIZE - PREFIX_SIZE) as u16;
pub async fn read_fmp_packet<R: AsyncRead + Unpin>(
reader: &mut R,
mtu: u16,
) -> Result<Vec<u8>, StreamError> {
let mut prefix = [0u8; PREFIX_SIZE];
reader.read_exact(&mut prefix).await?;
let version = prefix[0] >> 4;
let phase = prefix[0] & 0x0F;
if version != 0 {
return Err(StreamError::UnknownVersion(version));
}
let payload_len = u16::from_le_bytes([prefix[2], prefix[3]]);
let remaining = match phase {
PHASE_ESTABLISHED => {
let max_payload_len = mtu.saturating_sub(
(ESTABLISHED_REMAINING_HEADER + PREFIX_SIZE + AEAD_TAG_SIZE) as u16,
);
if payload_len > max_payload_len {
return Err(StreamError::PayloadTooLarge {
payload_len,
max_payload_len,
});
}
ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE
}
PHASE_MSG1 => {
if payload_len != MSG1_PAYLOAD_LEN {
return Err(StreamError::HandshakeSizeMismatch {
phase,
expected: MSG1_PAYLOAD_LEN,
got: payload_len,
});
}
payload_len as usize
}
PHASE_MSG2 => {
if payload_len != MSG2_PAYLOAD_LEN {
return Err(StreamError::HandshakeSizeMismatch {
phase,
expected: MSG2_PAYLOAD_LEN,
got: payload_len,
});
}
payload_len as usize
}
_ => {
return Err(StreamError::UnknownPhase(phase));
}
};
let total = PREFIX_SIZE + remaining;
let mut packet = vec![0u8; total];
packet[..PREFIX_SIZE].copy_from_slice(&prefix);
reader.read_exact(&mut packet[PREFIX_SIZE..]).await?;
Ok(packet)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn build_established_frame(payload_len: u16) -> Vec<u8> {
let total =
PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + payload_len as usize + AEAD_TAG_SIZE;
let mut frame = vec![0u8; total];
frame[0] = 0x00; frame[1] = 0x00; frame[2..4].copy_from_slice(&payload_len.to_le_bytes());
for (i, byte) in frame[PREFIX_SIZE..total].iter_mut().enumerate() {
*byte = ((PREFIX_SIZE + i) & 0xFF) as u8;
}
frame
}
fn build_msg1_frame() -> Vec<u8> {
let mut frame = vec![0xAA; MSG1_WIRE_SIZE];
frame[0] = 0x01; frame[1] = 0x00; frame[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
frame
}
fn build_msg2_frame() -> Vec<u8> {
let mut frame = vec![0xBB; MSG2_WIRE_SIZE];
frame[0] = 0x02; frame[1] = 0x00; frame[2..4].copy_from_slice(&MSG2_PAYLOAD_LEN.to_le_bytes());
frame
}
#[tokio::test]
async fn test_read_established_frame() {
let payload_len = 64u16;
let frame = build_established_frame(payload_len);
let expected = frame.clone();
let mut cursor = Cursor::new(frame);
let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(packet, expected);
}
#[tokio::test]
async fn test_read_msg1_frame() {
let frame = build_msg1_frame();
let expected = frame.clone();
let mut cursor = Cursor::new(frame);
let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(packet.len(), MSG1_WIRE_SIZE);
assert_eq!(packet, expected);
}
#[tokio::test]
async fn test_read_msg2_frame() {
let frame = build_msg2_frame();
let expected = frame.clone();
let mut cursor = Cursor::new(frame);
let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(packet.len(), MSG2_WIRE_SIZE);
assert_eq!(packet, expected);
}
#[tokio::test]
async fn test_read_multiple_packets() {
let mut data = Vec::new();
let msg1 = build_msg1_frame();
let est = build_established_frame(32);
let msg2 = build_msg2_frame();
data.extend_from_slice(&msg1);
data.extend_from_slice(&est);
data.extend_from_slice(&msg2);
let mut cursor = Cursor::new(data);
let p1 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(p1.len(), MSG1_WIRE_SIZE);
let p2 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(p2, est);
let p3 = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(p3.len(), MSG2_WIRE_SIZE);
}
#[tokio::test]
async fn test_unknown_version_error() {
let mut frame = vec![0u8; 100];
frame[0] = 0x16;
let mut cursor = Cursor::new(frame);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(err, StreamError::UnknownVersion(1)));
}
#[tokio::test]
async fn test_unknown_phase_error() {
let mut frame = vec![0u8; 100];
frame[0] = 0x05; frame[2..4].copy_from_slice(&10u16.to_le_bytes());
let mut cursor = Cursor::new(frame);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(err, StreamError::UnknownPhase(0x5)));
}
#[tokio::test]
async fn test_payload_too_large() {
let payload_len = 100u16; let mut prefix = [0u8; 4];
prefix[0] = 0x00; prefix[2..4].copy_from_slice(&payload_len.to_le_bytes());
let mut data = prefix.to_vec();
data.extend_from_slice(&[0u8; 200]);
let mut cursor = Cursor::new(data);
let err = read_fmp_packet(&mut cursor, 100).await.unwrap_err();
assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
}
#[tokio::test]
async fn test_handshake_size_mismatch_msg1() {
let mut frame = vec![0u8; 200];
frame[0] = 0x01; frame[2..4].copy_from_slice(&50u16.to_le_bytes());
let mut cursor = Cursor::new(frame);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(
err,
StreamError::HandshakeSizeMismatch { phase: 0x1, .. }
));
}
#[tokio::test]
async fn test_handshake_size_mismatch_msg2() {
let mut frame = vec![0u8; 200];
frame[0] = 0x02; frame[2..4].copy_from_slice(&50u16.to_le_bytes());
let mut cursor = Cursor::new(frame);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(
err,
StreamError::HandshakeSizeMismatch { phase: 0x2, .. }
));
}
#[tokio::test]
async fn test_eof_on_prefix() {
let data = vec![0u8; 2];
let mut cursor = Cursor::new(data);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(err, StreamError::Io(_)));
}
#[tokio::test]
async fn test_eof_on_body() {
let mut data = vec![0u8; 10]; data[0] = 0x01; data[2..4].copy_from_slice(&MSG1_PAYLOAD_LEN.to_le_bytes());
let mut cursor = Cursor::new(data);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(err, StreamError::Io(_)));
}
#[tokio::test]
async fn test_zero_payload_established() {
let frame = build_established_frame(0);
let expected_len = PREFIX_SIZE + ESTABLISHED_REMAINING_HEADER + AEAD_TAG_SIZE;
assert_eq!(frame.len(), expected_len);
let mut cursor = Cursor::new(frame.clone());
let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(packet.len(), expected_len);
assert_eq!(packet, frame);
}
#[tokio::test]
async fn test_max_payload_at_mtu_boundary() {
let max_payload = 1400u16 - 32;
let frame = build_established_frame(max_payload);
let mut cursor = Cursor::new(frame.clone());
let packet = read_fmp_packet(&mut cursor, 1400).await.unwrap();
assert_eq!(packet, frame);
}
#[tokio::test]
async fn test_payload_one_over_mtu() {
let over = 1400u16 - 32 + 1;
let mut prefix = [0u8; 4];
prefix[0] = 0x00; prefix[2..4].copy_from_slice(&over.to_le_bytes());
let mut data = prefix.to_vec();
data.extend_from_slice(&vec![0u8; 2000]);
let mut cursor = Cursor::new(data);
let err = read_fmp_packet(&mut cursor, 1400).await.unwrap_err();
assert!(matches!(err, StreamError::PayloadTooLarge { .. }));
}
}