use aide::{
axum::{routing::get_with, ApiRouter},
openapi::OpenApi,
transform::TransformOperation,
};
use axum::{
body::Body, extract::FromRef, middleware as axum_middleware,
response::Response, routing::get, serve, Extension, Router,
};
use openidconnect::core::CoreClient;
use prometheus::{IntCounterVec, Opts, Registry};
use std::{path::PathBuf, sync::Arc};
use thiserror::Error;
use tokio_listener::ListenerAddress;
use tower_http::trace::TraceLayer;
use tower_sessions::{cookie::SameSite, MemoryStore, SessionManagerLayer};
use tracing::{error, info};
use super::me;
pub use crate::app::CliApp;
use crate::auth::{self, OidcConfig, OidcDiscoveryError};
use crate::server::{
health::HealthRegistry, metrics, openapi, shutdown, spa, systemd,
};
pub trait ServerApp: CliApp {
fn server_run_configs(&self) -> Vec<ServerRunConfig>;
}
pub struct ServerRunConfig {
pub app_name: String,
pub listen_address: ListenerAddress,
pub frontend_path: Option<PathBuf>,
pub base_url: String,
pub oidc: Option<OidcConfig>,
}
#[derive(Debug, Error)]
pub enum ServerError {
#[error("OIDC provider discovery failed: {0}")]
OidcDiscovery(#[from] OidcDiscoveryError),
#[error("Failed to bind listener to {address}: {source}")]
ListenerBind {
address: String,
#[source]
source: std::io::Error,
},
#[error("HTTP-request metrics setup failed: {0}")]
RequestMetricsInit(#[from] prometheus::Error),
#[error("Server runtime error: {0}")]
Runtime(#[source] std::io::Error),
}
#[derive(Clone)]
pub struct BaseServerState {
pub health_registry: HealthRegistry,
pub metrics_registry: Arc<Registry>,
pub request_counter: IntCounterVec,
pub oidc_client: Option<Arc<CoreClient>>,
pub frontend_path: Option<PathBuf>,
}
impl BaseServerState {
pub async fn init(config: &ServerRunConfig) -> Result<Self, ServerError> {
let registry = Registry::new();
let request_counter = IntCounterVec::new(
Opts::new(
"http_requests_total",
"Total HTTP requests by method and status",
),
&["method", "status"],
)?;
registry.register(Box::new(request_counter.clone()))?;
let oidc_client = match &config.oidc {
Some(oidc) => Some(auth::discover_oidc(oidc, &config.base_url).await?),
None => {
info!("OIDC not configured — running unauthenticated");
None
}
};
Ok(Self {
health_registry: HealthRegistry::default(),
metrics_registry: Arc::new(registry),
request_counter,
oidc_client,
frontend_path: config.frontend_path.clone(),
})
}
}
impl FromRef<BaseServerState> for HealthRegistry {
fn from_ref(state: &BaseServerState) -> Self {
state.health_registry.clone()
}
}
impl FromRef<BaseServerState> for Arc<Registry> {
fn from_ref(state: &BaseServerState) -> Self {
state.metrics_registry.clone()
}
}
impl FromRef<BaseServerState> for Option<Arc<CoreClient>> {
fn from_ref(state: &BaseServerState) -> Self {
state.oidc_client.clone()
}
}
impl FromRef<BaseServerState> for IntCounterVec {
fn from_ref(state: &BaseServerState) -> Self {
state.request_counter.clone()
}
}
#[macro_export]
macro_rules! impl_server_state {
($state_ty:ty, $field:ident) => {
impl ::axum::extract::FromRef<$state_ty>
for $crate::server::health::HealthRegistry
{
fn from_ref(state: &$state_ty) -> Self {
state.$field.health_registry.clone()
}
}
impl ::axum::extract::FromRef<$state_ty>
for ::std::sync::Arc<::prometheus::Registry>
{
fn from_ref(state: &$state_ty) -> Self {
state.$field.metrics_registry.clone()
}
}
impl ::axum::extract::FromRef<$state_ty>
for ::std::option::Option<
::std::sync::Arc<::openidconnect::core::CoreClient>,
>
{
fn from_ref(state: &$state_ty) -> Self {
state.$field.oidc_client.clone()
}
}
impl ::axum::extract::FromRef<$state_ty> for ::prometheus::IntCounterVec {
fn from_ref(state: &$state_ty) -> Self {
state.$field.request_counter.clone()
}
}
};
}
pub struct Server<S = BaseServerState>
where
S: Clone + Send + Sync + 'static,
{
state: S,
base: BaseServerState,
router: ApiRouter<S>,
config: ServerRunConfig,
}
impl Server<BaseServerState> {
pub fn new(base: BaseServerState, config: ServerRunConfig) -> Self {
Self {
state: base.clone(),
base,
router: ApiRouter::new(),
config,
}
}
pub fn base_state(&self) -> &BaseServerState {
&self.base
}
pub fn with_state<S2>(
self,
f: impl FnOnce(BaseServerState) -> S2,
) -> Server<S2>
where
S2: Clone + Send + Sync + 'static,
HealthRegistry: FromRef<S2>,
Arc<Registry>: FromRef<S2>,
Option<Arc<CoreClient>>: FromRef<S2>,
{
let new_state = f(self.base.clone());
Server {
state: new_state,
base: self.base,
router: ApiRouter::new(),
config: self.config,
}
}
}
impl<S> Server<S>
where
S: Clone + Send + Sync + 'static,
HealthRegistry: FromRef<S>,
Arc<Registry>: FromRef<S>,
Option<Arc<CoreClient>>: FromRef<S>,
{
pub fn api_route(
mut self,
path: &str,
method: aide::axum::routing::ApiMethodRouter<S>,
) -> Self {
self.router = self.router.api_route(path, method);
self
}
pub fn merge(mut self, router: ApiRouter<S>) -> Self {
self.router = self.router.merge(router);
self
}
pub async fn listen(self) -> Result<(), ServerError> {
let listen_address = self.config.listen_address.clone();
let listen_address_display = listen_address.to_string();
let app = self.build_router(true);
let listener = tokio_listener::Listener::bind(
&listen_address,
&tokio_listener::SystemOptions::default(),
&tokio_listener::UserOptions::default(),
)
.await
.map_err(|source| {
error!("Failed to bind to {}: {}", listen_address_display, source);
ServerError::ListenerBind {
address: listen_address_display.clone(),
source,
}
})?;
info!("Server listening on {}", listen_address_display);
systemd::notify_ready();
systemd::spawn_watchdog();
let shutdown_future = async {
if let Err(e) = shutdown::shutdown_signal().await {
error!(
"Shutdown-signal install failed; graceful shutdown disabled: {}",
e,
);
std::future::pending::<()>().await;
}
};
serve(listener, app.into_make_service())
.with_graceful_shutdown(shutdown_future)
.await
.map_err(|e| {
error!("Server error: {}", e);
ServerError::Runtime(e)
})?;
info!("Server shut down");
Ok(())
}
pub fn into_test_router(self) -> Router {
self.build_router(false)
}
fn build_router(self, secure_cookies: bool) -> Router {
aide::generate::extract_schemas(true);
let app_name = self.config.app_name.clone();
let state = self.state;
let base = self.base;
let mut api = OpenApi::default();
let infra_router = ApiRouter::new()
.api_route(
"/healthz",
get_with(
crate::server::health::healthz_handler,
|op: TransformOperation| op.description("Health check."),
),
)
.api_route(
"/metrics",
get_with(metrics::metrics_handler, |op: TransformOperation| {
op.description("Prometheus metrics in text/plain format.")
}),
);
let api_router = infra_router
.merge(self.router)
.with_state(state)
.finish_api_with(&mut api, |a| a.title(&app_name));
let api = Arc::new(api);
let auth_router = build_auth_routes(base.oidc_client.clone());
let me_router = build_me_route(base.oidc_client.clone());
let mut full = Router::new()
.merge(api_router)
.merge(auth_router)
.merge(me_router)
.merge(openapi::openapi_routes(api, &app_name));
if let Some(ref frontend_path) = base.frontend_path {
full = full.fallback_service(spa::spa_service(frontend_path));
}
let session_store = MemoryStore::default();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(secure_cookies)
.with_same_site(SameSite::Lax);
full
.layer(axum_middleware::from_fn(request_counter_middleware))
.layer(Extension(base.request_counter.clone()))
.layer(session_layer)
.layer(TraceLayer::new_for_http())
}
}
async fn request_counter_middleware(
Extension(counter): Extension<IntCounterVec>,
request: axum::http::Request<Body>,
next: axum::middleware::Next,
) -> Response {
let method = request.method().to_string();
let response = next.run(request).await;
let status = response.status().as_u16().to_string();
counter.with_label_values(&[&method, &status]).inc();
response
}
fn build_auth_routes(oidc_client: Option<Arc<CoreClient>>) -> Router {
Router::new()
.route("/auth/login", get(auth::login_handler))
.route("/auth/callback", get(auth::callback_handler))
.route("/auth/logout", get(auth::logout_handler))
.with_state(oidc_client)
}
fn build_me_route(oidc_client: Option<Arc<CoreClient>>) -> Router {
Router::new()
.route("/me", get(me::me_handler))
.with_state(oidc_client)
}