use anyhow::Result;
use axum::{
Router,
extract::{Json, Path, Query, Request, State},
http::{HeaderMap, StatusCode, header},
response::{IntoResponse, Response},
routing::{any, post},
};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::sync::{Arc, RwLock};
use tokio::task::JoinSet;
use crate::{
assembled_statistical_sequences::AssembledStatisticalSequences,
config::Config,
means_of_production::{self, MeansOfProduction, Outcome},
tenx_programmer::{TenXProgrammer, TenXProgrammerCounters},
};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, Clone)]
pub struct IocaineStateSnapshot {
pub config: Config,
pub counters: Option<TenXProgrammerCounters>,
pub template: Arc<AssembledStatisticalSequences>,
pub request_handler: Option<Arc<MeansOfProduction>>,
}
pub type IocaineState = Arc<RwLock<IocaineStateSnapshot>>;
#[derive(Debug)]
pub struct Iocaine {
pub config: Config,
}
impl Iocaine {
pub fn new(config: Config) -> Result<Self> {
Ok(Self { config })
}
pub fn make_state(
config: &Config,
counters: Option<TenXProgrammerCounters>,
) -> Result<IocaineState> {
let request_handler = if let Some(path) = &config.server.request_handler.path {
Some(Arc::new(MeansOfProduction::new(path)?))
} else {
None
};
let state = IocaineStateSnapshot {
config: config.clone(),
counters,
template: Arc::new(AssembledStatisticalSequences::new(config)),
request_handler,
};
Ok(Arc::new(RwLock::new(state)))
}
fn main_app(state: IocaineState) -> Router {
Router::new()
.route("/", any(handler))
.route("/{*path}", any(handler))
.layer(tower_http::trace::TraceLayer::new_for_http())
.with_state(state)
}
fn control_app(state: IocaineState) -> Router {
Router::new()
.route("/config/load", post(control_config_load))
.layer(tower_http::trace::TraceLayer::new_for_http())
.with_state(state)
}
async fn start_server(self) -> Result<()> {
let bind = &self.config.server.bind.clone();
let metrics_bind = &self.config.metrics.bind.clone();
let mut opts = tokio_listener::UserOptions::default();
opts.unix_listen_unlink = true;
opts.unix_listen_chmod = self.config.server.unix_listen_access;
let metrics = TenXProgrammer::new(&self.config.metrics)?;
let state = Self::make_state(&self.config, metrics.as_ref().map(|v| v.counters.clone()))?;
let app = Self::main_app(state.clone());
let listener =
tokio_listener::Listener::bind(bind, &tokio_listener::SystemOptions::default(), &opts)
.await?;
let mut servers = JoinSet::new();
servers.spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
});
if let Some(metrics) = metrics {
let metrics_listener = tokio_listener::Listener::bind(
metrics_bind,
&tokio_listener::SystemOptions::default(),
&opts,
)
.await?;
let metrics_app = metrics.app();
servers.spawn(async move {
axum::serve(metrics_listener, metrics_app)
.with_graceful_shutdown(shutdown_signal())
.await
});
}
if let Some(control) = &self.config.server.control {
let mut opts = tokio_listener::UserOptions::default();
opts.unix_listen_unlink = true;
opts.unix_listen_chmod = control.unix_listen_access;
let listener = tokio_listener::Listener::bind(
&control.bind,
&tokio_listener::SystemOptions::default(),
&opts,
)
.await?;
let control_app = Self::control_app(state);
servers.spawn(async move {
axum::serve(listener, control_app)
.with_graceful_shutdown(shutdown_signal())
.await
});
}
let _ = servers.join_all().await;
Ok(())
}
pub async fn run(self) -> Result<()> {
self.start_server().await
}
}
pub async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
}
#[must_use]
pub fn handle_request(
headers: &HeaderMap,
state: &IocaineStateSnapshot,
method: &str,
path: Option<Path<String>>,
params: &BTreeMap<String, String>,
) -> impl IntoResponse + use<> {
if let Some(ref request_handler) = state.request_handler {
let p = path.as_ref().map(|p| p.0.clone());
request_handler
.decide(headers.clone(), p, method)
.map_or_else(
|variant| misdirect(headers, state, variant).into_response(),
|variant| match variant {
Outcome::Garbage => {
poison(headers, state, path, params, variant).into_response()
}
Outcome::Challenge => {
challenge(headers, state, path, params, variant).into_response()
}
Outcome::NotForUs => server_error().into_response(),
},
)
} else {
poison(headers, state, path, params, Outcome::Garbage).into_response()
}
}
async fn handler(
headers: HeaderMap,
State(state): State<IocaineState>,
path: Option<Path<String>>,
Query(params): Query<BTreeMap<String, String>>,
request: Request,
) -> impl IntoResponse {
let method = request.method().to_string();
handle_request(&headers, &state.read().unwrap(), &method, path, ¶ms)
}
#[derive(Debug, Deserialize)]
struct ControlConfigLoad {
pub path: String,
}
async fn control_config_load(
State(state): State<IocaineState>,
Json(payload): Json<ControlConfigLoad>,
) -> Result<impl IntoResponse, AppError> {
let Ok(config) = Config::load(&payload.path) else {
return Ok((StatusCode::UNPROCESSABLE_ENTITY, ""));
};
if !state.read().unwrap().config.is_compatible(&config) {
return Ok((StatusCode::CONFLICT, ""));
}
let request_handler = if let Some(path) = &config.server.request_handler.path {
match MeansOfProduction::new(path) {
Ok(v) => Some(Arc::new(v)),
Err(e) => {
tracing::error!(
{ config_file = payload.path },
"Failed to load request handler: {e}"
);
return Ok((StatusCode::UNPROCESSABLE_ENTITY, ""));
}
}
} else {
None
};
if let Ok(mut new_state) = state.write() {
new_state.config = config.clone();
new_state.request_handler = request_handler;
new_state.template = Arc::new(AssembledStatisticalSequences::new(&config));
} else {
tracing::error!("Failed to lock state for writing");
return Ok((StatusCode::INTERNAL_SERVER_ERROR, ""));
}
Ok((StatusCode::ACCEPTED, ""))
}
fn misdirect(
headers: &axum::http::HeaderMap,
state: &IocaineStateSnapshot,
outcome: Outcome,
) -> impl IntoResponse {
if let Some(ref counters) = state.counters {
let verdict = format!("reject::{outcome}");
let labels =
TenXProgrammer::build_label_values(&state.template.config.metrics, headers, &verdict);
counters.request_counter.with_label_values(&labels).inc();
}
(StatusCode::MISDIRECTED_REQUEST, "")
}
fn server_error() -> impl IntoResponse {
(StatusCode::INTERNAL_SERVER_ERROR, "")
}
fn challenge(
headers: &axum::http::HeaderMap,
state: &IocaineStateSnapshot,
path: Option<Path<String>>,
params: &BTreeMap<String, String>,
outcome: Outcome,
) -> std::result::Result<impl IntoResponse, AppError> {
let default_host = axum::http::HeaderValue::from_static("<unknown>");
let host = headers.get("host").unwrap_or(&default_host).to_str()?;
let path = path.unwrap_or(Path(String::new()));
let (content_type, challenge) =
state
.template
.generate(host, &path, params, "challenge.jinja")?;
if let Some(ref counters) = state.counters {
let verdict = format!("accept::{outcome}");
let labels =
TenXProgrammer::build_label_values(&state.template.config.metrics, headers, &verdict);
counters.request_counter.with_label_values(&labels).inc();
counters.challenge_counter.with_label_values(&labels).inc();
}
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, content_type.parse()?);
if state.config.templates.minify.enable
&& (content_type.starts_with("text/html") || content_type.starts_with("text/css"))
{
let config = &state.config.templates.minify;
let cfg = minify_html::Cfg {
minify_css: config.minify_css,
minify_js: false,
minify_doctype: false,
..Default::default()
};
let minified = minify_html::minify(challenge.as_bytes(), &cfg);
Ok((headers, minified))
} else {
Ok((headers, challenge.into()))
}
}
fn poison(
headers: &axum::http::HeaderMap,
state: &IocaineStateSnapshot,
path: Option<Path<String>>,
params: &BTreeMap<String, String>,
outcome: Outcome,
) -> std::result::Result<impl IntoResponse, AppError> {
let default_host = axum::http::HeaderValue::from_static("<unknown>");
let host = headers.get("host").unwrap_or(&default_host).to_str()?;
let path = path.unwrap_or(Path(String::new()));
let (content_type, garbage) = state.template.generate(host, &path, params, "main.jinja")?;
if let Some(ref counters) = state.counters {
let verdict = format!("accept::{outcome}");
let labels =
TenXProgrammer::build_label_values(&state.template.config.metrics, headers, &verdict);
counters.request_counter.with_label_values(&labels).inc();
counters
.garbage_served_counter
.with_label_values(&labels)
.inc_by(garbage.len() as u64);
let depth = path.chars().filter(|c| *c == '/').count() as u64;
let maze_depth_counter = counters.maze_depth.with_label_values(&labels);
let maze_depth = maze_depth_counter.get();
if depth > maze_depth {
maze_depth_counter.inc_by(depth - maze_depth);
}
}
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, content_type.parse()?);
if state.config.templates.minify.enable
&& (content_type.starts_with("text/html") || content_type.starts_with("text/css"))
{
let config = &state.config.templates.minify;
let cfg = minify_html::Cfg {
minify_css: config.minify_css,
minify_js: false,
minify_doctype: false,
..Default::default()
};
let minified = minify_html::minify(garbage.as_bytes(), &cfg);
Ok((headers, minified))
} else {
Ok((headers, garbage.into()))
}
}
pub struct AppError(anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Internal server error: {}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
}
}
impl From<axum::http::header::ToStrError> for AppError {
fn from(e: axum::http::header::ToStrError) -> Self {
Self(e.into())
}
}
impl From<anyhow::Error> for AppError {
fn from(e: anyhow::Error) -> Self {
Self(e)
}
}
impl From<std::io::Error> for AppError {
fn from(e: std::io::Error) -> Self {
Self(e.into())
}
}
impl From<std::string::FromUtf8Error> for AppError {
fn from(e: std::string::FromUtf8Error) -> Self {
Self(e.into())
}
}
impl From<axum::http::header::InvalidHeaderValue> for AppError {
fn from(e: axum::http::header::InvalidHeaderValue) -> Self {
Self(e.into())
}
}
impl From<means_of_production::Error> for AppError {
fn from(e: means_of_production::Error) -> Self {
Self(e.into())
}
}