use crate::Request;
use crate::ResponseLike;
pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 8;
use std::io;
use smol::net::{AsyncToSocketAddrs, SocketAddr, TcpListener, TcpStream};
#[cfg(feature = "tls")]
use smol::io::{AsyncRead, AsyncReadExt, AsyncWrite};
#[cfg(feature = "tls")]
use async_native_tls::{TlsAcceptor, TlsStream};
#[cfg(not(feature = "tls"))]
pub type Stream = TcpStream;
#[cfg(feature = "tls")]
pub type Stream = TlsStream<TcpStream>;
#[cfg(feature = "websocket")]
use crate::ws::maybe_websocket;
#[cfg(feature = "websocket")]
use async_tungstenite::WebSocketStream;
#[cfg(feature = "websocket")]
pub type WsHandler<S> = Option<(&'static str, WsHandlerFn<S>)>;
#[cfg(feature = "websocket")]
type WsHandlerFn<S> = Arc<dyn Fn(WebSocketStream<S>) + Send + Sync + 'static>;
use std::future::Future;
#[cfg(feature = "websocket")]
use std::sync::Arc;
pub struct Server {
acceptor: TcpListener,
buffer_size: usize,
insert_default_headers: bool,
#[cfg(feature = "tls")]
tls_acceptor: TlsAcceptor,
#[cfg(feature = "websocket")]
ws_handler: WsHandler<Stream>,
}
impl Server {
pub async fn new(
addr: impl AsyncToSocketAddrs,
#[cfg(feature = "tls")] tls_acceptor: TlsAcceptor,
) -> io::Result<Self> {
Ok(Self {
acceptor: TcpListener::bind(addr).await?,
buffer_size: DEFAULT_BUFFER_SIZE,
#[cfg(feature = "websocket")]
ws_handler: None,
#[cfg(feature = "tls")]
tls_acceptor,
insert_default_headers: false,
})
}
pub fn with_default_headers(mut self) -> Self {
self.insert_default_headers = true;
self
}
#[inline]
pub fn addr(&self) -> io::Result<SocketAddr> {
self.acceptor.local_addr()
}
pub fn pretty_addr(&self) -> io::Result<String> {
self.addr().map(crate::util::format_addr)
}
pub fn set_buffer_size(&mut self, size: usize) {
self.buffer_size = size;
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
#[cfg(feature = "websocket")]
pub fn on_websocket<F, R>(mut self, path: &'static str, handler: F) -> Self
where
F: Fn(WebSocketStream<Stream>) -> R + Send + 'static + Clone + Sync,
R: Future<Output = ()> + Send + 'static,
{
let real_handler: WsHandlerFn<Stream> = Arc::new(move |s: WebSocketStream<Stream>| {
smol::spawn(handler(s)).detach();
});
self.ws_handler = Some((path, real_handler));
self
}
pub async fn run<F, T, R>(mut self, handler: F) -> !
where
F: Fn(Request) -> R + Send + 'static + Clone,
R: Future<Output = T> + Send + 'static,
T: ResponseLike + 'static,
{
let buffer_size = self.buffer_size;
let should_insert_defaults = self.insert_default_headers;
#[cfg(feature = "websocket")]
let ws_handler = self.ws_handler.clone();
loop {
let (stream, addr) = self.next_stream().await;
smol::spawn(Self::keep_handling(
buffer_size,
should_insert_defaults,
stream,
addr,
handler.clone(),
#[cfg(feature = "websocket")]
ws_handler.clone(),
))
.detach();
}
}
async fn keep_handling<F, T, R>(
buffer_size: usize,
should_insert_defaults: bool,
mut stream: Stream,
addr: SocketAddr,
handler: F,
#[cfg(feature = "websocket")] ws_handler: WsHandler<Stream>,
) where
F: Fn(Request) -> R + Send + 'static,
R: Future<Output = T> + Send + 'static,
T: ResponseLike,
{
loop {
#[cfg_attr(not(feature = "websocket"), expect(unused_mut))]
let mut req = match Request::read_from(&mut stream, addr, buffer_size).await {
Ok(req) => req,
Err(e) if e.kind() == io::ErrorKind::InvalidInput => {
crate::response!(bad_request)
.send_to(&mut stream)
.await
.ok();
continue;
}
Err(e)
if e.kind() == io::ErrorKind::BrokenPipe
|| e.kind() == io::ErrorKind::ConnectionReset
|| e.kind() == io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(e) => {
eprintln!("[INTERNAL ERROR] {}", e);
crate::response!(internal_server_error)
.send_to(&mut stream)
.await
.ok();
break;
}
};
#[cfg(feature = "websocket")]
if let Err(new_stream) = maybe_websocket(ws_handler.clone(), stream, &mut req).await {
stream = new_stream;
} else {
break;
}
let keep_alive = req.keep_alive();
let mut response = handler(req)
.await
.to_response()
.maybe_add_defaults(should_insert_defaults);
let force_close = response
.headers
.get("connection")
.map(|s| s.to_ascii_lowercase())
== Some("close".to_string());
if keep_alive && !force_close {
response.set_header("connection", "keep-alive".into());
} else {
response.set_header("connection", "close".into());
};
let _ = response.send_to(&mut stream).await; }
}
}
impl Server {
#[inline]
pub async fn try_accept(&self) -> io::Result<(Stream, SocketAddr)> {
self.try_accept_inner().await
}
#[cfg(not(feature = "tls"))]
#[inline]
async fn try_accept_inner(&self) -> io::Result<(Stream, SocketAddr)> {
let (stream, addr) = self.acceptor.accept().await?;
stream.set_nodelay(true)?;
Ok((stream, addr))
}
#[cfg(feature = "tls")]
async fn try_accept_inner(&self) -> io::Result<(Stream, SocketAddr)> {
let (mut tcp_stream, ip) = self.acceptor.accept().await?;
tcp_stream.set_nodelay(true)?;
let mut buffer = [0; 2];
tcp_stream.peek(&mut buffer).await?;
if buffer == [0x16, 0x03] {
match self.tls_acceptor.accept(tcp_stream).await {
Ok(t) => Ok((t, ip)),
Err(_) => {
Err(io::Error::from(io::ErrorKind::ConnectionAborted))
}
}
} else {
self.handle_not_tls(&mut tcp_stream).await?;
Err(io::Error::from(io::ErrorKind::ConnectionAborted))
}
}
#[cfg(feature = "tls")]
async fn handle_not_tls<T: AsyncRead + AsyncWrite + Unpin>(
&self,
mut stream: T,
) -> io::Result<()> {
let mut buffer: Vec<u8> = vec![0; self.buffer_size];
let length = stream.read(&mut buffer).await?;
let mut path = vec![];
let mut in_path = false;
for byte in buffer.iter().take(length) {
if *byte == b' ' {
if in_path {
break;
} else {
in_path = true;
continue;
}
}
if in_path {
path.push(*byte);
}
}
let path = String::from_utf8_lossy(&path).to_string();
let mut res = crate::response!(
moved_permanently,
[],
crate::headers! {
"Location" => format!("https://{}{}", self.pretty_addr().unwrap_or_default(), path),
"Connection" => "keep-alive",
}
);
res.send_to(&mut stream).await?;
Ok(())
}
pub async fn next_stream(&mut self) -> (Stream, SocketAddr) {
loop {
match self.try_accept().await {
Ok(r) => return r,
Err(e)
if e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
|| e.kind() == io::ErrorKind::InvalidInput =>
{
continue;
}
Err(e) => {
eprintln!("[internal server error !!] {}", e);
eprintln!("[internal server error !!] {:#?}", e);
continue;
}
}
}
}
}