use crate::varint::{self, VarInt};
pub const UNIDIRECTIONAL_STREAM_TYPE: u64 = 0x54;
pub const BIDIRECTIONAL_SIGNAL_VALUE: u64 = 0x41;
const UNI_TYPE_VARINT: VarInt = VarInt::from_static(UNIDIRECTIONAL_STREAM_TYPE);
const BIDI_SIGNAL_VARINT: VarInt = VarInt::from_static(BIDIRECTIONAL_SIGNAL_VALUE);
fn decode_varint(buf: &[u8]) -> Option<(u64, usize)> {
varint::decode(buf).ok().map(|(v, n)| (v.get(), n))
}
fn session_id_to_varint(session_id: u64) -> VarInt {
VarInt::new(session_id).expect("session_id fits in VarInt")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StreamHeader {
pub session_id: u64,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamHeaderDecodeError {
BufferTooShort,
InvalidFormat,
InvalidSessionId,
SessionIdOutOfRange,
}
impl StreamHeader {
pub fn new(session_id: u64) -> Result<Self, StreamHeaderDecodeError> {
if VarInt::new(session_id).is_err() {
return Err(StreamHeaderDecodeError::SessionIdOutOfRange);
}
if !session_id.is_multiple_of(4) {
return Err(StreamHeaderDecodeError::InvalidSessionId);
}
Ok(Self { session_id })
}
pub fn encode_unidirectional(&self, buf: &mut Vec<u8>) {
varint::encode_into_vec(buf, UNI_TYPE_VARINT);
varint::encode_into_vec(buf, session_id_to_varint(self.session_id));
}
pub fn encode_bidirectional(&self, buf: &mut Vec<u8>) {
varint::encode_into_vec(buf, BIDI_SIGNAL_VARINT);
varint::encode_into_vec(buf, session_id_to_varint(self.session_id));
}
pub fn decode_unidirectional(buf: &[u8]) -> Option<(Self, usize)> {
Self::decode_unidirectional_checked(buf).ok()
}
pub fn decode_unidirectional_checked(
buf: &[u8],
) -> Result<(Self, usize), StreamHeaderDecodeError> {
let mut offset = 0;
let (stream_type, len) =
decode_varint(&buf[offset..]).ok_or(StreamHeaderDecodeError::BufferTooShort)?;
offset += len;
if stream_type != UNIDIRECTIONAL_STREAM_TYPE {
return Err(StreamHeaderDecodeError::InvalidFormat);
}
let (session_id, len) =
decode_varint(&buf[offset..]).ok_or(StreamHeaderDecodeError::BufferTooShort)?;
offset += len;
if !session_id.is_multiple_of(4) {
return Err(StreamHeaderDecodeError::InvalidSessionId);
}
Ok((Self { session_id }, offset))
}
pub fn decode_bidirectional(buf: &[u8]) -> Option<(Self, usize)> {
Self::decode_bidirectional_checked(buf).ok()
}
pub fn decode_bidirectional_checked(
buf: &[u8],
) -> Result<(Self, usize), StreamHeaderDecodeError> {
let mut offset = 0;
let (signal_value, len) =
decode_varint(&buf[offset..]).ok_or(StreamHeaderDecodeError::BufferTooShort)?;
offset += len;
if signal_value != BIDIRECTIONAL_SIGNAL_VALUE {
return Err(StreamHeaderDecodeError::InvalidFormat);
}
let (session_id, len) =
decode_varint(&buf[offset..]).ok_or(StreamHeaderDecodeError::BufferTooShort)?;
offset += len;
if !session_id.is_multiple_of(4) {
return Err(StreamHeaderDecodeError::InvalidSessionId);
}
Ok((Self { session_id }, offset))
}
pub fn encoded_size(&self) -> usize {
BIDI_SIGNAL_VARINT.encoded_len() + session_id_to_varint(self.session_id).encoded_len()
}
}
#[derive(Debug)]
pub struct Stream {
stream_id: u64,
session_id: u64,
bidirectional: bool,
header_sent: bool,
header_received: bool,
bytes_sent: u64,
bytes_received: u64,
}
impl Stream {
pub fn new(stream_id: u64, session_id: u64, bidirectional: bool) -> Self {
Self {
stream_id,
session_id,
bidirectional,
header_sent: false,
header_received: false,
bytes_sent: 0,
bytes_received: 0,
}
}
pub fn stream_id(&self) -> u64 {
self.stream_id
}
pub fn session_id(&self) -> u64 {
self.session_id
}
pub fn is_bidirectional(&self) -> bool {
self.bidirectional
}
pub fn is_header_sent(&self) -> bool {
self.header_sent
}
pub fn is_header_received(&self) -> bool {
self.header_received
}
pub fn set_header_sent(&mut self) {
self.header_sent = true;
}
pub fn set_header_received(&mut self) {
self.header_received = true;
}
pub fn bytes_sent(&self) -> u64 {
self.bytes_sent
}
pub fn bytes_received(&self) -> u64 {
self.bytes_received
}
pub fn add_bytes_sent(&mut self, bytes: u64) {
self.bytes_sent += bytes;
}
pub fn add_bytes_received(&mut self, bytes: u64) {
self.bytes_received += bytes;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClassifiedUniStream {
WebTransport {
session_id: u64,
data_offset: usize,
},
Http3 {
stream_type: u64,
data_offset: usize,
},
}
pub fn classify_uni_stream(buf: &[u8]) -> Result<ClassifiedUniStream, varint::DecodeError> {
let (stream_type, type_len) = varint::decode(buf)?;
if stream_type.get() == UNIDIRECTIONAL_STREAM_TYPE {
let (session_id, session_id_len) = varint::decode(&buf[type_len..])?;
Ok(ClassifiedUniStream::WebTransport {
session_id: session_id.get(),
data_offset: type_len + session_id_len,
})
} else {
Ok(ClassifiedUniStream::Http3 {
stream_type: stream_type.get(),
data_offset: type_len,
})
}
}
pub fn classify_uni_stream_checked(
buf: &[u8],
) -> Result<ClassifiedUniStream, StreamHeaderDecodeError> {
let (stream_type, type_len) =
varint::decode(buf).map_err(|_| StreamHeaderDecodeError::BufferTooShort)?;
if stream_type.get() == UNIDIRECTIONAL_STREAM_TYPE {
let (session_id, session_id_len) = varint::decode(&buf[type_len..])
.map_err(|_| StreamHeaderDecodeError::BufferTooShort)?;
let session_id = session_id.get();
if !session_id.is_multiple_of(4) {
return Err(StreamHeaderDecodeError::InvalidSessionId);
}
Ok(ClassifiedUniStream::WebTransport {
session_id,
data_offset: type_len + session_id_len,
})
} else {
Ok(ClassifiedUniStream::Http3 {
stream_type: stream_type.get(),
data_offset: type_len,
})
}
}
pub mod stream_type {
pub fn is_client_initiated(stream_id: u64) -> bool {
stream_id & 0x01 == 0
}
pub fn is_server_initiated(stream_id: u64) -> bool {
stream_id & 0x01 != 0
}
pub fn is_bidirectional(stream_id: u64) -> bool {
stream_id & 0x02 == 0
}
pub fn is_unidirectional(stream_id: u64) -> bool {
stream_id & 0x02 != 0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_header_new_rejects_invalid_session_id() {
assert_eq!(
StreamHeader::new(1),
Err(StreamHeaderDecodeError::InvalidSessionId)
);
assert_eq!(
StreamHeader::new(2),
Err(StreamHeaderDecodeError::InvalidSessionId)
);
assert_eq!(
StreamHeader::new(3),
Err(StreamHeaderDecodeError::InvalidSessionId)
);
assert!(StreamHeader::new(0).is_ok());
assert!(StreamHeader::new(4).is_ok());
}
#[test]
fn test_stream_header_new_rejects_session_id_out_of_range() {
let too_large = (1u64 << 62) & !0x3;
assert_eq!(
StreamHeader::new(too_large),
Err(StreamHeaderDecodeError::SessionIdOutOfRange)
);
}
#[test]
fn test_stream_creation() {
let stream = Stream::new(4, 0, true);
assert_eq!(stream.stream_id(), 4);
assert_eq!(stream.session_id(), 0);
assert!(stream.is_bidirectional());
assert!(!stream.is_header_sent());
assert!(!stream.is_header_received());
assert_eq!(stream.bytes_sent(), 0);
assert_eq!(stream.bytes_received(), 0);
}
#[test]
fn test_stream_bytes_tracking() {
let mut stream = Stream::new(2, 0, false);
stream.add_bytes_sent(100);
stream.add_bytes_received(50);
assert_eq!(stream.bytes_sent(), 100);
assert_eq!(stream.bytes_received(), 50);
stream.add_bytes_sent(200);
assert_eq!(stream.bytes_sent(), 300);
}
#[test]
fn test_stream_type_helpers() {
assert!(stream_type::is_client_initiated(0));
assert!(stream_type::is_bidirectional(0));
assert!(stream_type::is_server_initiated(1));
assert!(stream_type::is_bidirectional(1));
assert!(stream_type::is_client_initiated(2));
assert!(stream_type::is_unidirectional(2));
assert!(stream_type::is_server_initiated(3));
assert!(stream_type::is_unidirectional(3));
}
#[test]
fn test_decode_unidirectional_checked_invalid_session_id() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, UNI_TYPE_VARINT);
varint::encode_into_vec(&mut buf, VarInt::from_static(5));
let result = StreamHeader::decode_unidirectional_checked(&buf);
assert_eq!(result, Err(StreamHeaderDecodeError::InvalidSessionId));
}
#[test]
fn test_decode_bidirectional_checked_invalid_session_id() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, BIDI_SIGNAL_VARINT);
varint::encode_into_vec(&mut buf, VarInt::from_static(7));
let result = StreamHeader::decode_bidirectional_checked(&buf);
assert_eq!(result, Err(StreamHeaderDecodeError::InvalidSessionId));
}
#[test]
fn test_classify_uni_stream_checked_webtransport_valid() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, UNI_TYPE_VARINT);
varint::encode_into_vec(&mut buf, VarInt::ZERO); let expected_offset = buf.len();
buf.extend_from_slice(b"payload");
let result = classify_uni_stream_checked(&buf).unwrap();
assert_eq!(
result,
ClassifiedUniStream::WebTransport {
session_id: 0,
data_offset: expected_offset,
}
);
}
#[test]
fn test_classify_uni_stream_checked_webtransport_invalid_session_id() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, UNI_TYPE_VARINT);
varint::encode_into_vec(&mut buf, VarInt::from_static(5)); let result = classify_uni_stream_checked(&buf);
assert_eq!(result, Err(StreamHeaderDecodeError::InvalidSessionId));
}
#[test]
fn test_classify_uni_stream_checked_http3() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, VarInt::ZERO);
buf.extend_from_slice(b"data");
let result = classify_uni_stream_checked(&buf).unwrap();
assert_eq!(
result,
ClassifiedUniStream::Http3 {
stream_type: 0x00,
data_offset: 1,
}
);
}
#[test]
fn test_classify_uni_stream_checked_buffer_too_short() {
let result = classify_uni_stream_checked(&[]);
assert_eq!(result, Err(StreamHeaderDecodeError::BufferTooShort));
}
#[test]
fn test_classify_uni_stream_checked_session_id_buffer_too_short() {
let mut buf = Vec::new();
varint::encode_into_vec(&mut buf, UNI_TYPE_VARINT);
let result = classify_uni_stream_checked(&buf);
assert_eq!(result, Err(StreamHeaderDecodeError::BufferTooShort));
}
}