mod proxy;
use crate::{
common::{nonce, LOCAL, NETWORK, SERVER},
config::rt::RtcServe,
tls::TlsConfig,
watch::WatchSystem,
ws,
};
use anyhow::{Context, Result};
use axum::{
body::{Body, Bytes},
extract::{self, ws::WebSocketUpgrade},
http::{
header::{HeaderName, CONTENT_LENGTH, CONTENT_TYPE, HOST},
HeaderValue, StatusCode,
},
middleware::Next,
response::{IntoResponse, Response},
routing::{get, get_service, Router},
};
use axum_server::Handle;
use futures_util::FutureExt;
use hickory_resolver::TokioResolver;
use http::{header::CONTENT_SECURITY_POLICY, HeaderMap};
use proxy::{ProxyBuilder, ProxyClientOptions};
use std::{
collections::{BTreeSet, HashMap, HashSet},
net::{IpAddr, Ipv4Addr, SocketAddr},
path::PathBuf,
sync::Arc,
time::Duration,
};
use tokio::{
select,
sync::{broadcast, watch},
task::JoinHandle,
};
use tower_http::{
services::{ServeDir, ServeFile},
set_header::SetResponseHeaderLayer,
trace::TraceLayer,
};
use tracing::log;
const INDEX_HTML: &str = "index.html";
pub struct ServeSystem {
cfg: Arc<RtcServe>,
watch: WatchSystem,
open_http_addr: String,
shutdown_tx: broadcast::Sender<()>,
ws_state: watch::Receiver<ws::State>,
}
impl ServeSystem {
pub async fn new(cfg: Arc<RtcServe>, shutdown: broadcast::Sender<()>) -> Result<Self> {
let (ws_state_tx, ws_state) = watch::channel(ws::State::default());
let watch = WatchSystem::new(
cfg.watch.clone(),
shutdown.clone(),
Some(ws_state_tx),
cfg.ws_protocol,
)
.await?;
let prefix = if cfg.tls.is_some() { "https" } else { "http" };
let address = cfg.addresses.first().map_or_else(
|| SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), cfg.port),
|ipaddr| SocketAddr::new(*ipaddr, cfg.port),
);
let base = cfg.serve_base()?;
let open_http_addr = format!("{prefix}://{address}{base}");
Ok(Self {
cfg,
watch,
open_http_addr,
shutdown_tx: shutdown,
ws_state,
})
}
#[tracing::instrument(level = "trace", skip(self))]
pub async fn run(mut self) -> Result<()> {
let _build_res = self.watch.build().await; let watch_handle = tokio::spawn(self.watch.run());
let server_handle = Self::spawn_server(
self.cfg.clone(),
self.shutdown_tx.subscribe(),
self.ws_state,
)
.await?;
if self.cfg.open {
if let Err(err) = open::that(self.open_http_addr) {
tracing::error!(error = ?err, "error opening browser");
}
}
drop(self.shutdown_tx);
select! {
r = watch_handle => {
match r {
Err(err) => {
tracing::error!(error = ?err, "error joining watch system handle");
Err(err)
}
_ => r,
}?;
},
r = server_handle => {
match r {
Err(err) => {
tracing::error!(error = ?err, "error joining server handle");
Err(err)
}
_ => r,
}??;
},
}
Ok(())
}
#[tracing::instrument(level = "trace", skip(cfg, shutdown_rx))]
async fn spawn_server(
cfg: Arc<RtcServe>,
shutdown_rx: broadcast::Receiver<()>,
ws_state: watch::Receiver<ws::State>,
) -> Result<JoinHandle<Result<()>>> {
let serve_base_url = cfg.serve_base()?;
let state = Arc::new(State::new(
cfg.watch.build.final_dist.clone(),
serve_base_url.to_string(),
cfg.clone(),
ws_state,
)?);
let router = router(state, cfg.clone())?;
let addr = cfg
.addresses
.iter()
.map(|addr| (*addr, cfg.port).into())
.collect::<Vec<_>>();
let aliases = cfg
.aliases
.iter()
.map(|alias| format!("{alias}:{}", cfg.port))
.collect::<Vec<_>>();
show_listening(
&cfg,
&addr,
&aliases,
&serve_base_url,
!cfg.disable_address_lookup,
)
.await;
let server = run_server(addr, cfg.tls.clone(), router, shutdown_rx);
Ok(tokio::spawn(async move {
match server.await {
Err(err) => {
tracing::error!(error = ?err, "error from server task");
Err(err)
}
r => r,
}
}))
}
}
async fn show_listening(
cfg: &RtcServe,
addr: &[SocketAddr],
aliases: &[String],
base: &str,
lookup: bool,
) {
let mut cache = HashSet::new();
let prefix = if cfg.tls.is_some() { "https" } else { "http" };
let interfaces = local_ip_address::list_afinet_netifas()
.map(|addr| {
addr.into_iter()
.map(|(_name, addr)| addr)
.collect::<Vec<_>>()
})
.unwrap_or(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]);
let mut addresses = BTreeSet::<SocketAddr>::new();
for addr in addr {
if addr.ip().is_unspecified() {
addresses.extend(interfaces.iter().filter_map(|ipaddr| match ipaddr {
IpAddr::V4(_ip) if addr.is_ipv4() => Some(SocketAddr::new(*ipaddr, addr.port())),
IpAddr::V6(_ip) if addr.is_ipv6() => Some(SocketAddr::new(*ipaddr, addr.port())),
_ => None,
}));
} else {
addresses.insert(*addr);
}
}
fn is_loopback(address: &SocketAddr) -> bool {
match address {
SocketAddr::V4(addr) => addr.ip().is_loopback(),
SocketAddr::V6(addr) => addr.ip().is_loopback(),
}
}
tracing::info!("{SERVER}server listening at:");
for address in &addresses {
show_address(
&mut cache,
is_loopback(address),
format!("{prefix}://{address}{base}"),
);
}
for alias in aliases {
show_address(&mut cache, true, alias);
}
if lookup {
match TokioResolver::builder_tokio().map(|r| r.build()) {
Ok(resolver) => {
for address in &addresses {
let local = is_loopback(address);
if let Ok(names) = resolver.reverse_lookup(address.ip()).await {
for name in names {
show_address(
&mut cache,
local,
format!("{prefix}://{name}:{port}{base}", port = address.port()),
);
}
}
}
}
Err(err) => {
log::warn!("Failed to create system resolver, skipping address resolution: {err}");
}
}
}
}
fn show_address(cache: &mut HashSet<String>, local: bool, address: impl Into<String>) {
let address = address.into();
if cache.insert(address.clone()) {
tracing::info!(" {}{address}", if local { LOCAL } else { NETWORK });
}
}
async fn run_server(
addr: Vec<SocketAddr>,
tls: Option<TlsConfig>,
router: Router,
mut shutdown_rx: broadcast::Receiver<()>,
) -> Result<()> {
let shutdown_handle = Handle::new();
let shutdown = |handle: Handle| async move {
let _res = shutdown_rx.recv().await;
tracing::debug!("server is shutting down");
handle.graceful_shutdown(Some(Duration::from_secs(0)));
};
tokio::spawn(shutdown(shutdown_handle.clone()));
let mut tasks = vec![];
for addr in addr {
let router = router.clone();
let shutdown_handle = shutdown_handle.clone();
match &tls {
Some(tls) =>
{
#[allow(unreachable_code)]
match tls.clone() {
#[cfg(feature = "rustls")]
TlsConfig::Rustls { config } => {
tasks.push(
async move {
axum_server::bind_rustls(addr, config)
.handle(shutdown_handle)
.serve(router.into_make_service())
.await
}
.boxed(),
);
}
#[cfg(feature = "native-tls")]
TlsConfig::Native { config } => {
tasks.push(
async move {
axum_server::bind_openssl(addr, config)
.handle(shutdown_handle)
.serve(router.into_make_service())
.await
}
.boxed(),
);
}
}
}
None => tasks.push(
async move {
axum_server::bind(addr)
.handle(shutdown_handle)
.serve(router.into_make_service())
.await
}
.boxed(),
),
};
}
let (result, _, _) = futures_util::future::select_all(tasks).await;
Ok(result?)
}
pub struct State {
pub dist_dir: PathBuf,
pub serve_base: String,
pub ws_state: watch::Receiver<ws::State>,
pub ws_base: String,
pub headers: HashMap<String, String>,
pub cfg: Arc<RtcServe>,
}
impl State {
pub fn new(
dist_dir: PathBuf,
serve_base: String,
cfg: Arc<RtcServe>,
ws_state: watch::Receiver<ws::State>,
) -> Result<Self> {
let mut ws_base = cfg.ws_base()?.to_string();
if !ws_base.ends_with('/') {
ws_base.push('/');
}
Ok(Self {
dist_dir,
serve_base,
ws_state,
ws_base,
headers: cfg.headers.clone(),
cfg,
})
}
}
fn router(state: Arc<State>, cfg: Arc<RtcServe>) -> Result<Router> {
let mut serve_dir = if cfg.no_spa {
get_service(ServeDir::new(&state.dist_dir))
} else {
get_service(
ServeDir::new(&state.dist_dir)
.fallback(ServeFile::new(state.dist_dir.join(INDEX_HTML))),
)
};
for (key, value) in &state.headers {
let name = HeaderName::from_bytes(key.as_bytes())
.with_context(|| format!("invalid header {:?}", key))?;
let value: HeaderValue = value
.parse()
.with_context(|| format!("invalid header value {:?} for header {}", value, name))?;
serve_dir = serve_dir.layer(SetResponseHeaderLayer::overriding(name, value))
}
let mut router = Router::new()
.route(
"/.well-known/trunk/ws",
get(
|ws: WebSocketUpgrade, state: axum::extract::State<Arc<State>>| async move {
ws.on_upgrade(|socket| async move { ws::handle_ws(socket, state.0).await })
},
),
)
.fallback_service(
get_service(serve_dir)
.handle_error(|error| async move {
tracing::error!(?error, "failed serving static file");
StatusCode::INTERNAL_SERVER_ERROR
})
.layer(axum::middleware::from_fn_with_state(
state.clone(),
html_address_middleware,
)),
)
.layer(TraceLayer::new_for_http());
if state.serve_base != "/" {
router = Router::new().nest(&state.serve_base, router);
}
let router = router.with_state(state.clone());
tracing::info!(
"{}serving static assets at -> {}",
SERVER,
state.serve_base.as_str()
);
let mut builder = ProxyBuilder::new(cfg.tls.is_some(), router);
for proxy in &cfg.proxies {
let mut request_headers = HeaderMap::new();
for (key, value) in &proxy.request_headers {
let name = HeaderName::from_bytes(key.as_bytes())
.with_context(|| format!("invalid header {:?}", key))?;
let value: HeaderValue = value
.parse()
.with_context(|| format!("invalid header value {:?} for header {}", value, name))?;
request_headers.insert(name, value);
}
builder = builder.register_proxy(
proxy.ws,
&proxy.backend,
&request_headers,
proxy.rewrite.clone(),
ProxyClientOptions {
insecure: proxy.insecure,
no_system_proxy: proxy.no_system_proxy,
redirect: !proxy.no_redirect,
},
)?;
}
Ok(builder.build())
}
async fn html_address_middleware(
extract::State(state): extract::State<Arc<State>>,
request: extract::Request,
next: Next,
) -> Response {
let host = request.headers().get(HOST).cloned();
let response = next.run(request).await;
if !response.status().is_success() {
return response;
}
let is_html = response
.headers()
.get(CONTENT_TYPE)
.map(|t| t == "text/html")
.unwrap_or_default();
if !is_html {
return response;
}
let (parts, body) = response.into_parts();
let nonce = match &state.cfg.create_nonce {
Some(p) => match nonce() {
Ok(nonce) => Some((p.as_str(), nonce)),
Err(err) => {
tracing::warn!("Failed to create nonce: {err}");
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to create nonce")
.into_response();
}
},
None => None,
};
match axum::body::to_bytes(body, 100 * 1024 * 1024).await {
Err(err) => {
tracing::debug!("Unable to intercept: {err}");
(parts, Bytes::default()).into_response()
}
Ok(bytes) => {
let mut parts = parts;
let mut bytes = bytes;
match std::str::from_utf8(&bytes) {
Ok(data_str) => {
tracing::debug!("Replacing variable");
let host = host
.and_then(|uri| uri.to_str().map(|s| format!("'{}'", s)).ok())
.unwrap_or_else(|| "window.location.host".into());
let mut data_str = data_str
.replace("'{{__TRUNK_ADDRESS__}}'", &host)
.replace("`{{__TRUNK_ADDRESS__}}`", &host)
.replace("{{__TRUNK_WS_BASE__}}", &state.ws_base);
let mut csp = None;
if let Some((var, val)) = nonce {
data_str = data_str.replace(var, &val);
csp = state
.cfg
.csp
.as_ref()
.map(|csp| csp.join(";").replace("{{NONCE}}", &val).parse());
}
match csp {
Some(Ok(csp)) => {
parts.headers.insert(CONTENT_SECURITY_POLICY, csp);
}
Some(Err(e)) => tracing::error!("failed to encode csp header: {e:?}"),
None => {}
};
let bytes_vec = data_str.as_bytes().to_vec();
parts.headers.insert(CONTENT_LENGTH, bytes_vec.len().into());
bytes = Bytes::from(bytes_vec);
}
Err(err) => {
tracing::debug!("Unable to parse for injecting: {err}");
}
}
(parts, bytes).into_response()
}
}
}
pub(crate) type ServerResult<T> = std::result::Result<T, ServerError>;
pub(crate) struct ServerError(pub anyhow::Error);
impl From<anyhow::Error> for ServerError {
fn from(src: anyhow::Error) -> Self {
ServerError(src)
}
}
impl IntoResponse for ServerError {
fn into_response(self) -> Response {
tracing::error!(error = ?self.0, "error handling request");
let mut res = Response::new(Body::empty());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
}