use std::{
collections::HashMap,
fmt::Debug,
ops::{Deref, DerefMut},
sync::Arc,
};
use bytes::{Buf, BufMut, BytesMut};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use super::{Frame, UniStream, VarInt, VarIntUnexpectedEnd};
use crate::grease::is_grease_value;
use crate::io::read_incremental;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Setting(pub VarInt);
impl Setting {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
Ok(Setting(VarInt::decode(buf)?))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
self.0.encode(buf)
}
pub fn size(&self) -> usize {
self.0.size()
}
pub fn is_grease(&self) -> bool {
is_grease_value(self.0.into_inner())
}
}
impl Debug for Setting {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Setting::QPACK_MAX_TABLE_CAPACITY => write!(f, "QPACK_MAX_TABLE_CAPACITY"),
Setting::MAX_FIELD_SECTION_SIZE => write!(f, "MAX_FIELD_SECTION_SIZE"),
Setting::QPACK_BLOCKED_STREAMS => write!(f, "QPACK_BLOCKED_STREAMS"),
Setting::ENABLE_CONNECT_PROTOCOL => write!(f, "ENABLE_CONNECT_PROTOCOL"),
Setting::ENABLE_DATAGRAM => write!(f, "ENABLE_DATAGRAM"),
Setting::ENABLE_DATAGRAM_DEPRECATED => write!(f, "ENABLE_DATAGRAM_DEPRECATED"),
Setting::WEBTRANSPORT_ENABLE_DEPRECATED => write!(f, "WEBTRANSPORT_ENABLE_DEPRECATED"),
Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED => {
write!(f, "WEBTRANSPORT_MAX_SESSIONS_DEPRECATED")
}
Setting::WEBTRANSPORT_MAX_SESSIONS => write!(f, "WEBTRANSPORT_MAX_SESSIONS"),
x if x.is_grease() => write!(f, "GREASE SETTING [{:x?}]", x.0.into_inner()),
x => write!(f, "UNKNOWN_SETTING [{:x?}]", x.0.into_inner()),
}
}
}
impl Setting {
pub const fn from_u32(value: u32) -> Self {
Self(VarInt::from_u32(value))
}
pub const QPACK_MAX_TABLE_CAPACITY: Setting = Setting::from_u32(0x1); pub const MAX_FIELD_SECTION_SIZE: Setting = Setting::from_u32(0x6);
pub const QPACK_BLOCKED_STREAMS: Setting = Setting::from_u32(0x7);
pub const ENABLE_CONNECT_PROTOCOL: Setting = Setting::from_u32(0x8);
pub const ENABLE_DATAGRAM: Setting = Setting::from_u32(0x33);
pub const ENABLE_DATAGRAM_DEPRECATED: Setting = Setting::from_u32(0xFFD277);
pub const WEBTRANSPORT_ENABLE_DEPRECATED: Setting = Setting::from_u32(0x2b603742);
pub const WEBTRANSPORT_MAX_SESSIONS_DEPRECATED: Setting = Setting::from_u32(0x2b603743);
pub const WEBTRANSPORT_MAX_SESSIONS: Setting = Setting::from_u32(0xc671706a);
}
#[derive(Error, Debug, Clone)]
pub enum SettingsError {
#[error("unexpected end of input")]
UnexpectedEnd,
#[error("unexpected stream type {0:?}")]
UnexpectedStreamType(UniStream),
#[error("unexpected frame {0:?}")]
UnexpectedFrame(Frame),
#[error("invalid size")]
InvalidSize,
#[error("io error: {0}")]
Io(Arc<std::io::Error>),
}
impl From<std::io::Error> for SettingsError {
fn from(err: std::io::Error) -> Self {
SettingsError::Io(Arc::new(err))
}
}
#[derive(Default, Debug)]
pub struct Settings(HashMap<Setting, VarInt>);
impl Settings {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, SettingsError> {
let typ = UniStream::decode(buf).map_err(|_| SettingsError::UnexpectedEnd)?;
if typ != UniStream::CONTROL {
return Err(SettingsError::UnexpectedStreamType(typ));
}
let (typ, mut data) = Frame::read(buf).map_err(|_| SettingsError::UnexpectedEnd)?;
if typ != Frame::SETTINGS {
return Err(SettingsError::UnexpectedFrame(typ));
}
let mut settings = Settings::default();
while data.has_remaining() {
let id = Setting::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?;
let value = VarInt::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?;
if !id.is_grease() {
settings.0.insert(id, value);
}
}
Ok(settings)
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, SettingsError> {
read_incremental(
stream,
|cursor| Self::decode(cursor),
|err| matches!(err, SettingsError::UnexpectedEnd),
SettingsError::UnexpectedEnd,
)
.await
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
UniStream::CONTROL.encode(buf);
Frame::SETTINGS.encode(buf);
let payload_len = self.payload_len();
VarInt::try_from(payload_len as u64)
.expect("settings payload length exceeds VarInt bounds")
.encode(buf);
for (id, value) in &self.0 {
id.encode(buf);
value.encode(buf);
}
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), SettingsError> {
let mut buf = BytesMut::with_capacity(self.encoded_len());
self.encode(&mut buf);
stream.write_all_buf(&mut buf).await?;
Ok(())
}
pub fn enable_webtransport(&mut self, max_sessions: u32) {
self.enable_webtransport_internal(max_sessions, true);
}
pub fn enable_webtransport_latest(&mut self, max_sessions: u32) {
self.enable_webtransport_internal(max_sessions, false);
}
fn enable_webtransport_internal(&mut self, max_sessions: u32, include_deprecated: bool) {
let max = VarInt::from_u32(max_sessions);
self.insert(Setting::ENABLE_CONNECT_PROTOCOL, VarInt::from_u32(1));
self.insert(Setting::ENABLE_DATAGRAM, VarInt::from_u32(1));
self.insert(Setting::ENABLE_DATAGRAM_DEPRECATED, VarInt::from_u32(1));
self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS, max);
if include_deprecated {
self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED, max);
self.insert(Setting::WEBTRANSPORT_ENABLE_DEPRECATED, VarInt::from_u32(1));
} else {
self.0
.remove(&Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED);
self.0.remove(&Setting::WEBTRANSPORT_ENABLE_DEPRECATED);
}
}
pub fn supports_webtransport(&self) -> u64 {
let datagram = self
.get(&Setting::ENABLE_DATAGRAM)
.or(self.get(&Setting::ENABLE_DATAGRAM_DEPRECATED))
.map(|v| v.into_inner());
if datagram != Some(1) {
return 0;
}
if let Some(max) = self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS) {
return max.into_inner();
}
let enabled = self
.get(&Setting::WEBTRANSPORT_ENABLE_DEPRECATED)
.map(|v| v.into_inner());
if enabled != Some(1) {
return 0;
}
self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED)
.map(|v| v.into_inner())
.unwrap_or(1)
}
fn payload_len(&self) -> usize {
self.0
.iter()
.map(|(id, value)| id.size() + value.size())
.sum()
}
fn encoded_len(&self) -> usize {
let payload_len = self.payload_len();
UniStream::CONTROL.0.size()
+ Frame::SETTINGS.0.size()
+ VarInt::try_from(payload_len as u64)
.expect("settings payload length exceeds VarInt bounds")
.size()
+ payload_len
}
}
impl Deref for Settings {
type Target = HashMap<Setting, VarInt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Settings {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}