use crate::functional_traits::Router;
use crate::http::{Response, StatusCode};
use crate::stream::{ConnectionStream, IntoConnectionStream};
use crate::tii_builder::{ErrorHandler, NotFoundHandler, RouterWebSocketServingResponse};
use crate::tii_error::{TiiError, TiiResult};
use crate::{error_log, trace_log};
use crate::{warn_log, HttpHeaderName};
use crate::{ContinueHandler, RequestContext};
use crate::{HttpVersion, TypeSystem, TypeSystemBuilder};
use std::any::Any;
use std::fmt::{Debug, Formatter};
use std::io;
use std::io::ErrorKind;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::{Arc, Mutex};
use std::time::Duration;
pub trait ConnectionStreamMetadata: Any + Debug + Send + Sync {
fn as_any(&self) -> &dyn Any;
}
#[derive(Debug)]
struct PhantomStreamMetadata;
impl ConnectionStreamMetadata for PhantomStreamMetadata {
fn as_any(&self) -> &dyn Any {
crate::util::unreachable()
}
}
#[derive(Debug)]
pub struct Server {
type_system: TypeSystem,
shutdown: AtomicBool,
routers: Vec<Box<dyn Router>>,
error_handler: ErrorHandler,
not_found_handler: NotFoundHandler,
max_head_buffer_size: usize,
connection_timeout: Option<Duration>,
read_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
request_body_io_timeout: Option<Duration>,
write_timeout: Option<Duration>,
continue_handler: ContinueHandler,
shutdown_hooks: Hooks,
}
struct Hooks(Mutex<Vec<Box<dyn FnMut() + Send + Sync>>>);
impl Debug for Hooks {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.write_str("Hooks")
}
}
impl Default for Hooks {
fn default() -> Self {
Self(Mutex::new(Vec::new()))
}
}
impl Server {
#[expect(clippy::too_many_arguments)] pub(crate) fn new(
type_system: TypeSystemBuilder,
routers: Vec<Box<dyn Router>>,
error_handler: ErrorHandler,
not_found_handler: NotFoundHandler,
max_head_buffer_size: usize,
connection_timeout: Option<Duration>,
read_timeout: Option<Duration>,
keep_alive_timeout: Option<Duration>,
request_body_io_timeout: Option<Duration>,
write_timeout: Option<Duration>,
continue_handler: ContinueHandler,
) -> Self {
Server {
type_system: type_system.build(),
shutdown: AtomicBool::new(false),
routers,
error_handler,
not_found_handler,
max_head_buffer_size,
read_timeout,
connection_timeout: connection_timeout.or(read_timeout),
keep_alive_timeout: keep_alive_timeout.or(read_timeout),
request_body_io_timeout: request_body_io_timeout.or(read_timeout),
write_timeout,
continue_handler,
shutdown_hooks: Hooks::default(),
}
}
pub fn handle_connection<S: IntoConnectionStream>(&self, stream: S) -> TiiResult<()> {
self.handle_connection_inner::<S, PhantomStreamMetadata>(stream, None)
}
pub fn handle_connection_with_meta<S: IntoConnectionStream, M: ConnectionStreamMetadata>(
&self,
stream: S,
meta: M,
) -> TiiResult<()> {
self.handle_connection_inner(stream, Some(meta))
}
pub fn shutdown(&self) {
self.shutdown.store(true, SeqCst);
if let Ok(mut guard) = self.shutdown_hooks.0.lock() {
while let Some(mut hook) = guard.pop() {
hook()
}
}
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(SeqCst)
}
pub fn add_shutdown_hook<F: FnMut() + Sync + Send + 'static>(&self, mut hook: F) {
let Ok(mut guard) = self.shutdown_hooks.0.lock() else {
hook();
return;
};
if self.is_shutdown() {
drop(guard); hook();
return;
}
guard.push(Box::new(hook));
}
fn handle_connection_inner<S: IntoConnectionStream, M: ConnectionStreamMetadata>(
&self,
stream: S,
meta: Option<M>,
) -> TiiResult<()> {
if self.shutdown.load(SeqCst) {
return Err(TiiError::from_io_kind(ErrorKind::ConnectionAborted));
}
trace_log!("tii: tii:Server -> New connection");
let stream = stream.into_connection_stream();
stream.set_read_timeout(self.connection_timeout)?;
stream.set_write_timeout(self.write_timeout)?;
if !stream.ensure_readable()? {
return Err(TiiError::from_io_kind(ErrorKind::UnexpectedEof));
}
let meta = meta.map(|a| Arc::new(a) as Arc<dyn ConnectionStreamMetadata>);
let mut count = 0u64;
loop {
if count > 0 && !self.handle_keep_alive(stream.as_ref())? {
break;
}
stream.set_read_timeout(self.read_timeout)?;
let mut context = RequestContext::read(
stream.as_ref(),
meta.as_ref().cloned(),
self.max_head_buffer_size,
self.type_system.clone(),
)?;
count += 1;
if let Some(value) = context.get_header(HttpHeaderName::Expect) {
if value == "100-continue" && (self.continue_handler)(&mut context)? {
match context.get_version() {
HttpVersion::Http10 => _ = stream.write("HTTP/1.0 100 Continue\r\n\r\n".as_bytes())?,
HttpVersion::Http11 => _ = stream.write("HTTP/1.1 100 Continue\r\n\r\n".as_bytes())?,
_ => (),
};
stream.flush()?;
}
}
stream.set_read_timeout(self.request_body_io_timeout)?;
if context.get_version() == HttpVersion::Http11
&& context.get_header(&HttpHeaderName::Upgrade) == Some("websocket")
{
trace_log!("tii: Request {} is a web socket connection request", context.id());
for router in self.routers.iter() {
match router.serve_websocket(stream.as_ref(), &mut context)? {
RouterWebSocketServingResponse::HandledWithProtocolSwitch => return Ok(()),
RouterWebSocketServingResponse::HandledWithoutProtocolSwitch(response) => {
self.write_response(stream.as_ref(), context, false, response)?;
return Ok(());
}
RouterWebSocketServingResponse::NotHandled => (), }
}
let response = match (self.not_found_handler)(&mut context) {
Ok(res) => res,
Err(error) => (self.error_handler)(&mut context, error)
.unwrap_or_else(|e| self.fallback_error_handler(&mut context, e)),
};
self.write_response(stream.as_ref(), context, false, response)?;
return Ok(());
}
let mut keep_alive = !self.is_shutdown()
&& context.get_version() == HttpVersion::Http11
&& self.keep_alive_timeout.as_ref().map(|a| !a.is_zero()).unwrap_or(true)
&& context
.get_header(&HttpHeaderName::Connection)
.map(|e| e.eq_ignore_ascii_case("keep-alive"))
.unwrap_or_default();
let mut response = None;
for router in self.routers.iter() {
response = Some(match router.serve(&mut context) {
Ok(Some(resp)) => resp,
Ok(None) => continue,
Err(error) => (self.error_handler)(&mut context, error)
.unwrap_or_else(|e| self.fallback_error_handler(&mut context, e)),
});
break;
}
let response = response.unwrap_or_else(|| match (self.not_found_handler)(&mut context) {
Ok(res) => res,
Err(error) => (self.error_handler)(&mut context, error)
.unwrap_or_else(|e| self.fallback_error_handler(&mut context, e)),
});
if response.omit_body {
context.force_connection_close();
}
keep_alive &= !context.is_connection_close_forced();
let id = context.id();
self.write_response(stream.as_ref(), context, keep_alive, response)?;
if !keep_alive {
trace_log!("tii: Request {} will NOT do keep alive", id);
break;
}
trace_log!("tii: Request {} will attempt to do keep alive", id);
}
trace_log!("tii: tii:Server -> Connection closed");
Ok(())
}
fn handle_keep_alive(&self, stream: &dyn ConnectionStream) -> TiiResult<bool> {
if self.is_shutdown() {
trace_log!("tii: Keep-alive server shutting down...");
return Ok(false);
}
if stream.available() > 0 {
trace_log!("tii: Keep-alive client sent data. Processing next request...");
return Ok(true);
}
stream.set_read_timeout(self.keep_alive_timeout)?;
match stream.ensure_readable() {
Ok(true) => {
trace_log!("tii: Keep-alive client sent data. Processing next request...");
Ok(true)
}
Ok(false) => {
trace_log!("tii: Keep-alive client disconnected before timeout expired.");
Ok(false)
}
Err(err) => match err.kind() {
ErrorKind::UnexpectedEof => {
trace_log!("tii: Keep-alive client disconnected before timeout expired.");
Ok(false)
}
ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted | ErrorKind::BrokenPipe => {
trace_log!("tii: Keep-alive OS reset connection before timeout expired.");
Ok(false)
}
ErrorKind::TimedOut | ErrorKind::WouldBlock => {
trace_log!("tii: Keep-alive time out closing connection.");
Ok(false)
}
_ => {
error_log!("tii: Keep-alive unspecified error when waiting for data {}", &err);
Err(err.into())
}
},
}
}
pub fn read_timeout(&self) -> Option<Duration> {
self.read_timeout
}
pub fn connection_timeout(&self) -> Option<Duration> {
self.connection_timeout
}
pub fn write_timeout(&self) -> Option<Duration> {
self.write_timeout
}
pub fn keep_alive_timeout(&self) -> Option<Duration> {
self.keep_alive_timeout
}
pub fn request_body_io_timeout(&self) -> Option<Duration> {
self.request_body_io_timeout
}
fn write_response(
&self,
stream: &dyn ConnectionStream,
request: RequestContext,
keep_alive: bool,
mut response: Response,
) -> TiiResult<()> {
if request.get_version() == HttpVersion::Http11 {
let previous_headers = if keep_alive {
response.headers.replace_all(HttpHeaderName::Connection, "Keep-Alive")
} else {
response.headers.replace_all(HttpHeaderName::Connection, "Close")
};
if !previous_headers.is_empty() {
error_log!(
"tii: Request {} Endpoint has set banned header 'Connection' {:?}",
request.id(),
previous_headers
);
return Err(TiiError::new_io(
io::ErrorKind::InvalidInput,
"Endpoint has set banned header 'Connection'",
));
}
}
trace_log!(
"tii: Request {} responding with HTTP {}",
request.id(),
response.status_code.code()
);
if let Some(enc) = response.get_body().and_then(|a| a.get_content_encoding()) {
if enc == "gzip" && !request.accepts_gzip() {
warn_log!("tii: Request {} responding with gzip even tho client doesnt indicate that it can understand gzip.", request.id());
}
}
#[cfg(feature = "log")]
let status = response.get_status_code_number();
response.write_to(request.id(), request.get_version(), stream.as_stream_write()).inspect_err(
|e| {
error_log!("tii: Request {} response.write_to error={}", request.id(), e);
},
)?;
#[cfg(feature = "log")]
{
let now: u128 = std::time::SystemTime::now()
.duration_since(std::time::SystemTime::UNIX_EPOCH)
.map(|a| a.as_millis())
.unwrap_or_default();
let diff = now.checked_sub(request.get_timestamp()).unwrap_or_default();
crate::info_log!(
"tii: Request {} from {} to {} {} ({}) served in {}ms",
request.id(),
request.peer_address(),
request.get_method(),
request.get_path(),
status,
diff
);
}
request.consume_request_body()?;
Ok(())
}
fn fallback_error_handler(&self, request: &mut RequestContext, error: TiiError) -> Response {
request.force_connection_close();
error_log!(
"tii: Request {} Error handler failed. Will respond with empty Internal Server Error {} {} {:?}",
request.id(),
&request.get_method(),
request.get_path(),
error
);
Response::new(StatusCode::InternalServerError)
}
}
impl Drop for Server {
fn drop(&mut self) {
self.shutdown();
trace_log!("tii: Server::drop");
}
}