use super::{
config::{
Binding, CacheNode, Config, Directive, ElementOp, FilesDirective, HeaderOp,
HeadersDirective, HttpConfigNode, ProxyDirective, RedirectDirective, RewriteHtmlDirective,
Route, SelectBlock,
},
upstream,
};
use crate::{directory_listing::DirectoryListing, tls::Tls};
use trillium::{BoxedHandler, Conn, Handler, HttpConfig, KnownHeaderName, Method, Status};
use trillium_cache::{InMemoryStorage, client::Cache};
use trillium_client::Client;
use trillium_html_rewriter::{
HtmlRewriter, Settings,
html::{element, html_content::ContentType},
};
use trillium_logger::Logger;
use trillium_proxy::Proxy;
use trillium_router::Router;
use trillium_server_common::{ServerHandle, Swansong};
use trillium_static::StaticFileHandler;
const DEFAULT_CACHE_CAPACITY: u64 = 256 * 1024 * 1024;
const DEFAULT_CACHE_MAX_BODY: u64 = 16 * 1024 * 1024;
pub fn build_client(config: &Config) -> Client {
let client = Client::from(Tls::default());
match &config.cache {
None => client,
Some(cache) => client.with_handler(build_cache(cache)),
}
}
fn build_cache(cache: &CacheNode) -> impl trillium_client::ClientHandler {
let capacity = cache
.capacity
.as_deref()
.map_or(DEFAULT_CACHE_CAPACITY, parse_size);
let max_body = cache
.max_body
.as_deref()
.map_or(DEFAULT_CACHE_MAX_BODY, parse_size);
let mut storage = InMemoryStorage::new().with_max_capacity_bytes(capacity);
if let Some(tti) = &cache.time_to_idle {
storage = storage.with_time_to_idle(parse_duration(tti));
}
if let Some(ttl) = &cache.time_to_live {
storage = storage.with_time_to_live(parse_duration(ttl));
}
Cache::new(storage)
.with_max_cacheable_size(max_body)
.shared()
}
const ROUTE_METHODS: &[Method] = &[
Method::Get,
Method::Head,
Method::Post,
Method::Put,
Method::Delete,
Method::Patch,
Method::Options,
Method::Connect,
Method::Trace,
];
pub fn print_startup(config: &Config) {
use colored::Colorize;
for binding in &config.bindings {
let (host, port) = parse_listen(&binding.listen);
let scheme = if binding.tls.is_some() {
"https"
} else {
"http"
};
println!("{}", format!("{scheme}://{host}:{port}").bold().green());
for hostblock in &binding.hosts {
println!(" {}", hostblock.patterns.join(" ").yellow());
print_routes(&hostblock.routes, 4);
}
if !binding.routes.is_empty() {
if !binding.hosts.is_empty() {
println!(" {}", "(default)".yellow().dimmed());
}
print_routes(
&binding.routes,
if binding.hosts.is_empty() { 2 } else { 4 },
);
}
}
}
fn print_routes(routes: &[Route], indent: usize) {
use colored::Colorize;
let width = routes.iter().map(|r| r.pattern.len()).max().unwrap_or(0);
for route in routes {
let directives = route
.directives
.iter()
.map(describe_directive)
.collect::<Vec<_>>()
.join(", ");
println!(
"{:indent$}{:<width$} {} {directives}",
"",
route.pattern.cyan(),
"→".dimmed(),
);
}
}
fn describe_directive(directive: &Directive) -> String {
match directive {
Directive::Files(f) => format!("files {}", f.root.display()),
Directive::Proxy(p) => format!(
"proxy {}",
p.upstreams
.iter()
.map(|u| u.url.as_str())
.collect::<Vec<_>>()
.join(", ")
),
Directive::Redirect(r) => format!("redirect {}", r.to),
Directive::Headers(_) => "headers".to_string(),
Directive::RewriteHtml(r) => format!("rewrite-html ({} selectors)", r.selects.len()),
}
}
pub fn spawn_binding(
binding: &Binding,
config: &Config,
swansong: &Swansong,
client: &Client,
) -> ServerHandle {
let (host, port) = parse_listen(&binding.listen);
let mut server = trillium_smol::config()
.with_host(&host)
.with_port(port)
.with_swansong(swansong.clone())
.without_signals();
if let Some(http) = &binding.http {
server = server.with_http_config(http_config(http));
if let Some(max) = http.max_connections {
server = server.with_max_connections(Some(max));
}
}
let handler = binding_handler(binding, config, client);
if let Some(tls) = super::sni::build(binding) {
#[cfg(feature = "h3")]
return server
.with_acceptor(tls.acceptor)
.with_quic(tls.quic)
.spawn(handler);
#[cfg(not(feature = "h3"))]
return server.with_acceptor(tls.acceptor).spawn(handler);
}
server.spawn(handler)
}
fn http_config(node: &HttpConfigNode) -> HttpConfig {
let mut cfg = HttpConfig::default();
if let Some(s) = &node.received_body_max_len {
cfg = cfg.with_received_body_max_len(parse_size(s));
}
if let Some(s) = &node.head_max_len {
cfg = cfg.with_head_max_len(parse_size(s) as usize);
}
cfg
}
fn parse_size(s: &str) -> u64 {
let size = size::Size::from_str(s).unwrap_or_else(|e| panic!("invalid size {s:?}: {e}"));
u64::try_from(size.bytes()).unwrap_or_else(|_| panic!("size {s:?} must not be negative"))
}
fn parse_duration(s: &str) -> std::time::Duration {
humantime::parse_duration(s).unwrap_or_else(|e| panic!("invalid duration {s:?}: {e}"))
}
pub fn binding_handler(binding: &Binding, config: &Config, client: &Client) -> impl Handler {
let dispatcher = if binding.hosts.is_empty() {
BoxedHandler::new(build_router(&binding.routes, client))
} else {
let hosts = binding
.hosts
.iter()
.map(|h| (h.patterns.clone(), build_router(&h.routes, client)))
.collect();
let default = (!binding.routes.is_empty()).then(|| build_router(&binding.routes, client));
BoxedHandler::new(super::host::HostRouter::new(hosts, default))
};
let compression = config
.compression
.unwrap_or(true)
.then(trillium_compression::compression);
let rate_limit = config.rate_limit.as_ref().map(|rl| {
crate::ratelimit::limiter_for(&rl.rate, rl.burst)
.unwrap_or_else(|e| panic!("invalid rate-limit {:?}: {e}", rl.rate))
});
let caching_headers = config
.cache
.is_some()
.then(trillium_caching_headers::caching_headers);
(
Logger::new().without_init_message(),
rate_limit,
caching_headers,
compression,
dispatcher,
)
}
fn build_router(routes: &[Route], client: &Client) -> Router {
let mut router = Router::new();
for route in routes {
router = router.any(
ROUTE_METHODS,
route.pattern.as_str(),
route_stack(route, client),
);
}
router
}
fn route_stack(route: &Route, client: &Client) -> Vec<BoxedHandler> {
let mut stack = Vec::new();
for directive in &route.directives {
push_directive(&mut stack, directive, client);
}
stack
}
fn push_directive(stack: &mut Vec<BoxedHandler>, directive: &Directive, client: &Client) {
match directive {
Directive::Files(files) => push_files(stack, files),
Directive::Proxy(proxy) => push_proxy(stack, proxy, client),
Directive::Redirect(redirect) => stack.push(BoxedHandler::new(Redirect::new(redirect))),
Directive::Headers(headers) => stack.push(BoxedHandler::new(Headers::new(headers))),
Directive::RewriteHtml(rewrite) => push_rewrite_html(stack, rewrite),
}
}
fn push_rewrite_html(stack: &mut Vec<BoxedHandler>, rewrite: &RewriteHtmlDirective) {
let selects = rewrite.selects.clone();
let handler = HtmlRewriter::new(move || Settings {
element_content_handlers: selects
.iter()
.cloned()
.map(|SelectBlock { selector, ops }| {
element!(selector, move |el| {
for op in &ops {
match op {
ElementOp::SetAttribute(name, value) => {
let _ = el.set_attribute(name, value);
}
ElementOp::RemoveAttribute(name) => el.remove_attribute(name),
ElementOp::Before(html) => el.before(html, ContentType::Html),
ElementOp::After(html) => el.after(html, ContentType::Html),
ElementOp::Prepend(html) => el.prepend(html, ContentType::Html),
ElementOp::Append(html) => el.append(html, ContentType::Html),
ElementOp::SetInner(html) => {
el.set_inner_content(html, ContentType::Html)
}
ElementOp::SetText(text) => {
el.set_inner_content(text, ContentType::Text)
}
ElementOp::Replace(html) => el.replace(html, ContentType::Html),
ElementOp::SetTag(name) => {
let _ = el.set_tag_name(name);
}
ElementOp::Remove => el.remove(),
ElementOp::Unwrap => el.remove_and_keep_content(),
}
}
Ok(())
})
})
.collect(),
..Settings::new_send()
});
stack.push(BoxedHandler::new(handler));
}
fn push_files(stack: &mut Vec<BoxedHandler>, files: &FilesDirective) {
let mut handler = StaticFileHandler::new(&files.root);
if let Some(index) = &files.index {
handler = handler.with_index_file(index);
}
stack.push(BoxedHandler::new(handler));
if files.directory_listing.unwrap_or(false) {
stack.push(BoxedHandler::new(DirectoryListing));
}
}
fn push_proxy(stack: &mut Vec<BoxedHandler>, proxy: &ProxyDirective, client: &Client) {
let handler = Proxy::new(client.clone(), upstream::build_selector(proxy))
.with_via_pseudonym("trillium-gateway")
.with_websocket_upgrades()
.proxy_not_found();
stack.push(BoxedHandler::new(handler));
}
pub fn parse_listen(listen: &str) -> (String, u16) {
let (host, port) = listen
.rsplit_once(':')
.unwrap_or_else(|| panic!("listen must be host:port or :port (got {listen:?})"));
let port = port
.parse()
.unwrap_or_else(|_| panic!("invalid port in listen {listen:?}"));
let host = if host.is_empty() {
"0.0.0.0".to_string()
} else {
host.to_string()
};
(host, port)
}
#[derive(Debug, Clone)]
struct Redirect {
to: String,
status: Status,
}
impl Redirect {
fn new(redirect: &RedirectDirective) -> Self {
let status = match redirect.status {
Some(code) => {
Status::try_from(code).unwrap_or_else(|_| panic!("invalid redirect status {code}"))
}
None => Status::Found,
};
Self {
to: redirect.to.clone(),
status,
}
}
}
impl Handler for Redirect {
async fn run(&self, conn: Conn) -> Conn {
conn.with_response_header(KnownHeaderName::Location, self.to.clone())
.with_status(self.status)
.halt()
}
}
#[derive(Debug, Clone)]
struct Headers {
ops: Vec<HeaderOp>,
}
impl Headers {
fn new(headers: &HeadersDirective) -> Self {
Self {
ops: headers.ops.clone(),
}
}
}
impl Handler for Headers {
async fn run(&self, conn: Conn) -> Conn {
conn
}
async fn before_send(&self, mut conn: Conn) -> Conn {
let headers = conn.response_headers_mut();
for op in &self.ops {
match op {
HeaderOp::Add(name, value) => {
headers.append(name.clone(), value.clone());
}
HeaderOp::Set(name, value) => {
headers.insert(name.clone(), value.clone());
}
HeaderOp::Remove(name) => {
headers.remove(name.clone());
}
}
}
conn
}
}