use crate::bytes::BytesMut;
use crate::codec::Encoder;
use crate::net::atp::protocol::codec::AtpFrameCodec;
use crate::net::atp::protocol::varint::{VarInt, VarIntError};
use crate::types::outcome::Outcome;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ProtocolVersion(pub u32);
impl ProtocolVersion {
pub const V0: Self = ProtocolVersion(0);
pub const CURRENT: Self = Self::V0;
}
impl fmt::Display for ProtocolVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ATP/{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u16)]
pub enum FrameType {
Handshake = 0x0001,
HandshakeAck = 0x0002,
Capabilities = 0x0003,
CapabilitiesAck = 0x0004,
ObjectManifest = 0x0100,
ObjectRequest = 0x0101,
ObjectData = 0x0102,
ObjectComplete = 0x0103,
ObjectError = 0x0104,
PathUpdate = 0x0200,
PathChallenge = 0x0201,
PathResponse = 0x0202,
KeepAlive = 0x0203,
Cancel = 0x0300,
Error = 0x0301,
Close = 0x0302,
Control = 0x0400,
Data = 0x0401,
Proof = 0x0402,
Repair = 0x0403,
Session = 0x0404,
Manifest = 0x0405,
}
impl FrameType {
pub fn to_varint(self) -> VarInt {
match VarInt::new(self as u64) {
Outcome::Ok(varint) => varint,
_ => panic!("frame type fits in varint"),
}
}
pub fn from_varint(varint: VarInt) -> Result<Self, FrameError> {
match varint.value() {
0x0001 => Ok(FrameType::Handshake),
0x0002 => Ok(FrameType::HandshakeAck),
0x0003 => Ok(FrameType::Capabilities),
0x0004 => Ok(FrameType::CapabilitiesAck),
0x0100 => Ok(FrameType::ObjectManifest),
0x0101 => Ok(FrameType::ObjectRequest),
0x0102 => Ok(FrameType::ObjectData),
0x0103 => Ok(FrameType::ObjectComplete),
0x0104 => Ok(FrameType::ObjectError),
0x0200 => Ok(FrameType::PathUpdate),
0x0201 => Ok(FrameType::PathChallenge),
0x0202 => Ok(FrameType::PathResponse),
0x0203 => Ok(FrameType::KeepAlive),
0x0300 => Ok(FrameType::Cancel),
0x0301 => Ok(FrameType::Error),
0x0302 => Ok(FrameType::Close),
0x0400 => Ok(FrameType::Control),
0x0401 => Ok(FrameType::Data),
0x0402 => Ok(FrameType::Proof),
0x0403 => Ok(FrameType::Repair),
0x0404 => Ok(FrameType::Session),
0x0405 => Ok(FrameType::Manifest),
other => Err(FrameError::UnknownFrameType(other)),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FrameHeader {
pub version: ProtocolVersion,
pub frame_type: FrameType,
pub payload_length: VarInt,
pub extensions: HashMap<u16, Vec<u8>>,
}
impl FrameHeader {
pub fn new(
version: ProtocolVersion,
frame_type: FrameType,
payload_length: u64,
) -> Result<Self, FrameError> {
let payload_varint = match VarInt::new(payload_length) {
Outcome::Ok(varint) => varint,
_ => {
return Err(FrameError::InvalidFormat(
"Invalid payload length".to_string(),
));
}
};
Ok(FrameHeader {
version,
frame_type,
payload_length: payload_varint,
extensions: HashMap::new(),
})
}
pub fn with_extension(mut self, extension_id: u16, data: Vec<u8>) -> Self {
self.extensions.insert(extension_id, data);
self
}
pub fn encoded_len(&self) -> usize {
let mut len = 0;
len += match VarInt::new(self.version.0 as u64) {
Outcome::Ok(varint) => varint.encoded_len(),
_ => panic!("version fits in varint"),
};
len += self.frame_type.to_varint().encoded_len();
len += self.payload_length.encoded_len();
len += VarInt::new(self.extensions.len() as u64)
.unwrap()
.encoded_len();
for (extension_id, data) in &self.extensions {
len += match VarInt::new(*extension_id as u64) {
Outcome::Ok(varint) => varint.encoded_len(),
_ => panic!("u16 extension id fits in varint"),
};
len += match VarInt::new(data.len() as u64) {
Outcome::Ok(varint) => varint.encoded_len(),
_ => panic!("data length fits in varint"),
};
len += data.len(); }
len
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Frame {
pub header: FrameHeader,
pub payload: Vec<u8>,
}
impl Frame {
pub fn new(
version: ProtocolVersion,
frame_type: FrameType,
payload: Vec<u8>,
) -> Result<Self, FrameError> {
let header = FrameHeader::new(version, frame_type, payload.len() as u64)?;
Ok(Frame { header, payload })
}
pub fn encoded_len(&self) -> usize {
self.header.encoded_len() + self.payload.len()
}
pub fn frame_type(&self) -> FrameType {
self.header.frame_type
}
pub fn version(&self) -> ProtocolVersion {
self.header.version
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
pub fn empty(frame_type: FrameType) -> Result<Self, FrameError> {
Self::new(ProtocolVersion::CURRENT, frame_type, Vec::new())
}
pub fn to_wire_bytes(&self) -> Result<Vec<u8>, FrameError> {
let mut codec = AtpFrameCodec::new();
let mut encoded = BytesMut::with_capacity(self.encoded_len());
codec.encode(self.clone(), &mut encoded)?;
Ok(encoded.to_vec())
}
}
#[derive(Debug, thiserror::Error)]
pub enum FrameError {
#[error("varint encoding error: {0}")]
VarInt(#[from] VarIntError),
#[error("unknown frame type: {0}")]
UnknownFrameType(u64),
#[error("unsupported protocol version: {0}")]
UnsupportedVersion(u32),
#[error("frame too large: {size} bytes (max: {max})")]
FrameTooLarge {
size: u64,
max: u64,
},
#[error("invalid frame format: {0}")]
InvalidFormat(String),
#[error("unexpected end of frame data")]
UnexpectedEof,
#[error("extension too large: {size} bytes")]
ExtensionTooLarge {
size: u64,
},
}
pub const MAX_FRAME_SIZE: u64 = 1024 * 1024;
pub const MAX_EXTENSION_SIZE: u64 = 4096;
pub const MAX_EXTENSION_COUNT: u64 = 64;
pub const MAX_HEADER_SIZE: u64 = 32 * 1024;
impl From<std::io::Error> for FrameError {
fn from(err: std::io::Error) -> Self {
FrameError::InvalidFormat(format!("I/O error: {err}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_type_roundtrip() {
let frame_types = [
FrameType::Handshake,
FrameType::HandshakeAck,
FrameType::Capabilities,
FrameType::ObjectManifest,
FrameType::ObjectData,
FrameType::PathUpdate,
FrameType::Cancel,
FrameType::Error,
FrameType::Close,
];
for frame_type in frame_types {
let varint = frame_type.to_varint();
let parsed = FrameType::from_varint(varint).unwrap();
assert_eq!(parsed, frame_type);
}
}
#[test]
fn test_frame_creation() {
let payload = b"Hello, ATP!".to_vec();
let frame = Frame::new(ProtocolVersion::V0, FrameType::Handshake, payload.clone()).unwrap();
assert_eq!(frame.version(), ProtocolVersion::V0);
assert_eq!(frame.frame_type(), FrameType::Handshake);
assert_eq!(frame.payload(), payload);
}
#[test]
fn test_frame_header_with_extensions() {
let header = FrameHeader::new(ProtocolVersion::V0, FrameType::Capabilities, 100)
.unwrap()
.with_extension(1, b"ext1".to_vec())
.with_extension(2, b"extension2".to_vec());
assert_eq!(header.extensions.len(), 2);
assert_eq!(header.extensions[&1], b"ext1");
assert_eq!(header.extensions[&2], b"extension2");
}
#[test]
fn frame_empty_uses_protocol_wire_bytes_not_marker_payloads() {
let frame = Frame::empty(FrameType::Control).unwrap();
assert_eq!(frame.payload(), b"");
assert_eq!(
frame.to_wire_bytes().unwrap(),
vec![0x00, 0x44, 0x00, 0x00, 0x00]
);
}
#[test]
fn test_protocol_version_display() {
assert_eq!(ProtocolVersion::V0.to_string(), "ATP/0");
assert_eq!(ProtocolVersion(42).to_string(), "ATP/42");
}
}