use std::sync::Arc;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::grease::is_grease_value;
use crate::io::read_incremental;
use crate::{VarInt, VarIntUnexpectedEnd};
const CLOSE_WEBTRANSPORT_SESSION_TYPE: u64 = 0x2843;
const MAX_MESSAGE_SIZE: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Capsule {
CloseWebTransportSession {
code: u32,
reason: String,
},
Unknown {
typ: VarInt,
payload: Bytes,
},
}
impl Capsule {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, CapsuleError> {
loop {
let typ = VarInt::decode(buf)?;
let length = VarInt::decode(buf)?;
let mut payload = buf.take(length.into_inner() as usize);
if payload.remaining() > MAX_MESSAGE_SIZE {
return Err(CapsuleError::MessageTooLong);
}
if payload.remaining() < payload.limit() {
return Err(CapsuleError::UnexpectedEnd);
}
match typ.into_inner() {
CLOSE_WEBTRANSPORT_SESSION_TYPE => {
if payload.remaining() < 4 {
return Err(CapsuleError::UnexpectedEnd);
}
let error_code = payload.get_u32();
let message_len = payload.remaining();
if message_len > MAX_MESSAGE_SIZE {
return Err(CapsuleError::MessageTooLong);
}
let message_bytes = payload.copy_to_bytes(message_len);
let error_message = String::from_utf8(message_bytes.to_vec())
.map_err(|_| CapsuleError::InvalidUtf8)?;
return Ok(Self::CloseWebTransportSession {
code: error_code,
reason: error_message,
});
}
t if is_grease(t) => continue,
_ => {
let payload_bytes = payload.copy_to_bytes(payload.remaining());
return Ok(Self::Unknown {
typ,
payload: payload_bytes,
});
}
}
}
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, CapsuleError> {
read_incremental(
stream,
|cursor| Self::decode(cursor),
|err| matches!(err, CapsuleError::UnexpectedEnd),
CapsuleError::UnexpectedEnd,
)
.await
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
match self {
Self::CloseWebTransportSession {
code: error_code,
reason: error_message,
} => {
VarInt::from_u64(CLOSE_WEBTRANSPORT_SESSION_TYPE)
.unwrap()
.encode(buf);
let length = 4 + error_message.len();
VarInt::from_u32(length as u32).encode(buf);
buf.put_u32(*error_code);
buf.put_slice(error_message.as_bytes());
}
Self::Unknown { typ, payload } => {
typ.encode(buf);
VarInt::try_from(payload.len()).unwrap().encode(buf);
buf.put_slice(payload);
}
}
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), CapsuleError> {
let mut buf = BytesMut::new();
self.encode(&mut buf);
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}
fn is_grease(val: u64) -> bool {
is_grease_value(val)
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CapsuleError {
#[error("unexpected end of buffer")]
UnexpectedEnd,
#[error("invalid UTF-8")]
InvalidUtf8,
#[error("message too long")]
MessageTooLong,
#[error("unknown capsule type: {0:?}")]
UnknownType(VarInt),
#[error("varint decode error: {0:?}")]
VarInt(#[from] VarIntUnexpectedEnd),
#[error("io error: {0}")]
Io(Arc<std::io::Error>),
}
impl From<std::io::Error> for CapsuleError {
fn from(err: std::io::Error) -> Self {
CapsuleError::Io(Arc::new(err))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_close_webtransport_session_decode() {
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data);
VarInt::from_u32(8).encode(&mut data);
data.extend_from_slice(b"\x00\x00\x01\xa4test");
let mut buf = data.as_slice();
let capsule = Capsule::decode(&mut buf).unwrap();
match capsule {
Capsule::CloseWebTransportSession {
code: error_code,
reason: error_message,
} => {
assert_eq!(error_code, 420);
assert_eq!(error_message, "test");
}
_ => panic!("Expected CloseWebTransportSession"),
}
assert_eq!(buf.len(), 0); }
#[test]
fn test_close_webtransport_session_encode() {
let capsule = Capsule::CloseWebTransportSession {
code: 420,
reason: "test".to_string(),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
assert_eq!(buf, b"\x68\x43\x08\x00\x00\x01\xa4test");
}
#[test]
fn test_close_webtransport_session_roundtrip() {
let original = Capsule::CloseWebTransportSession {
code: 12345,
reason: "Connection closed by application".to_string(),
};
let mut buf = Vec::new();
original.encode(&mut buf);
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(original, decoded);
assert_eq!(read_buf.len(), 0); }
#[test]
fn test_empty_error_message() {
let capsule = Capsule::CloseWebTransportSession {
code: 0,
reason: String::new(),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
assert_eq!(buf, b"\x68\x43\x04\x00\x00\x00\x00");
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(capsule, decoded);
}
#[test]
fn test_invalid_utf8() {
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data); VarInt::from_u32(5).encode(&mut data); data.extend_from_slice(b"\x00\x00\x00\x00"); data.push(0xFF);
let mut buf = data.as_slice();
let result = Capsule::decode(&mut buf);
assert!(matches!(result, Err(CapsuleError::InvalidUtf8)));
}
#[test]
fn test_truncated_error_code() {
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data); VarInt::from_u32(3).encode(&mut data); data.extend_from_slice(b"\x00\x00\x00");
let mut buf = data.as_slice();
let result = Capsule::decode(&mut buf);
assert!(matches!(result, Err(CapsuleError::UnexpectedEnd)));
}
#[test]
fn test_unknown_capsule() {
let unknown_type = 0x1234u64;
let payload_data = b"unknown payload";
let mut data = Vec::new();
VarInt::from_u64(unknown_type).unwrap().encode(&mut data);
VarInt::from_u32(payload_data.len() as u32).encode(&mut data);
data.extend_from_slice(payload_data);
let mut buf = data.as_slice();
let capsule = Capsule::decode(&mut buf).unwrap();
match capsule {
Capsule::Unknown { typ, payload } => {
assert_eq!(typ.into_inner(), unknown_type);
assert_eq!(payload.as_ref(), payload_data);
}
_ => panic!("Expected Unknown capsule"),
}
}
#[test]
fn test_unknown_capsule_roundtrip() {
let capsule = Capsule::Unknown {
typ: VarInt::from_u64(0x9999).unwrap(),
payload: Bytes::from("test payload"),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(capsule, decoded);
assert_eq!(read_buf.len(), 0);
}
}