use std::{future::poll_fn, io, pin::Pin};
use futures_core::Stream;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::FramedRead;
use crate::{
proto::{Config, Limits, Role},
upgrade::client_request,
Error, WebSocketStream,
};
const BAD_REQUEST: &[u8] = b"HTTP/1.1 400 Bad Request\r\n\r\n";
pub struct Builder {
config: Config,
limits: Limits,
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl Builder {
#[must_use]
pub fn new() -> Self {
Self {
config: Config::default(),
limits: Limits::default(),
}
}
#[must_use]
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
#[must_use]
pub fn limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
pub async fn accept<S: AsyncRead + AsyncWrite + Unpin>(
&self,
stream: S,
) -> Result<WebSocketStream<S>, Error> {
let mut framed = FramedRead::new(stream, client_request::Codec {});
let reply = poll_fn(|cx| Pin::new(&mut framed).poll_next(cx)).await;
match reply {
Some(Ok(response)) => {
framed.get_mut().write_all(response.as_bytes()).await?;
Ok(WebSocketStream::from_framed(
framed,
Role::Server,
self.config,
self.limits,
))
}
Some(Err(e)) => {
framed.get_mut().write_all(BAD_REQUEST).await?;
Err(e)
}
None => Err(Error::Io(io::ErrorKind::UnexpectedEof.into())),
}
}
pub fn serve<S: AsyncRead + AsyncWrite + Unpin>(&self, stream: S) -> WebSocketStream<S> {
WebSocketStream::from_raw_stream(stream, Role::Server, self.config, self.limits)
}
}