use std::borrow::Cow;
use derive_more::From;
use nom::bytes::complete::take;
use super::FrameType;
use crate::{
error::{ErrorFrameType, ErrorKind},
frame::{GetFrameType, be_frame_type, io::WriteFrameType},
varint::{VarInt, be_varint},
};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Layer {
Quic,
App,
}
impl From<Layer> for u8 {
fn from(layer: Layer) -> u8 {
match layer {
Layer::Quic => 0,
Layer::App => 1,
}
}
}
impl From<u8> for Layer {
fn from(value: u8) -> Self {
match value & 0x01 {
0 => Layer::Quic,
_ => Layer::App,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AppCloseFrame {
error_code: VarInt,
reason: Cow<'static, str>,
}
impl AppCloseFrame {
pub fn error_code(&self) -> u64 {
self.error_code.into_inner()
}
pub fn reason(&self) -> &str {
&self.reason
}
pub fn conceal(&self) -> QuicCloseFrame {
QuicCloseFrame {
error_kind: ErrorKind::Application,
frame_type: ErrorFrameType::V1(FrameType::Padding),
reason: Cow::Borrowed(""),
}
}
}
impl From<AppCloseFrame> for QuicCloseFrame {
fn from(_: AppCloseFrame) -> Self {
QuicCloseFrame {
error_kind: ErrorKind::Application,
frame_type: ErrorFrameType::V1(FrameType::Padding),
reason: Cow::Borrowed(""),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QuicCloseFrame {
error_kind: ErrorKind,
frame_type: ErrorFrameType,
reason: Cow<'static, str>,
}
impl QuicCloseFrame {
pub fn error_kind(&self) -> ErrorKind {
self.error_kind
}
pub fn frame_type(&self) -> ErrorFrameType {
self.frame_type
}
pub fn reason(&self) -> &str {
&self.reason
}
}
#[derive(Debug, Clone, From, PartialEq, Eq)]
pub enum ConnectionCloseFrame {
App(AppCloseFrame),
Quic(QuicCloseFrame),
}
impl super::GetFrameType for ConnectionCloseFrame {
fn frame_type(&self) -> FrameType {
match self {
ConnectionCloseFrame::App(_) => FrameType::ConnectionClose(Layer::App),
ConnectionCloseFrame::Quic(_) => FrameType::ConnectionClose(Layer::Quic),
}
}
}
impl super::EncodeSize for ConnectionCloseFrame {
fn max_encoding_size(&self) -> usize {
match self {
ConnectionCloseFrame::App(frame) => 1 + 8 + 2 + frame.reason.len(),
ConnectionCloseFrame::Quic(frame) => 1 + 8 + 8 + 2 + frame.reason.len(),
}
}
fn encoding_size(&self) -> usize {
match self {
ConnectionCloseFrame::App(frame) => {
1 + frame.error_code.encoding_size()
+ VarInt::try_from(frame.reason.len()).unwrap().encoding_size()
+ frame.reason.len()
}
ConnectionCloseFrame::Quic(frame) => {
1 + VarInt::from(frame.error_kind).encoding_size() + 1
+ VarInt::try_from(frame.reason.len()).unwrap().encoding_size()
+ frame.reason.len()
}
}
}
}
impl ConnectionCloseFrame {
pub fn new_quic(
error_kind: ErrorKind,
frame_type: ErrorFrameType,
reason: impl Into<Cow<'static, str>>,
) -> Self {
Self::Quic(QuicCloseFrame {
error_kind,
frame_type,
reason: reason.into(),
})
}
pub fn new_app(error_code: VarInt, reason: impl Into<Cow<'static, str>>) -> Self {
Self::App(AppCloseFrame {
error_code,
reason: reason.into(),
})
}
}
fn be_app_close_frame(input: &[u8]) -> nom::IResult<&[u8], AppCloseFrame> {
let (remain, error_code) = be_varint(input)?;
let (remain, reason_length) = be_varint(remain)?;
let (remain, reason) = take(reason_length.into_inner() as usize)(remain)?;
let cow = String::from_utf8_lossy(reason).into_owned();
Ok((
remain,
AppCloseFrame {
error_code,
reason: Cow::Owned(cow),
},
))
}
fn be_quic_close_frame(input: &[u8]) -> nom::IResult<&[u8], QuicCloseFrame> {
let (remain, error_code) = be_varint(input)?;
let error_kind = ErrorKind::try_from(error_code)
.map_err(|_e| nom::Err::Error(nom::error::make_error(input, nom::error::ErrorKind::Alt)))?;
let (remain, frame_type) = be_frame_type(remain)
.map_err(|_e| nom::Err::Error(nom::error::make_error(input, nom::error::ErrorKind::Alt)))?;
let (remain, reason_length) = be_varint(remain)?;
let (remain, reason) = take(reason_length.into_inner() as usize)(remain)?;
let cow = String::from_utf8_lossy(reason).into_owned();
Ok((
remain,
QuicCloseFrame {
error_kind,
frame_type: frame_type.into(),
reason: Cow::Owned(cow),
},
))
}
pub fn connection_close_frame_at_layer(
layer: Layer,
) -> impl Fn(&[u8]) -> nom::IResult<&[u8], ConnectionCloseFrame> {
move |input: &[u8]| match layer {
Layer::App => {
be_app_close_frame(input).map(|(remain, app)| (remain, ConnectionCloseFrame::App(app)))
}
Layer::Quic => be_quic_close_frame(input)
.map(|(remain, quic)| (remain, ConnectionCloseFrame::Quic(quic))),
}
}
impl<T: bytes::BufMut> super::io::WriteFrame<ConnectionCloseFrame> for T {
fn put_frame(&mut self, frame: &ConnectionCloseFrame) {
use crate::varint::WriteVarInt;
self.put_frame_type(frame.frame_type());
match frame {
ConnectionCloseFrame::App(frame) => {
self.put_varint(&frame.error_code);
let len = frame.reason.len().min(self.remaining_mut());
self.put_varint(&VarInt::from_u32(len as u32));
self.put_slice(&frame.reason.as_bytes()[..len]);
}
ConnectionCloseFrame::Quic(frame) => {
self.put_varint(&frame.error_kind.into());
self.put_varint(&frame.frame_type.into());
let len = frame.reason.len().min(self.remaining_mut());
self.put_varint(&VarInt::from_u32(len as u32));
self.put_slice(&frame.reason.as_bytes()[..len]);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
error::ErrorKind,
frame::{
EncodeSize, FrameType, GetFrameType,
io::{WriteFrame, WriteFrameType},
stream::{Fin, Len, Offset},
},
varint::VarInt,
};
#[test]
fn test_connection_close_frame() {
let frame = ConnectionCloseFrame::new_app(VarInt::from_u32(0x1234), "wrong");
assert_eq!(frame.frame_type(), FrameType::ConnectionClose(Layer::App));
assert_eq!(frame.max_encoding_size(), 1 + 8 + 2 + 5);
assert_eq!(frame.encoding_size(), 1 + 2 + 1 + 5);
}
#[test]
fn test_read_connection_close_frame() {
use nom::{Parser, combinator::flat_map};
use crate::varint::be_varint;
let mut buf = Vec::new();
buf.put_frame_type(FrameType::ConnectionClose(Layer::App));
buf.extend_from_slice(&[0x0c, 5, b'w', b'r', b'o', b'n', b'g']);
let app_close_frame_type = VarInt::from(FrameType::ConnectionClose(Layer::App));
let (input, frame) = flat_map(be_varint, |frame_type| {
if frame_type == app_close_frame_type {
connection_close_frame_at_layer(Layer::App)
} else {
panic!("wrong frame type: {frame_type}")
}
})
.parse(buf.as_ref())
.unwrap();
assert!(input.is_empty());
assert_eq!(
frame,
super::ConnectionCloseFrame::new_app(VarInt::from_u32(0x0c), "wrong",)
);
}
#[test]
fn test_write_connection_close_frame() {
use super::FrameType;
let mut buf = Vec::<u8>::new();
let frame = ConnectionCloseFrame::new_quic(
ErrorKind::FlowControl,
FrameType::Stream(Offset::NonZero, Len::Explicit, Fin::No).into(),
"wrong",
);
buf.put_frame(&frame);
let mut expected = Vec::new();
expected.put_frame_type(FrameType::ConnectionClose(Layer::Quic));
expected.extend_from_slice(&[0x03, 0xe, 5, b'w', b'r', b'o', b'n', b'g']);
assert_eq!(buf, expected);
}
}