mod acl;
#[cfg(feature = "tee")]
mod attestation;
mod audit;
mod auth;
mod auth_portal;
mod backup;
mod backup_blob;
mod bootstrap;
mod cache;
mod capabilities;
mod config;
mod contexts;
mod did_templates;
#[cfg(feature = "webvh")]
mod did_webvh;
mod health;
pub mod keys;
#[cfg(feature = "webvh")]
mod passkey_vms;
#[cfg(feature = "webvh")]
mod protocol;
pub(crate) mod trust_tasks;
mod vta;
use std::sync::Arc;
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::http::{HeaderName, HeaderValue, Method};
use axum::routing::{delete, get, post, put};
use tower_governor::GovernorLayer;
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::cors::{AllowOrigin, CorsLayer};
use crate::server::AppState;
const MAX_BODY_SIZE: usize = 1024 * 1024;
const UNAUTH_BODY_SIZE: usize = 64 * 1024;
pub(super) const BACKUP_BLOB_BODY_SIZE: usize = 100 * 1024 * 1024;
const UNAUTH_RPS: u64 = 5;
const UNAUTH_BURST: u32 = 10;
pub fn health_router() -> Router<AppState> {
Router::new().route("/health", get(health::health))
}
pub fn health_router_with_cors(allowed_origins: &[String]) -> Router<AppState> {
let router = health_router();
match build_cors_layer(allowed_origins) {
Some(cors) => router.layer(cors),
None => router,
}
}
fn build_cors_layer(allowed_origins: &[String]) -> Option<CorsLayer> {
if allowed_origins.is_empty() {
return None;
}
let parsed: Vec<HeaderValue> = allowed_origins
.iter()
.filter(|o| !o.is_empty() && *o != "*")
.filter_map(|o| HeaderValue::from_str(o).ok())
.collect();
if parsed.is_empty() {
return None;
}
Some(
CorsLayer::new()
.allow_origin(AllowOrigin::list(parsed))
.allow_methods([Method::GET, Method::POST, Method::DELETE, Method::PATCH])
.allow_headers([
HeaderName::from_static("content-type"),
HeaderName::from_static("authorization"),
HeaderName::from_static("x-backup-token"),
])
.max_age(std::time::Duration::from_secs(60)),
)
}
pub fn router() -> Router<AppState> {
router_with_cors(&[], false)
}
pub fn router_with_cors(allowed_origins: &[String], trust_xff: bool) -> Router<AppState> {
let unauth = Router::new()
.route("/bootstrap/request", post(bootstrap::request))
.route("/auth/passkey-login/start", post(auth::passkey_login_start))
.route(
"/auth/passkey-login/finish",
post(auth::passkey_login_finish),
)
.route("/auth/challenge", post(auth::challenge))
.route("/auth/", post(auth::authenticate))
.route("/auth/refresh", post(auth::refresh))
.layer(DefaultBodyLimit::max(UNAUTH_BODY_SIZE));
#[cfg(feature = "webvh")]
let unauth = unauth
.route("/did/{did}/log", get(did_webvh::get_did_log_public_handler));
let unauth = if trust_xff {
let cfg = Arc::new(
GovernorConfigBuilder::default()
.per_second(UNAUTH_RPS)
.burst_size(UNAUTH_BURST)
.key_extractor(tower_governor::key_extractor::SmartIpKeyExtractor)
.finish()
.expect("governor config values are static and non-zero"),
);
unauth.layer(GovernorLayer::new(cfg))
} else {
let cfg = Arc::new(
GovernorConfigBuilder::default()
.per_second(UNAUTH_RPS)
.burst_size(UNAUTH_BURST)
.key_extractor(tower_governor::key_extractor::PeerIpKeyExtractor)
.finish()
.expect("governor config values are static and non-zero"),
);
unauth.layer(GovernorLayer::new(cfg))
};
let auth_portal_router = Router::new().route("/auth/portal", get(auth_portal::portal_handler));
#[cfg(feature = "webvh")]
let auth_provision = Router::new().route(
"/bootstrap/provision-integration",
post(bootstrap::provision_integration),
);
let router = Router::new().merge(unauth);
#[cfg(feature = "webvh")]
let router = router.merge(auth_provision);
let router = router.merge(auth_portal_router);
let router = router
.route(
"/auth/sessions",
get(auth::session_list).delete(auth::revoke_sessions_by_did),
)
.route("/auth/sessions/{session_id}", delete(auth::revoke_session))
.route("/api/trust-tasks", post(trust_tasks::dispatch_trust_task))
.route(
"/config",
get(config::get_config).patch(config::update_config),
)
.route("/keys", get(keys::list_keys).post(keys::create_key))
.route(
"/keys/{key_id}",
get(keys::get_key)
.delete(keys::invalidate_key)
.patch(keys::rename_key),
)
.route("/keys/{key_id}/secret", get(keys::get_key_secret))
.route("/keys/{key_id}/sign", post(keys::sign_with_key))
.route("/keys/import/wrapping-key", get(keys::get_wrapping_key))
.route("/keys/import", post(keys::import_key))
.route("/keys/seeds", get(keys::list_seeds))
.route("/keys/seeds/rotate", post(keys::rotate_seed))
.route(
"/contexts",
get(contexts::list_contexts_handler).post(contexts::create_context_handler),
)
.route(
"/contexts/{id}",
get(contexts::get_context_handler)
.patch(contexts::update_context_handler)
.delete(contexts::delete_context_handler),
)
.route(
"/contexts/{id}/did",
put(contexts::update_context_did_handler),
)
.route(
"/contexts/{id}/delete-preview",
get(contexts::preview_delete_context_handler),
)
.route(
"/did-templates",
get(did_templates::list_handler).post(did_templates::create_handler),
)
.route(
"/did-templates/{name}",
get(did_templates::get_handler)
.put(did_templates::update_handler)
.delete(did_templates::delete_handler),
)
.route(
"/did-templates/{name}/render",
post(did_templates::render_handler),
)
.route(
"/contexts/{id}/did-templates",
get(did_templates::list_context_handler).post(did_templates::create_context_handler),
)
.route(
"/contexts/{id}/did-templates/{name}",
get(did_templates::get_context_handler)
.put(did_templates::update_context_handler)
.delete(did_templates::delete_context_handler),
)
.route(
"/contexts/{id}/did-templates/{name}/render",
post(did_templates::render_context_handler),
)
.route("/acl", get(acl::list_acl).post(acl::create_acl))
.route("/acl/swap", post(acl::swap_acl))
.route(
"/acl/{did}",
get(acl::get_acl)
.patch(acl::update_acl)
.delete(acl::delete_acl),
)
.route("/audit/logs", get(audit::list_audit_logs))
.route(
"/audit/retention",
get(audit::get_retention).patch(audit::update_retention),
)
.route(
"/cache/{key}",
get(cache::get_cached)
.put(cache::put_cached)
.delete(cache::delete_cached),
);
#[cfg(feature = "tee")]
let router = router
.route("/attestation/status", get(attestation::status))
.route(
"/attestation/report",
get(attestation::cached_report).post(attestation::generate_report),
)
.route(
"/attestation/mnemonic",
get(attestation::mnemonic_status).post(attestation::mnemonic_export),
)
.route("/attestation/did-log", get(attestation::did_log));
#[cfg(feature = "webvh")]
let router = router
.route(
"/services/didcomm/enable",
post(protocol::enable_didcomm_handler),
)
.route(
"/services/didcomm/disable",
post(protocol::disable_didcomm_handler),
)
.route("/services/rest/enable", post(protocol::enable_rest_handler))
.route("/services/rest/update", post(protocol::update_rest_handler))
.route(
"/services/rest/disable",
post(protocol::disable_rest_handler),
)
.route(
"/services/rest/rollback",
post(protocol::rollback_rest_handler),
)
.route(
"/services/webauthn/enable",
post(protocol::enable_webauthn_handler),
)
.route(
"/services/webauthn/update",
post(protocol::update_webauthn_handler),
)
.route(
"/services/webauthn/disable",
post(protocol::disable_webauthn_handler),
)
.route(
"/services/webauthn/rollback",
post(protocol::rollback_webauthn_handler),
)
.route("/services", get(protocol::list_services_handler))
.route("/services/didcomm/drain", get(protocol::list_drain_handler))
.route(
"/services/didcomm/update",
post(protocol::update_didcomm_handler),
)
.route(
"/services/didcomm/rollback",
post(protocol::rollback_didcomm_handler),
)
.route(
"/mediators/drain/cancel",
post(protocol::drain_cancel_handler),
)
.route("/mediators/report", get(protocol::mediator_report_handler));
#[cfg(feature = "webvh")]
let router = router
.route(
"/webvh/servers",
get(did_webvh::list_servers_handler).post(did_webvh::add_server_handler),
)
.route(
"/webvh/servers/{id}",
axum::routing::patch(did_webvh::update_server_handler)
.delete(did_webvh::remove_server_handler),
)
.route(
"/webvh/servers/{id}/domains",
get(did_webvh::list_server_domains_handler),
)
.route(
"/webvh/dids",
get(did_webvh::list_dids_handler).post(did_webvh::create_did_handler),
)
.route(
"/webvh/dids/{did}",
get(did_webvh::get_did_handler).delete(did_webvh::delete_did_handler),
)
.route("/webvh/dids/{did}/log", get(did_webvh::get_did_log_handler))
.route(
"/webvh/dids/{did}/register-server",
post(did_webvh::register_did_with_server_handler),
)
.route(
"/contexts/{ctx_id}/dids/{scid}/update",
post(did_webvh::update_did_handler),
)
.route(
"/contexts/{ctx_id}/dids/{scid}/rotate-keys",
post(did_webvh::rotate_did_keys_handler),
)
.route(
"/did/verification-methods/passkey/challenge",
post(passkey_vms::enroll_challenge_handler),
)
.route(
"/did/verification-methods/passkey",
post(passkey_vms::enroll_submit_handler).get(passkey_vms::list_passkeys_handler),
)
.route(
"/did/verification-methods/passkey/{fragment}",
delete(passkey_vms::revoke_passkey_handler),
);
let router = router
.route("/vta/restart", post(vta::restart))
.route("/metrics", get(vta::metrics))
.route("/backup/export", post(backup::export))
.route("/backup/import", post(backup::import));
let backup_blob_router = Router::new()
.route(
"/backup/blob/{bundle_id}",
get(backup_blob::get_blob).post(backup_blob::post_blob),
)
.layer(DefaultBodyLimit::disable());
let router = router.merge(backup_blob_router);
let router = router
.route("/health/details", get(health::health_details))
.route("/capabilities", get(capabilities::capabilities));
let router = router.layer(DefaultBodyLimit::max(MAX_BODY_SIZE));
match build_cors_layer(allowed_origins) {
Some(cors) => router.layer(cors),
None => router,
}
}
#[cfg(test)]
mod cors_tests {
use super::*;
#[test]
fn empty_list_disables_cors_entirely() {
assert!(build_cors_layer(&[]).is_none());
}
#[test]
fn explicit_origin_produces_layer() {
let layer = build_cors_layer(&["http://localhost:8000".to_string()]);
assert!(layer.is_some());
}
#[test]
fn invalid_origin_filtered_out_and_empty_result_returns_none() {
let bad_origin = "http://localhost:8000\n".to_string();
assert!(build_cors_layer(&[bad_origin]).is_none());
}
#[test]
fn wildcard_alone_yields_no_layer() {
assert!(
build_cors_layer(&["*".to_string()]).is_none(),
"wildcard must be filtered to None, never partial-applied"
);
}
#[test]
fn wildcard_mixed_with_explicit_origins_drops_wildcard_keeps_others() {
let layer = build_cors_layer(&["*".to_string(), "http://localhost:8000".to_string()]);
assert!(layer.is_some());
}
#[test]
fn empty_origin_string_filtered() {
let layer = build_cors_layer(&["".to_string(), "http://x".to_string()]);
assert!(layer.is_some());
}
#[test]
fn health_router_with_cors_builds_both_branches() {
let _with = health_router_with_cors(&["http://localhost:8000".to_string()]);
let _without = health_router_with_cors(&[]);
let _wildcard_only = health_router_with_cors(&["*".to_string()]);
}
}