use std::cmp;
use std::convert::TryFrom;
use std::error::Error;
use std::fmt::Debug;
use std::sync::Arc;
use multimap::MultiMap;
use quiche;
use quiche::h3::frame::Frame as QFrame;
use quiche::h3::Header;
use quiche::h3::NameValue;
use serde::ser::SerializeStruct;
use serde::ser::Serializer;
use serde::Serialize;
use crate::client::connection_summary::MAX_SERIALIZED_BUFFER_LEN;
use crate::encode_header_block;
pub type BoxError = Box<dyn Error + Send + Sync + 'static>;
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum H3iFrame {
QuicheH3(QFrame),
Headers(EnrichedHeaders),
ResetStream(ResetStream),
}
impl H3iFrame {
pub fn to_enriched_headers(&self) -> Option<EnrichedHeaders> {
if let H3iFrame::Headers(header) = self {
Some(header.clone())
} else {
None
}
}
}
impl Serialize for H3iFrame {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
H3iFrame::QuicheH3(frame) => {
let mut state = s.serialize_struct("frame", 1)?;
let name = frame_name(frame);
state.serialize_field(name, &SerializableQFrame(frame))?;
state.end()
},
H3iFrame::Headers(headers) => {
let mut state = s.serialize_struct("enriched_headers", 1)?;
state.serialize_field("enriched_headers", headers)?;
state.end()
},
H3iFrame::ResetStream(reset) => {
let mut state = s.serialize_struct("reset_stream", 1)?;
state.serialize_field("reset_stream", reset)?;
state.end()
},
}
}
}
impl From<QFrame> for H3iFrame {
fn from(value: QFrame) -> Self {
Self::QuicheH3(value)
}
}
impl From<Vec<Header>> for H3iFrame {
fn from(value: Vec<Header>) -> Self {
Self::Headers(EnrichedHeaders::from(value))
}
}
pub type HeaderMap = MultiMap<Vec<u8>, Vec<u8>>;
#[derive(Clone, PartialEq, Eq)]
pub struct EnrichedHeaders {
header_block: Vec<u8>,
headers: Vec<Header>,
header_map: HeaderMap,
}
pub struct SerializableHeader<'a>(&'a Header);
impl Serialize for SerializableHeader<'_> {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = s.serialize_struct("header", 2)?;
state.serialize_field("name", &String::from_utf8_lossy(self.0.name()))?;
state
.serialize_field("value", &String::from_utf8_lossy(self.0.value()))?;
state.end()
}
}
impl EnrichedHeaders {
pub fn headers(&self) -> &[Header] {
&self.headers
}
pub fn header_map(&self) -> &HeaderMap {
&self.header_map
}
pub fn status_code(&self) -> Option<&Vec<u8>> {
self.header_map.get(b":status".as_slice())
}
}
impl Serialize for EnrichedHeaders {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = s.serialize_struct("enriched_headers", 2)?;
state.serialize_field("header_block_len", &self.header_block.len())?;
let x: Vec<SerializableHeader> =
self.headers.iter().map(SerializableHeader).collect();
state.serialize_field("headers", &x)?;
state.end()
}
}
impl From<Vec<Header>> for EnrichedHeaders {
fn from(headers: Vec<Header>) -> Self {
let header_block = encode_header_block(&headers).unwrap();
let mut header_map: HeaderMap = MultiMap::with_capacity(headers.len());
for header in headers.iter() {
header_map.insert(header.name().to_vec(), header.value().to_vec());
}
Self {
header_block,
headers,
header_map,
}
}
}
impl TryFrom<QFrame> for EnrichedHeaders {
type Error = BoxError;
fn try_from(value: QFrame) -> Result<Self, Self::Error> {
match value {
QFrame::Headers { header_block } => {
let mut qpack_decoder = quiche::h3::qpack::Decoder::new();
let headers =
qpack_decoder.decode(&header_block, u64::MAX).unwrap();
Ok(EnrichedHeaders::from(headers))
},
_ => Err("Cannot convert non-Headers frame into HeadersFrame".into()),
}
}
}
impl Debug for EnrichedHeaders {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.headers)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize)]
pub struct ResetStream {
pub stream_id: u64,
pub error_code: u64,
}
fn frame_name(frame: &QFrame) -> &'static str {
match frame {
QFrame::Data { .. } => "DATA",
QFrame::Headers { .. } => "HEADERS",
QFrame::CancelPush { .. } => "CANCEL_PUSH",
QFrame::Settings { .. } => "SETTINGS",
QFrame::PushPromise { .. } => "PUSH_PROMISE",
QFrame::GoAway { .. } => "GO_AWAY",
QFrame::MaxPushId { .. } => "MAX_PUSH_ID",
QFrame::PriorityUpdateRequest { .. } => "PRIORITY_UPDATE(REQUEST)",
QFrame::PriorityUpdatePush { .. } => "PRIORITY_UPDATE(PUSH)",
QFrame::Unknown { .. } => "UNKNOWN",
}
}
pub struct SerializableQFrame<'a>(&'a QFrame);
impl Serialize for SerializableQFrame<'_> {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let name = frame_name(self.0);
match self.0 {
QFrame::Data { payload } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("payload_len", &payload.len())?;
state.end()
},
QFrame::Headers { header_block } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("header_block_len", &header_block.len())?;
state.end()
},
QFrame::CancelPush { push_id } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("push_id", &push_id)?;
state.end()
},
QFrame::Settings {
max_field_section_size,
qpack_max_table_capacity,
qpack_blocked_streams,
connect_protocol_enabled,
h3_datagram,
grease: _,
additional_settings,
raw: _,
} => {
let mut state = s.serialize_struct(name, 6)?;
state.serialize_field(
"max_field_section_size",
&max_field_section_size,
)?;
state.serialize_field(
"qpack_max_table_capacity",
&qpack_max_table_capacity,
)?;
state.serialize_field(
"qpack_blocked_streams",
&qpack_blocked_streams,
)?;
state.serialize_field(
"connect_protocol_enabled",
&connect_protocol_enabled,
)?;
state.serialize_field("h3_datagram", &h3_datagram)?;
state.serialize_field(
"additional_settings",
&additional_settings,
)?;
state.end()
},
QFrame::PushPromise {
push_id,
header_block,
} => {
let mut state = s.serialize_struct(name, 2)?;
state.serialize_field("push_id", &push_id)?;
state.serialize_field("header_block_len", &header_block.len())?;
state.end()
},
QFrame::GoAway { id } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("id", &id)?;
state.end()
},
QFrame::MaxPushId { push_id } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("push_id", &push_id)?;
state.end()
},
QFrame::PriorityUpdateRequest {
prioritized_element_id,
priority_field_value,
} => {
let mut state = s.serialize_struct(name, 2)?;
state.serialize_field(
"prioritized_element_id",
&prioritized_element_id,
)?;
let max = cmp::min(
priority_field_value.len(),
MAX_SERIALIZED_BUFFER_LEN,
);
state.serialize_field(
"priority_field_value",
&String::from_utf8_lossy(&priority_field_value[..max]),
)?;
state.end()
},
QFrame::PriorityUpdatePush {
prioritized_element_id,
priority_field_value,
} => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field(
"prioritized_element_id",
&prioritized_element_id,
)?;
let max = cmp::min(
priority_field_value.len(),
MAX_SERIALIZED_BUFFER_LEN,
);
state.serialize_field(
"priority_field_value",
&String::from_utf8_lossy(&priority_field_value[..max]),
)?;
state.end()
},
QFrame::Unknown { raw_type, payload } => {
let mut state = s.serialize_struct(name, 1)?;
state.serialize_field("raw_type", &raw_type)?;
let max = cmp::min(payload.len(), MAX_SERIALIZED_BUFFER_LEN);
state.serialize_field(
"payload",
&qlog::HexSlice::maybe_string(Some(&payload[..max])),
)?;
state.end()
},
}
}
}
type CustomEquivalenceHandler =
Box<dyn for<'f> Fn(&'f H3iFrame) -> bool + Send + Sync + 'static>;
#[derive(Clone)]
enum Comparator {
Frame(H3iFrame),
Fn(Arc<CustomEquivalenceHandler>),
}
impl Serialize for Comparator {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Fn(_) => serializer.serialize_str("<comparator_fn>"),
Self::Frame(f) => {
let mut frame_ser = serializer.serialize_struct("frame", 1)?;
frame_ser.serialize_field("frame", f)?;
frame_ser.end()
},
}
}
}
#[derive(Serialize, Clone)]
pub struct CloseTriggerFrame {
stream_id: u64,
comparator: Comparator,
}
impl CloseTriggerFrame {
pub fn new(stream_id: u64, frame: impl Into<H3iFrame>) -> Self {
Self {
stream_id,
comparator: Comparator::Frame(frame.into()),
}
}
pub fn new_with_comparator<F>(stream_id: u64, comparator_fn: F) -> Self
where
F: Fn(&H3iFrame) -> bool + Send + Sync + 'static,
{
Self {
stream_id,
comparator: Comparator::Fn(Arc::new(Box::new(comparator_fn))),
}
}
pub(crate) fn stream_id(&self) -> u64 {
self.stream_id
}
pub(crate) fn is_equivalent(&self, other: &H3iFrame) -> bool {
let frame = match &self.comparator {
Comparator::Fn(compare) => return compare(other),
Comparator::Frame(frame) => frame,
};
match frame {
H3iFrame::Headers(me) => {
let H3iFrame::Headers(other) = other else {
return false;
};
me.headers().iter().all(|m| other.headers().contains(m))
},
H3iFrame::QuicheH3(me) => match other {
H3iFrame::QuicheH3(other) => me == other,
_ => false,
},
H3iFrame::ResetStream(me) => match other {
H3iFrame::ResetStream(rs) => me == rs,
_ => false,
},
}
}
}
impl Debug for CloseTriggerFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let repr = match &self.comparator {
Comparator::Frame(frame) => format!("{frame:?}"),
Comparator::Fn(_) => "closure".to_string(),
};
write!(
f,
"CloseTriggerFrame {{ stream_id: {}, comparator: {repr} }}",
self.stream_id
)
}
}
impl PartialEq for CloseTriggerFrame {
fn eq(&self, other: &Self) -> bool {
match (&self.comparator, &other.comparator) {
(Comparator::Frame(this_frame), Comparator::Frame(other_frame)) =>
self.stream_id == other.stream_id && this_frame == other_frame,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quiche::h3::frame::Frame;
#[test]
fn test_header_equivalence() {
let this = CloseTriggerFrame::new(0, vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
]);
let other: H3iFrame = vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
Header::new(b"go", b"devils"),
]
.into();
assert!(this.is_equivalent(&other));
}
#[test]
fn test_header_non_equivalence() {
let this = CloseTriggerFrame::new(0, vec![
Header::new(b"hello", b"world"),
Header::new(b"go", b"jets"),
Header::new(b"go", b"devils"),
]);
let other: H3iFrame =
vec![Header::new(b"hello", b"world"), Header::new(b"go", b"jets")]
.into();
assert!(!this.is_equivalent(&other));
}
#[test]
fn test_rst_stream_equivalence() {
let mut rs = ResetStream {
stream_id: 0,
error_code: 57,
};
let this = CloseTriggerFrame::new(0, H3iFrame::ResetStream(rs.clone()));
let incoming = H3iFrame::ResetStream(rs.clone());
assert!(this.is_equivalent(&incoming));
rs.stream_id = 57;
let incoming = H3iFrame::ResetStream(rs);
assert!(!this.is_equivalent(&incoming));
}
#[test]
fn test_frame_equivalence() {
let mut d = Frame::Data {
payload: b"57".to_vec(),
};
let this = CloseTriggerFrame::new(0, H3iFrame::QuicheH3(d.clone()));
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(this.is_equivalent(&incoming));
d = Frame::Data {
payload: b"go jets".to_vec(),
};
let incoming = H3iFrame::QuicheH3(d.clone());
assert!(!this.is_equivalent(&incoming));
}
#[test]
fn test_comparator() {
let this = CloseTriggerFrame::new_with_comparator(0, |frame| {
if let H3iFrame::Headers(..) = frame {
frame
.to_enriched_headers()
.unwrap()
.header_map()
.get(&b"cookie".to_vec())
.is_some_and(|v| {
std::str::from_utf8(v)
.map(|s| s.to_lowercase())
.unwrap()
.contains("cookie")
})
} else {
false
}
});
let incoming: H3iFrame =
vec![Header::new(b"cookie", b"SomeRandomCookie1234")].into();
assert!(this.is_equivalent(&incoming));
}
}