use std::{
net::SocketAddr,
task::{ready, Context, Poll},
time::Duration,
};
use algorithm::buf::{Binary, BinaryMut};
use tokio_stream::Stream;
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::{channel, Receiver},
};
use webparse::{
http::http2::frame::{Reason, StreamIdentifier},
Version,
};
use crate::{
ws::ServerWsConnection, Builder, HeaderHelper, HttpHelper, HttpTrait, Initiator, Middleware,
ProtError, ProtResult, RecvRequest, RecvResponse, TimeoutLayer,
};
use super::{codec::Codec, control::ControlConfig, Control};
pub struct ServerH2Connection<T> {
codec: Codec<T>,
inner: InnerConnection,
timeout: Option<TimeoutLayer>,
}
struct InnerConnection {
state: State,
control: Control,
receiver_push: Option<Receiver<(StreamIdentifier, RecvResponse)>>,
}
#[derive(Debug)]
enum State {
Open,
Closing(Reason, Initiator),
Closed(Reason, Initiator),
}
unsafe impl<T> Sync for ServerH2Connection<T> {}
unsafe impl<T> Send for ServerH2Connection<T> {}
impl<T> ServerH2Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(io: T, builder: Builder) -> ServerH2Connection<T> {
let (sender, receiver) = channel(10);
ServerH2Connection {
codec: Codec::new(io),
inner: InnerConnection {
state: State::Open,
control: Control::new(
ControlConfig {
next_stream_id: 2.into(),
initial_max_send_streams: 0,
max_send_buffer_size: builder.max_send_buffer_size,
reset_stream_duration: builder.reset_stream_duration,
reset_stream_max: builder.reset_stream_max,
remote_reset_stream_max: builder.pending_accept_reset_stream_max,
settings: builder.settings.clone(),
},
sender,
true,
),
receiver_push: Some(receiver),
},
timeout: None,
}
}
pub fn into_io(self) -> T {
self.codec.into_io()
}
pub fn into_ws(self, binary: Binary) -> ServerWsConnection<T> {
let (io, read_buf, write_buf) = self.codec.into_io_with_cache();
let mut connect = ServerWsConnection::new(io);
connect.set_cache_buf(read_buf, write_buf);
connect.set_handshake_status(binary);
connect.set_timeout_layer(self.timeout);
connect
}
pub fn set_read_timeout(&mut self, read_timeout: Option<Duration>) {
if self.timeout.is_none() {
self.timeout = Some(TimeoutLayer::new());
}
self.timeout
.as_mut()
.unwrap()
.set_read_timeout(read_timeout);
}
pub fn set_write_timeout(&mut self, write_timeout: Option<Duration>) {
if self.timeout.is_none() {
self.timeout = Some(TimeoutLayer::new());
}
self.timeout
.as_mut()
.unwrap()
.set_write_timeout(write_timeout);
}
pub fn set_timeout(&mut self, timeout: Option<Duration>) {
if self.timeout.is_none() {
self.timeout = Some(TimeoutLayer::new());
}
self.timeout.as_mut().unwrap().set_timeout(timeout);
}
pub fn set_ka_timeout(&mut self, timeout: Option<Duration>) {
if self.timeout.is_none() {
self.timeout = Some(TimeoutLayer::new());
}
self.timeout.as_mut().unwrap().set_ka_timeout(timeout);
}
pub fn set_timeout_layer(&mut self, timeout_layer: Option<TimeoutLayer>) {
self.timeout = timeout_layer;
}
pub fn pull_accept(&mut self, _cx: &mut Context<'_>) -> Poll<Option<ProtResult<()>>> {
Poll::Pending
}
pub fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<ProtResult<RecvRequest>>> {
if self.timeout.is_some() {
let (ready_time, is_read_end, is_write_end, is_idle) = (
*self.inner.control.get_ready_time(),
self.inner.control.is_read_end(),
self.inner.control.is_write_end(&self.codec),
self.inner.control.is_idle(&self.codec),
);
self.timeout.as_mut().unwrap().poll_ready(
cx,
"server",
ready_time,
is_read_end,
is_write_end,
is_idle,
)?;
}
self.inner.control.poll_request(cx, &mut self.codec)
}
pub fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<ProtResult<()>> {
self.inner.control.poll_write(cx, &mut self.codec, false)
}
pub async fn handle_request(
&mut self,
addr: &Option<SocketAddr>,
mut r: RecvRequest,
f: &mut Box<dyn HttpTrait>,
middles: &mut Vec<Box<dyn Middleware>>,
) -> ProtResult<Option<bool>> {
let stream_id: Option<StreamIdentifier> = r.extensions_mut().remove::<StreamIdentifier>();
let res = HttpHelper::handle_request(Version::Http2, addr, r, f, middles).await?;
self.send_response(res, stream_id.unwrap_or(StreamIdentifier::client_first()))
.await?;
return Ok(None);
}
pub async fn incoming(&mut self) -> ProtResult<Option<RecvRequest>> {
use tokio_stream::StreamExt;
loop {
let mut receiver = self.inner.receiver_push.take().unwrap();
tokio::select! {
res = receiver.recv() => {
self.inner.receiver_push = Some(receiver);
if res.is_some() {
let res = res.unwrap();
let id = self.inner.control.next_stream_id();
self.inner.control.send_response_may_push(res.1, res.0, Some(id)).await?;
} else {
return Ok(None);
}
},
req = self.next() => {
self.inner.receiver_push = Some(receiver);
match req {
None => return Ok(None),
Some(Err(e)) => return Err(e),
Some(Ok(r)) => {
return Ok(Some(r));
}
};
}
}
}
}
fn handle_poll_result(&mut self, result: Option<ProtResult<RecvRequest>>) -> ProtResult<()> {
match result {
None => {
self.inner.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
Ok(())
}
Some(Err(ProtError::GoAway(debug_data, reason, initiator))) => {
let e = ProtError::GoAway(debug_data.clone(), reason, initiator);
tracing::debug!(error = ?e, "Connection::poll; connection error");
if self.inner.control.last_goaway_reason() == &reason {
self.inner.state = State::Closing(reason, initiator);
return Ok(());
}
self.inner.control.go_away_now_data(reason, debug_data);
Ok(())
}
Some(Err(e)) => {
self.inner.state = State::Closing(Reason::NO_ERROR, Initiator::Library);
return Err(e);
}
_ => {
unreachable!();
}
}
}
fn take_error(&mut self, ours: Reason, initiator: Initiator) -> ProtResult<()> {
let (debug_data, theirs) = self
.inner
.control
.error
.take()
.as_ref()
.map_or((Binary::new(), Reason::NO_ERROR), |frame| {
(frame.debug_data().clone(), frame.reason())
});
match (ours, theirs) {
(Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()),
(ours, Reason::NO_ERROR) => Err(ProtError::GoAway(Binary::new(), ours, initiator)),
(_, theirs) => Err(ProtError::GoAway(debug_data, theirs, Initiator::Remote)),
}
}
pub fn set_cache_buf(&mut self, read_buf: BinaryMut, write_buf: BinaryMut) {
self.codec.set_cache_buf(read_buf, write_buf)
}
pub fn set_handshake_status(&mut self, binary: Binary) {
self.inner.control.set_handshake_status(binary, false)
}
pub async fn send_response(
&mut self,
mut res: RecvResponse,
stream_id: StreamIdentifier,
) -> ProtResult<()> {
HeaderHelper::process_response_header(Version::Http2, false, &mut res)?;
self.inner.control.send_response(res, stream_id).await
}
}
impl<T> Stream for ServerH2Connection<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = ProtResult<RecvRequest>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
match self.inner.state {
State::Open => {
match self.poll_request(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Some(Ok(v))) => {
return Poll::Ready(Some(Ok(v)));
}
Poll::Ready(v) => {
let _ = self.handle_poll_result(v)?;
continue;
}
};
}
State::Closing(reason, initiator) => {
ready!(self.codec.shutdown(cx))?;
self.inner.state = State::Closed(reason, initiator);
}
State::Closed(reason, initiator) => {
if let Err(e) = self.take_error(reason, initiator) {
return Poll::Ready(Some(Err(e)));
}
return Poll::Ready(None);
}
}
}
}
}