use crate::pool::SessionPool;
use axum::Router;
use axum_session::{SessionConfig, SessionLayer, SessionStore};
use axum_session_auth::AuthConfig as AxumAuthConfig;
use crate::auth::AuthLayer;
use crate::config::AuthConfig;
pub async fn install(router: Router, cfg: AuthConfig) -> anyhow::Result<Router> {
let mut router = router;
crate::auth::sync_bootstrap_admin(&cfg.pool).await?;
#[cfg(feature = "_oauth-core")]
let provider_infos: Vec<crate::wire::ProviderInfo> = {
let mut infos = Vec::new();
if !cfg.oauth.is_empty() {
let oauth_router = Router::new()
.route(
"/auth/{provider}/login",
axum::routing::get(crate::oauth::oauth_login),
)
.route(
"/auth/{provider}/callback",
axum::routing::get(crate::oauth::oauth_callback),
)
.with_state(cfg.oauth.clone());
router = router.merge(oauth_router);
for p in cfg.oauth.list() {
infos.push(crate::wire::ProviderInfo {
name: p.name().to_string(),
display_name: p.display_name().to_string(),
login_url: format!("/auth/{}/login", p.name()),
icon_svg: p.icon_svg().map(|s| s.to_string()),
});
}
}
infos
};
#[cfg(not(feature = "_oauth-core"))]
let provider_infos: Vec<crate::wire::ProviderInfo> = Vec::new();
#[cfg(feature = "ratelimit")]
if let Some(rl) = cfg.rate_limit.as_ref() {
let governor_config = tower_governor::governor::GovernorConfigBuilder::default()
.key_extractor(LenientIpKeyExtractor)
.per_second(rl.per_second)
.burst_size(rl.burst)
.finish()
.ok_or_else(|| {
anyhow::anyhow!(
"invalid rate-limit config (per_second={}, burst={})",
rl.per_second,
rl.burst,
)
})?;
let governor_config = std::sync::Arc::new(governor_config);
let limiter = governor_config.limiter().clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
limiter.retain_recent();
}
});
router = router.layer(tower_governor::GovernorLayer::new(governor_config));
}
router = router.layer(axum::Extension(cfg.pool.clone()));
router = router.layer(axum::Extension(cfg.audit.clone()));
router = router.layer(axum::Extension(std::sync::Arc::new(provider_infos)));
#[cfg(feature = "mail")]
{
router = router.layer(axum::Extension(cfg.mailer.clone()));
}
if let Some(authority) = cfg.resource_authority.clone() {
router = router.layer(axum::Extension(authority));
}
#[cfg(feature = "tokens")]
{
let pool = cfg.pool.clone();
router = router.layer(axum::middleware::from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
crate::api_key::bearer_auth(pool.clone(), req, next)
},
));
}
if cfg.audit.retention_days > 0 {
let prune_pool = cfg.pool.clone();
let retention = cfg.audit.retention_days;
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
loop {
match crate::auth::audit::prune(&prune_pool, retention).await {
Ok(n) if n > 0 => {
eprintln!("[audit] pruned {n} event(s) older than {retention}d");
}
Ok(_) => {}
Err(err) => {
eprintln!("[audit] WARN: prune failed: {err}");
}
}
tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
}
});
}
let security_headers = std::sync::Arc::new(build_security_headers(&cfg));
router = router.layer(
AuthLayer::new(Some(cfg.pool.clone()))
.with_config(AxumAuthConfig::<i64>::default().with_anonymous_user_id(Some(1))),
);
let session_store = SessionStore::<SessionPool>::new(
Some(cfg.pool.into()),
SessionConfig::default()
.with_table_name(cfg.session_table_name)
.with_ip_and_user_agent(false)
.with_secure(cfg.cookie_secure)
.with_lifetime(cfg.session_lifetime)
.with_max_lifetime(cfg.session_max_lifetime)
.with_max_age(Some(cfg.cookie_max_age)),
)
.await?;
router = router.layer(SessionLayer::new(session_store));
router = router.layer(axum::middleware::from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
let headers = security_headers.clone();
async move {
let mut resp = next.run(req).await;
let out = resp.headers_mut();
for (name, value) in headers.iter() {
out.insert(name.clone(), value.clone());
}
resp
}
},
));
Ok(router)
}
fn build_security_headers(cfg: &AuthConfig) -> Vec<(http::HeaderName, http::HeaderValue)> {
use http::{HeaderName, HeaderValue};
let mut headers: Vec<(HeaderName, HeaderValue)> = vec![
(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff"),
),
(
HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("strict-origin-when-cross-origin"),
),
(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("SAMEORIGIN"),
),
(
HeaderName::from_static("cross-origin-opener-policy"),
HeaderValue::from_static("same-origin"),
),
(
HeaderName::from_static("x-permitted-cross-domain-policies"),
HeaderValue::from_static("none"),
),
(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("camera=(), microphone=(), geolocation=()"),
),
];
if let Some(hsts) = cfg.hsts.as_deref() {
match HeaderValue::from_str(hsts) {
Ok(v) => headers.push((http::header::STRICT_TRANSPORT_SECURITY, v)),
Err(e) => eprintln!("[security] WARN: ignoring invalid HSTS value: {e}"),
}
}
if let Some(csp) = cfg.csp.as_deref() {
match HeaderValue::from_str(csp) {
Ok(v) => headers.push((http::header::CONTENT_SECURITY_POLICY, v)),
Err(e) => eprintln!("[security] WARN: ignoring invalid CSP value: {e}"),
}
}
headers
}
#[cfg(feature = "ratelimit")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LenientIpKeyExtractor;
#[cfg(feature = "ratelimit")]
impl tower_governor::key_extractor::KeyExtractor for LenientIpKeyExtractor {
type Key = std::net::IpAddr;
fn extract<T>(
&self,
req: &http::Request<T>,
) -> Result<Self::Key, tower_governor::GovernorError> {
use tower_governor::key_extractor::SmartIpKeyExtractor;
Ok(SmartIpKeyExtractor
.extract(req)
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)))
}
}