use crate::Request;
use crate::ResponseLike;
pub const DEFAULT_BUFFER_SIZE: usize = 1024 * 8;
use std::{
io,
net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
};
#[cfg(feature = "tls")]
use native_tls::{TlsAcceptor, TlsStream};
#[cfg(not(feature = "tls"))]
pub type Stream = TcpStream;
#[cfg(feature = "tls")]
pub type Stream = TlsStream<TcpStream>;
#[cfg(feature = "websocket")]
use crate::ws::{maybe_websocket, WebSocket};
#[cfg(feature = "async")]
use std::future::Future;
pub struct Server {
acceptor: TcpListener,
buffer_size: usize,
insert_default_headers: bool,
#[cfg(feature = "tls")]
tls_acceptor: TlsAcceptor,
#[cfg(feature = "websocket")]
ws_handler: Option<(&'static str, fn(WebSocket<&mut Stream>))>,
}
impl Server {
#[cfg(not(feature = "tls"))]
pub fn new(addr: impl ToSocketAddrs) -> io::Result<Self> {
Ok(Self {
acceptor: TcpListener::bind(addr)?,
buffer_size: DEFAULT_BUFFER_SIZE,
#[cfg(feature = "websocket")]
ws_handler: None,
insert_default_headers: false,
})
}
#[cfg(feature = "tls")]
pub fn new_with_tls(addr: impl ToSocketAddrs, tls_acceptor: TlsAcceptor) -> io::Result<Self> {
Ok(Self {
acceptor: TcpListener::bind(addr)?,
buffer_size: DEFAULT_BUFFER_SIZE,
tls_acceptor,
#[cfg(feature = "websocket")]
ws_handler: None,
insert_default_headers: false,
})
}
pub fn with_default_headers(mut self) -> Self {
self.insert_default_headers = true;
self
}
#[inline]
pub fn addr(&self) -> io::Result<SocketAddr> {
self.acceptor.local_addr()
}
pub fn pretty_addr(&self) -> io::Result<String> {
self.addr().map(crate::util::format_addr)
}
pub fn set_buffer_size(&mut self, size: usize) {
self.buffer_size = size;
}
pub fn with_buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
#[cfg(feature = "websocket")]
pub fn on_websocket(mut self, path: &'static str, handler: fn(WebSocket<&mut Stream>)) -> Self {
self.ws_handler = Some((path, handler));
self
}
pub fn run<T: ResponseLike>(
self,
handler: impl Fn(Request) -> T + Send + 'static + Clone,
) -> ! {
#[cfg(feature = "websocket")]
let ws_handler = self.ws_handler.clone();
let should_insert = self.insert_default_headers;
#[cfg_attr(not(feature = "websocket"), allow(unused_mut))]
for (mut stream, mut request) in self {
let handler = handler.clone();
std::thread::spawn(move || {
#[cfg(feature = "websocket")]
if maybe_websocket(ws_handler, &mut stream, &mut request) {
return Ok(());
};
handler(request)
.to_response()
.maybe_add_defaults(should_insert)
.send_to(&mut stream)
});
}
unreachable!("Server::run() should never return")
}
#[cfg(feature = "async")]
pub fn run_async<F, T, R>(self, handler: F) -> !
where
F: Fn(Request) -> R + Send + 'static + Clone,
R: Future<Output = T> + Send + 'static,
T: ResponseLike,
{
#[cfg(feature = "websocket")]
let ws_handler = self.ws_handler.clone();
let should_insert = self.insert_default_headers;
#[cfg_attr(not(feature = "websocket"), allow(unused_mut))]
for (mut stream, mut request) in self {
let handler = handler.clone();
async_std::task::spawn(async move {
#[cfg(feature = "websocket")]
if maybe_websocket(ws_handler, &mut stream, &mut request) {
return Ok(());
};
handler(request)
.await
.to_response()
.maybe_add_defaults(should_insert)
.send_to(&mut stream)
});
}
unreachable!("Server::run() should never return")
}
}
impl Server {
#[inline]
pub fn try_accept(&self) -> io::Result<(Stream, Request)> {
self.try_accept_inner()
}
#[cfg(not(feature = "tls"))]
#[inline]
fn try_accept_inner(&self) -> io::Result<(Stream, Request)> {
let (stream, ip) = self.acceptor.accept()?;
self.handle_request(stream, ip)
}
#[cfg(feature = "tls")]
fn try_accept_inner(&self) -> io::Result<(Stream, Request)> {
let (mut tcp_stream, ip) = self.acceptor.accept()?;
let mut buffer = [0; 2];
tcp_stream.peek(&mut buffer)?;
if buffer == [0x16, 0x03] {
match self.tls_acceptor.accept(tcp_stream) {
Ok(tls_stream) => self.handle_request(tls_stream, ip),
Err(_) => {
Err(io::Error::from(io::ErrorKind::ConnectionAborted))
}
}
} else {
self.handle_not_tls(&mut tcp_stream)?;
Err(io::Error::from(io::ErrorKind::ConnectionAborted))
}
}
fn handle_request<T: io::Write + io::Read>(
&self,
mut stream: T,
ip: SocketAddr,
) -> io::Result<(T, Request)> {
let mut buffer: Vec<u8> = vec![0; self.buffer_size];
let payload_size = stream.read(&mut buffer)?;
if payload_size > self.buffer_size {
crate::response!(payload_too_large).send_to(&mut stream)?;
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Payload too large",
));
}
if payload_size == 0 {
crate::response!(bad_request).send_to(&mut stream)?;
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Empty request"));
}
let req = match Request::new(&buffer[..payload_size], ip) {
Some(req) => req,
None => return Err(io::Error::from(io::ErrorKind::InvalidInput)),
};
Ok((stream, req))
}
#[cfg(feature = "tls")]
fn handle_not_tls<T: io::Read + io::Write>(&self, mut stream: T) -> io::Result<()> {
let mut buffer: Vec<u8> = vec![0; self.buffer_size];
stream.read(&mut buffer)?;
let mut path = vec![];
let mut in_path = false;
for byte in buffer.iter() {
if *byte == b' ' {
if in_path {
break;
} else {
in_path = true;
continue;
}
}
if in_path {
path.push(*byte);
}
}
let path = String::from_utf8_lossy(&path).to_string();
crate::response!(
moved_permanently,
[],
crate::headers! {
"Location" => format!("https://{}{}", self.pretty_addr().unwrap_or_default(), path),
"Connection" => "keep-alive",
"Content-Length" => 0
}
)
.send_to(&mut stream)?;
Ok(())
}
}
impl Iterator for Server {
type Item = (Stream, Request);
fn next(&mut self) -> Option<Self::Item> {
match self.try_accept() {
Ok(r) => Some(r),
Err(e)
if e.kind() == io::ErrorKind::ConnectionAborted
|| e.kind() == io::ErrorKind::ConnectionReset
|| e.kind() == io::ErrorKind::InvalidInput =>
{
self.next()
}
Err(e) => {
eprintln!("Server generated error: {:#?}", e);
self.next() }
}
}
}