use std::{
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use bytes::Bytes;
use futures::FutureExt;
use thiserror::Error;
use super::{
ack::ack_frame_with_ecn, add_address::be_add_address_frame,
connection_close::connection_close_frame_at_layer, crypto::be_crypto_frame,
data_blocked::be_data_blocked_frame, datagram::datagram_frame_with_flag,
max_data::be_max_data_frame, max_stream_data::be_max_stream_data_frame,
max_streams::max_streams_frame_with_dir, new_connection_id::be_new_connection_id_frame,
new_token::be_new_token_frame, path_challenge::be_path_challenge_frame,
path_response::be_path_response_frame, punch_done::be_punch_done_frame,
punch_hello::be_punch_hello_frame, punch_me_now::be_punch_me_now_frame,
remove_address::be_remove_address_frame, reset_stream::be_reset_stream_frame,
retire_connection_id::be_retire_connection_id_frame, stop_sending::be_stop_sending_frame,
stream::stream_frame_with_flag, stream_data_blocked::be_stream_data_blocked_frame,
streams_blocked::streams_blocked_frame_with_dir, *,
};
use crate::util::ContinuousData;
fn complete_frame(
frame_type: FrameType,
raw: Bytes,
) -> impl Fn(&[u8]) -> nom::IResult<&[u8], Frame> {
use nom::{Parser, combinator::map};
move |input: &[u8]| match frame_type {
FrameType::Padding => Ok((input, Frame::Padding(PaddingFrame))),
FrameType::Ping => Ok((input, Frame::Ping(PingFrame))),
FrameType::ConnectionClose(layer) => {
map(connection_close_frame_at_layer(layer), Frame::Close).parse(input)
}
FrameType::NewConnectionId => {
map(be_new_connection_id_frame, Frame::NewConnectionId).parse(input)
}
FrameType::RetireConnectionId => {
map(be_retire_connection_id_frame, Frame::RetireConnectionId).parse(input)
}
FrameType::DataBlocked => map(be_data_blocked_frame, Frame::DataBlocked).parse(input),
FrameType::MaxData => map(be_max_data_frame, Frame::MaxData).parse(input),
FrameType::PathChallenge => map(be_path_challenge_frame, Frame::PathChallenge).parse(input),
FrameType::PathResponse => map(be_path_response_frame, Frame::PathResponse).parse(input),
FrameType::HandshakeDone => Ok((input, Frame::HandshakeDone(HandshakeDoneFrame))),
FrameType::NewToken => map(be_new_token_frame, Frame::NewToken).parse(input),
FrameType::Ack(ecn) => map(ack_frame_with_ecn(ecn), Frame::Ack).parse(input),
FrameType::ResetStream => {
map(be_reset_stream_frame, |f| Frame::StreamCtl(f.into())).parse(input)
}
FrameType::StopSending => {
map(be_stop_sending_frame, |f| Frame::StreamCtl(f.into())).parse(input)
}
FrameType::MaxStreamData => {
map(be_max_stream_data_frame, |f| Frame::StreamCtl(f.into())).parse(input)
}
FrameType::MaxStreams(dir) => map(max_streams_frame_with_dir(dir), |f| {
Frame::StreamCtl(f.into())
})
.parse(input),
FrameType::StreamsBlocked(dir) => map(streams_blocked_frame_with_dir(dir), |f| {
Frame::StreamCtl(f.into())
})
.parse(input),
FrameType::StreamDataBlocked => {
map(be_stream_data_blocked_frame, |f| Frame::StreamCtl(f.into())).parse(input)
}
FrameType::Crypto => {
let (input, frame) = be_crypto_frame(input)?;
let start = raw.len() - input.len();
let len = frame.len() as usize;
if input.len() < len {
Err(nom::Err::Incomplete(nom::Needed::new(len - input.len())))
} else {
let data = raw.slice(start..start + len);
Ok((&input[len..], Frame::Crypto(frame, data)))
}
}
FrameType::Stream(offset, len, fin) => {
let (input, frame) = stream_frame_with_flag(offset, len, fin)(input)?;
let start = raw.len() - input.len();
let len = frame.len();
if input.len() < len {
Err(nom::Err::Incomplete(nom::Needed::new(len - input.len())))
} else {
let data = raw.slice(start..start + len);
Ok((&input[len..], Frame::Stream(frame, data)))
}
}
FrameType::Datagram(with_len) => {
let (input, frame) = datagram_frame_with_flag(with_len)(input)?;
let start = raw.len() - input.len();
match frame.encode_len() {
true if frame.len().into_inner() > input.len() as u64 => Err(nom::Err::Incomplete(
nom::Needed::new((frame.len().into_inner() - input.len() as u64) as usize),
)),
true => {
let data = raw.slice(start..start + frame.len().into_inner() as usize);
Ok((
&input[frame.len().into_inner() as usize..],
Frame::Datagram(frame, data),
))
}
false => {
let data = raw.slice(start..);
Ok((&[], Frame::Datagram(frame, data)))
}
}
}
FrameType::AddAddress(family) => {
map(be_add_address_frame(family), Frame::AddAddress).parse(input)
}
FrameType::RemoveAddress => map(be_remove_address_frame, Frame::RemoveAddress).parse(input),
FrameType::PunchMeNow(family) => {
map(be_punch_me_now_frame(family), Frame::PunchMeNow).parse(input)
}
FrameType::PunchHello => map(be_punch_hello_frame, Frame::PunchHello).parse(input),
FrameType::PunchDone => map(be_punch_done_frame, Frame::PunchDone).parse(input),
}
}
pub fn be_frame(raw: &Bytes, packet_type: Type) -> Result<(usize, Frame, FrameType), Error> {
let input = raw.as_ref();
let (remain, frame_type) = be_frame_type(input)?;
if !frame_type.belongs_to(packet_type) {
return Err(Error::WrongType(frame_type, packet_type));
}
let (remain, frame) = complete_frame(frame_type, raw.clone())(remain).map_err(|e| match e {
ne @ nom::Err::Incomplete(_) => {
nom::Err::Error(Error::IncompleteFrame(frame_type, ne.to_string()))
}
nom::Err::Error(ne) => {
nom::Err::Error(Error::ParseError(
frame_type,
ne.code.description().to_owned(),
))
}
_ => unreachable!("parsing frame never fails"),
})?;
Ok((input.len() - remain.len(), frame, frame_type))
}
pub trait WriteFrame<F>: bytes::BufMut {
fn put_frame(&mut self, frame: &F);
}
impl<B: BufMut, D: ContinuousData> WriteFrame<Frame<D>> for B
where
D: ContinuousData,
B: BufMut + ?Sized,
for<'b> &'b mut B: crate::util::WriteData<D>,
{
fn put_frame(&mut self, frame: &Frame<D>) {
#[inline(always)]
fn put<F, B: WriteFrame<F> + ?Sized>(buf: &mut B, frame: &F) {
buf.put_frame(frame);
}
let mut buf = self;
match frame {
Frame::Padding(f) => put(&mut buf, f),
Frame::Ping(f) => put(&mut buf, f),
Frame::Ack(f) => put(&mut buf, f),
Frame::Close(f) => put(&mut buf, f),
Frame::NewToken(f) => put(&mut buf, f),
Frame::MaxData(f) => put(&mut buf, f),
Frame::DataBlocked(f) => put(&mut buf, f),
Frame::AddAddress(f) => put(&mut buf, f),
Frame::RemoveAddress(f) => put(&mut buf, f),
Frame::PunchMeNow(f) => put(&mut buf, f),
Frame::PunchHello(f) => put(&mut buf, f),
Frame::PunchDone(f) => put(&mut buf, f),
Frame::NewConnectionId(f) => put(&mut buf, f),
Frame::RetireConnectionId(f) => put(&mut buf, f),
Frame::HandshakeDone(f) => put(&mut buf, f),
Frame::PathChallenge(f) => put(&mut buf, f),
Frame::PathResponse(f) => put(&mut buf, f),
Frame::StreamCtl(f) => put(&mut buf, f),
Frame::Stream(f, d) => buf.put_data_frame(f, d),
Frame::Crypto(f, d) => buf.put_data_frame(f, d),
Frame::Datagram(f, d) => buf.put_data_frame(f, d),
}
}
}
pub trait WriteDataFrame<F, D: ContinuousData>: bytes::BufMut {
fn put_data_frame(&mut self, frame: &F, data: &D);
}
pub trait WriteFrameType: bytes::BufMut {
fn put_frame_type(&mut self, frame_type: FrameType);
}
impl<T: BufMut> WriteFrameType for T {
fn put_frame_type(&mut self, frame_type: FrameType) {
use crate::varint::WriteVarInt;
let fty: VarInt = frame_type.into();
self.put_varint(&fty);
}
}
pub trait SendFrame<T> {
fn send_frame<I: IntoIterator<Item = T>>(&self, iter: I);
}
pub trait ReceiveFrame<T> {
type Output;
fn recv_frame(&self, frame: &T) -> Result<Self::Output, crate::error::Error>;
}
#[derive(Debug, Default)]
pub enum Receiving<F> {
#[default]
Pending,
Waiting(Waker),
Rcvd(F),
Read,
Reset,
}
impl<F> Receiving<F> {
pub fn reset(&mut self) {
if let Self::Waiting(waker) = std::mem::replace(self, Self::Reset) {
waker.wake();
}
}
}
impl<F> ReceiveFrame<F> for Receiving<F> {
type Output = ();
fn recv_frame(&self, _frame: &F) -> Result<Self::Output, crate::error::Error> {
todo!(
"Pending的时候,变为Rcvd;Waiting的时候,唤醒waker,变为Rcvd;Rcvd,打印debug;Reset,打印warn"
)
}
}
#[derive(Debug, Error)]
#[error("Reset")]
pub struct ResetError;
impl<F: Unpin> Future for Receiving<F> {
type Output = Result<Option<F>, ResetError>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let state = self.get_mut();
match std::mem::take(state) {
Self::Pending => Poll::Pending,
Self::Waiting(waker) => {
*state = Self::Waiting(waker);
Poll::Pending
}
Self::Rcvd(frame) => {
*state = Self::Read;
Poll::Ready(Ok(Some(frame)))
}
Self::Read => {
*state = Self::Read;
Poll::Ready(Ok(None))
}
Self::Reset => {
*state = Self::Reset;
Poll::Ready(Err(ResetError))
}
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ArcReceiving<F>(Arc<Mutex<Receiving<F>>>);
impl<F> ArcReceiving<F> {
pub fn reset(&self) {
self.0.lock().unwrap().reset();
}
}
impl<F: Unpin> Future for ArcReceiving<F> {
type Output = Result<Option<F>, ResetError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.lock().unwrap().poll_unpin(cx)
}
}