use crate::{http2::codec::Codec, ProtError, ProtResult};
use std::{
io,
pin::Pin,
task::{ready, Context, Poll},
};
use algorithm::buf::{Binary, Bt};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use webparse::http::http2::HTTP2_MAGIC;
pub struct StateHandshake {
state: Handshaking,
is_client: bool,
}
enum Handshaking {
None,
Flushing(Flush),
ReadingPreface(ReadPreface),
Done,
}
struct Flush(Binary);
struct ReadPreface {
pos: usize,
}
impl ReadPreface {
pub fn new() -> Self {
ReadPreface { pos: 0 }
}
pub fn poll_handle<T>(
&mut self,
cx: &mut Context<'_>,
codec: &mut Codec<T>,
) -> Poll<ProtResult<()>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = [0; 24];
let mut rem = HTTP2_MAGIC.len() - self.pos;
while rem > 0 {
let mut buf = ReadBuf::new(&mut buf[..rem]);
ready!(Pin::new(codec.get_reader()).poll_read(cx, &mut buf))
.map_err(ProtError::from)?;
let n = buf.filled().len();
if n == 0 {
return Poll::Ready(Err(ProtError::from(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed before reading preface",
))));
}
if &HTTP2_MAGIC[self.pos..self.pos + n] != buf.filled() {
return Poll::Ready(Err(ProtError::Extension("handshake not match")));
}
self.pos += n;
rem -= n;
}
Poll::Ready(Ok(()))
}
}
impl StateHandshake {
pub fn new_server() -> StateHandshake {
StateHandshake {
state: Handshaking::None,
is_client: false,
}
}
pub fn new_client() -> StateHandshake {
StateHandshake {
state: Handshaking::None,
is_client: true,
}
}
pub fn poll_handle<T>(
&mut self,
cx: &mut Context<'_>,
codec: &mut Codec<T>,
) -> Poll<ProtResult<()>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
loop {
match &mut self.state {
Handshaking::None => {
self.state = Handshaking::Flushing(Flush(Binary::new()));
}
Handshaking::Flushing(flush) => {
match ready!(flush.poll_handle(cx, codec)) {
Ok(_) => {
tracing::trace!(flush.poll = %"Ready");
if self.is_client {
self.state = Handshaking::Done;
} else {
self.state = Handshaking::ReadingPreface(ReadPreface::new());
}
continue;
}
Err(e) => return Poll::Ready(Err(e)),
};
}
Handshaking::ReadingPreface(read) => {
match ready!(read.poll_handle(cx, codec)) {
Ok(_) => {
tracing::trace!(flush.poll = %"Ready");
self.state = Handshaking::Done;
return Poll::Ready(Ok(()));
}
Err(e) => return Poll::Ready(Err(e)),
};
}
Handshaking::Done => {
return Poll::Ready(Ok(()));
}
}
}
}
pub fn set_handshake_status(&mut self, binary: Binary, is_client: bool) {
self.is_client = is_client;
self.state = Handshaking::Flushing(Flush(binary))
}
}
impl Flush {
pub fn poll_handle<T>(
&mut self,
cx: &mut Context<'_>,
codec: &mut Codec<T>,
) -> Poll<ProtResult<()>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
if !self.0.has_remaining() {
return Poll::Ready(Ok(()));
}
loop {
match ready!(Pin::new(codec.get_mut()).poll_write(cx, self.0.chunk())) {
Ok(n) => {
self.0.advance(n);
}
Err(e) => return Poll::Ready(Err(e.into())),
}
if !self.0.has_remaining() {
return Poll::Ready(Ok(()));
}
}
}
}
unsafe impl Send for StateHandshake {}
unsafe impl Sync for StateHandshake {}