use crate::{
Acceptor, ArcHandler, BoxedAcceptor, BoxedQuicConfig, ListenerConfig, QuicConfig, RuntimeTrait,
Server, ServerHandle,
server::{PreboundListener, resolve_listener},
};
use async_cell::sync::AsyncCell;
use futures_lite::StreamExt;
use std::{
cell::OnceCell,
net::{SocketAddr, UdpSocket as StdUdpSocket},
pin::pin,
sync::Arc,
};
use trillium::{
Handler, Headers, HttpConfig, Info, KnownHeaderName, Listener, Listeners, Swansong, TypeSet,
};
use trillium_http::HttpContext;
use url::Url;
#[derive(Debug)]
pub struct Config<ServerType: Server, AcceptorType, QuicType: QuicConfig<ServerType> = ()> {
pub(crate) acceptor: AcceptorType,
pub(crate) quic: QuicType,
pub(crate) binding: Option<ServerType>,
pub(crate) host: Option<String>,
pub(crate) context_cell: Arc<AsyncCell<Arc<HttpContext>>>,
pub(crate) max_connections: Option<usize>,
pub(crate) nodelay: bool,
pub(crate) port: Option<u16>,
pub(crate) register_signals: bool,
pub(crate) runtime: ServerType::Runtime,
pub(crate) context: HttpContext,
}
impl<ServerType, AcceptorType, QuicType> Config<ServerType, AcceptorType, QuicType>
where
ServerType: Server,
AcceptorType: Acceptor<ServerType::Transport>,
QuicType: QuicConfig<ServerType>,
{
pub fn run(self, handler: impl Handler) {
self.runtime.clone().block_on(self.run_async(handler));
}
pub async fn run_async(self, handler: impl Handler) {
let Self {
runtime,
acceptor,
quic,
max_connections,
nodelay,
binding,
host,
port,
register_signals,
context,
context_cell,
} = self;
let builder = ListenerConfig::<ServerType>::from_global(
context,
context_cell,
runtime,
max_connections,
nodelay,
register_signals,
);
let builder = match binding {
Some(server) => {
if quic.is_configured() {
log::warn!(
"QUIC configuration is ignored when a prebound server is supplied; use a \
multi-listener ListenerConfig to bind QUIC explicitly"
);
}
log::debug!("taking prebound listener");
builder.bind_server_boxed(server, BoxedAcceptor::new(acceptor))
}
None => {
let host = host
.or_else(|| std::env::var("HOST").ok())
.unwrap_or_else(|| "localhost".into());
let port = port
.or_else(|| {
std::env::var("PORT")
.ok()
.map(|x| x.parse().expect("PORT must be an unsigned integer"))
})
.unwrap_or(8080);
if quic.is_configured() {
match resolve_listener(&host, port)
.unwrap_or_else(|e| panic!("failed to bind {host}:{port}: {e}"))
{
PreboundListener::Tcp(tcp) => {
let addr = tcp
.local_addr()
.expect("a bound tcp listener has a local address");
let builder = builder.push_listener(
PreboundListener::Tcp(tcp),
BoxedAcceptor::new(acceptor),
);
let socket = StdUdpSocket::bind(addr).unwrap_or_else(|e| {
panic!("failed to bind QUIC UDP socket at {addr}: {e}")
});
builder.push_quic_listener(socket, BoxedQuicConfig::new(quic))
}
#[cfg(unix)]
PreboundListener::Unix(unix) => {
log::warn!("QUIC configuration is ignored on a unix-domain listener");
builder.push_listener(
PreboundListener::Unix(unix),
BoxedAcceptor::new(acceptor),
)
}
}
} else {
let server = ServerType::from_host_and_port(&host, port);
builder.bind_server_boxed(server, BoxedAcceptor::new(acceptor))
}
}
};
builder.run_async(handler).await;
}
pub fn spawn(self, handler: impl Handler) -> ServerHandle {
let server_handle = self.handle();
self.runtime.clone().spawn(self.run_async(handler));
server_handle
}
pub fn handle(&self) -> ServerHandle {
ServerHandle {
swansong: self.context.swansong().clone(),
context: self.context_cell.clone(),
received_context: OnceCell::new(),
runtime: self.runtime().into(),
}
}
pub fn with_port(mut self, port: u16) -> Self {
if self.has_binding() {
log::warn!(
"constructing a config with both a port and a pre-bound listener will ignore the \
port"
);
}
self.port = Some(port);
self
}
pub fn with_host(mut self, host: &str) -> Self {
if self.has_binding() {
log::warn!(
"constructing a config with both a host and a pre-bound listener will ignore the \
host"
);
}
self.host = Some(host.into());
self
}
pub fn without_signals(mut self) -> Self {
self.register_signals = false;
self
}
pub fn with_nodelay(mut self) -> Self {
self.nodelay = true;
self
}
pub fn with_socketaddr(self, socketaddr: SocketAddr) -> Self {
self.with_host(&socketaddr.ip().to_string())
.with_port(socketaddr.port())
}
pub fn with_acceptor<A: Acceptor<ServerType::Transport>>(
self,
acceptor: A,
) -> Config<ServerType, A, QuicType> {
Config {
acceptor,
quic: self.quic,
host: self.host,
port: self.port,
nodelay: self.nodelay,
register_signals: self.register_signals,
max_connections: self.max_connections,
context_cell: self.context_cell,
context: self.context,
binding: self.binding,
runtime: self.runtime,
}
}
pub fn with_quic<Q: QuicConfig<ServerType>>(
self,
quic: Q,
) -> Config<ServerType, AcceptorType, Q> {
Config {
acceptor: self.acceptor,
quic,
host: self.host,
port: self.port,
nodelay: self.nodelay,
register_signals: self.register_signals,
max_connections: self.max_connections,
context_cell: self.context_cell,
context: self.context,
binding: self.binding,
runtime: self.runtime,
}
}
pub fn with_swansong(mut self, swansong: Swansong) -> Self {
self.context.set_swansong(swansong);
self
}
pub fn with_max_connections(mut self, max_connections: Option<usize>) -> Self {
self.max_connections = max_connections;
self
}
pub fn with_http_config(mut self, config: HttpConfig) -> Self {
*self.context.config_mut() = config;
self
}
pub fn with_prebound_server(mut self, server: impl Into<ServerType>) -> Self {
if self.host.is_some() {
log::warn!(
"constructing a config with both a host and a pre-bound listener will ignore the \
host"
);
}
if self.port.is_some() {
log::warn!(
"constructing a config with both a port and a pre-bound listener will ignore the \
port"
);
}
self.binding = Some(server.into());
self
}
fn has_binding(&self) -> bool {
self.binding.is_some()
}
pub fn runtime(&self) -> ServerType::Runtime {
self.runtime.clone()
}
pub fn port(&self) -> Option<u16> {
self.port
}
pub fn host(&self) -> Option<&str> {
self.host.as_deref()
}
pub fn with_shared_state<T: Send + Sync + 'static>(mut self, state: T) -> Self {
self.context.shared_state_mut().insert(state);
self
}
pub fn set_shared_state<T: Send + Sync + 'static>(&mut self, state: T) -> &mut Self {
self.context.shared_state_mut().insert(state);
self
}
}
impl<ServerType: Server> Config<ServerType, ()> {
pub fn new() -> Self {
Self::default()
}
pub fn listeners(self) -> ListenerConfig<ServerType> {
if self.host.is_some() || self.port.is_some() || self.binding.is_some() {
log::warn!(
"Config::listeners() does not carry over host/port/prebound-server configuration; \
bind listeners explicitly on the returned ListenerConfig"
);
}
ListenerConfig::from_global(
self.context,
self.context_cell,
self.runtime,
self.max_connections,
self.nodelay,
self.register_signals,
)
}
}
impl<ServerType: Server> Default for Config<ServerType, ()> {
fn default() -> Self {
Self {
acceptor: (),
quic: (),
port: None,
host: None,
nodelay: false,
register_signals: cfg!(unix),
max_connections: None,
context_cell: AsyncCell::shared(),
binding: None,
runtime: ServerType::runtime(),
context: Default::default(),
}
}
}
pub(crate) fn info_with_server_header<ServerType: Server>(
context: HttpContext,
runtime: &ServerType::Runtime,
) -> Info {
let mut info = Info::from(context)
.with_shared_state(runtime.clone().into())
.with_shared_state(runtime.clone());
info.shared_state_entry::<Headers>()
.or_default()
.try_insert(KnownHeaderName::Server, trillium::headers::server_header());
info
}
#[cfg_attr(not(unix), allow(unused_mut))]
pub(crate) async fn init_shared<ServerType, QuicType, H>(
mut info: Info,
runtime: ServerType::Runtime,
quic: QuicType,
mut max_connections: Option<usize>,
is_secure: bool,
mut handler: H,
) -> (
Arc<HttpContext>,
ArcHandler<H>,
Option<QuicType::Endpoint>,
Option<usize>,
)
where
ServerType: Server,
QuicType: QuicConfig<ServerType>,
H: Handler,
{
#[cfg(unix)]
if max_connections.is_none() {
max_connections = rlimit::getrlimit(rlimit::Resource::NOFILE)
.ok()
.and_then(|(soft, _hard)| soft.try_into().ok())
.map(|limit: usize| ((limit as f32) * 0.75) as usize);
}
log::debug!("using max connections of {max_connections:?}");
let quic_binding = if let Some(socket_addr) = info.tcp_socket_addr().copied() {
let quic_binding = quic
.bind(socket_addr, runtime, &mut info)
.map(|r| r.expect("failed to bind QUIC endpoint"));
if quic_binding.is_some() {
info.shared_state_entry::<Headers>()
.or_default()
.try_insert_with(KnownHeaderName::AltSvc, || -> &'static str {
format!("h3=\":{}\"", socket_addr.port()).leak()
});
}
quic_binding
} else {
None
};
if info.shared_state::<Listeners>().is_none()
&& let Some(primary) = primary_listener(&info, is_secure)
{
let mut listeners = vec![primary];
if quic_binding.is_some()
&& let Some(addr) = info.tcp_socket_addr().copied()
{
listeners.push(Listener::quic(addr));
}
info.insert_shared_state(Listeners(listeners));
}
insert_url(info.as_mut(), is_secure);
handler.init(&mut info).await;
let context = Arc::new(HttpContext::from(info));
let handler = ArcHandler::new(handler);
(context, handler, quic_binding, max_connections)
}
pub(crate) fn spawn_signals_loop<R: RuntimeTrait>(
context: Arc<HttpContext>,
register: bool,
runtime: R,
) {
if !register {
return;
}
let swansong = context.swansong().clone();
runtime.clone().spawn(async move {
let mut signals = pin!(runtime.hook_signals([2, 3, 15]));
while signals.next().await.is_some() {
let guard_count = swansong.guard_count();
if swansong.state().is_shutting_down() {
eprintln!(
"\nSecond interrupt, shutting down harshly (dropping {guard_count} guards)"
);
std::process::exit(1);
} else {
println!(
"\nShutting down gracefully. Waiting for {guard_count} shutdown guards to \
drop.\nControl-c again to force."
);
swansong.shut_down();
}
}
});
}
pub(crate) fn primary_listener(info: &Info, is_secure: bool) -> Option<Listener> {
if let Some(addr) = info.tcp_socket_addr().copied() {
return Some(Listener::tcp(addr, is_secure));
}
#[cfg(unix)]
if let Some(path) = info
.unix_socket_addr()
.and_then(|addr| addr.as_pathname().map(std::path::Path::to_path_buf))
{
return Some(Listener::unix(Some(path), is_secure));
}
None
}
fn insert_url(state: &mut TypeSet, secure: bool) -> Option<()> {
let socket_addr = state.get::<SocketAddr>().copied()?;
let vacant_entry = state.entry::<Url>().into_vacant()?;
let host = if socket_addr.ip().is_loopback() {
"localhost".to_string()
} else {
socket_addr.ip().to_string()
};
let url = match (secure, socket_addr.port()) {
(true, 443) => format!("https://{host}"),
(false, 80) => format!("http://{host}"),
(true, port) => format!("https://{host}:{port}/"),
(false, port) => format!("http://{host}:{port}/"),
};
let url = Url::parse(&url).ok()?;
vacant_entry.insert(url);
Some(())
}