use std::{future::poll_fn, io, pin::Pin};
use futures_core::Stream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::{Decoder, Framed};
use crate::{
proto::{Config, Limits, Role},
upgrade::client_request,
Error, WebSocketStream,
};
const BAD_REQUEST: &[u8] = b"HTTP/1.1 400 Bad Request\r\n\r\n";
pub struct Builder {
config: Config,
limits: Limits,
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl Builder {
#[must_use]
pub fn new() -> Self {
Self {
config: Config::default(),
limits: Limits::default(),
}
}
#[must_use]
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
#[must_use]
pub fn limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
pub async fn accept<S: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: S,
) -> Result<WebSocketStream<S>, Error> {
let mut framed = client_request::Codec {}.framed(stream);
let reply = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)).await;
let mut parts = framed.into_parts();
match reply {
Some(Ok(response)) => {
parts.io.write_all(response.as_bytes()).await?;
let framed = Framed::from_parts(parts);
Ok(WebSocketStream::from_framed(
framed,
Role::Server,
self.config,
self.limits,
))
}
Some(Err(e)) => {
parts.io.write_all(BAD_REQUEST).await?;
Err(e)
}
None => Err(Error::Io(io::ErrorKind::UnexpectedEof.into())),
}
}
pub fn serve<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) -> WebSocketStream<S> {
WebSocketStream::from_raw_stream(stream, Role::Server, self.config, self.limits)
}
}