use crate::pool::SessionPool;
use axum::Router;
use axum_session::{SessionConfig, SessionLayer, SessionStore};
use axum_session_auth::AuthConfig as AxumAuthConfig;
use crate::auth::AuthLayer;
use crate::config::AuthConfig;
pub async fn install(router: Router, cfg: AuthConfig) -> anyhow::Result<Router> {
let mut router = router;
crate::auth::sync_bootstrap_admin(&cfg.pool).await?;
#[cfg(feature = "_oauth-core")]
let provider_infos: Vec<crate::wire::ProviderInfo> = {
let mut infos = Vec::new();
if !cfg.oauth.is_empty() {
let oauth_router = Router::new()
.route(
"/auth/{provider}/login",
axum::routing::get(crate::oauth::oauth_login),
)
.route(
"/auth/{provider}/callback",
axum::routing::get(crate::oauth::oauth_callback),
)
.with_state(cfg.oauth.clone());
router = router.merge(oauth_router);
for p in cfg.oauth.list() {
infos.push(crate::wire::ProviderInfo {
name: p.name().to_string(),
display_name: p.display_name().to_string(),
login_url: format!("/auth/{}/login", p.name()),
icon_svg: p.icon_svg().map(|s| s.to_string()),
});
}
}
infos
};
#[cfg(not(feature = "_oauth-core"))]
let provider_infos: Vec<crate::wire::ProviderInfo> = Vec::new();
#[cfg(feature = "ratelimit")]
if let Some(rl) = cfg.rate_limit.as_ref() {
let governor_config = tower_governor::governor::GovernorConfigBuilder::default()
.key_extractor(LenientIpKeyExtractor)
.per_second(rl.per_second)
.burst_size(rl.burst)
.finish()
.ok_or_else(|| {
anyhow::anyhow!(
"invalid rate-limit config (per_second={}, burst={})",
rl.per_second,
rl.burst,
)
})?;
let governor_config = std::sync::Arc::new(governor_config);
let limiter = governor_config.limiter().clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
limiter.retain_recent();
}
});
router = router.layer(tower_governor::GovernorLayer::new(governor_config));
}
router = router.layer(axum::Extension(cfg.pool.clone()));
router = router.layer(axum::Extension(cfg.audit.clone()));
router = router.layer(axum::Extension(std::sync::Arc::new(provider_infos)));
#[cfg(feature = "mail")]
{
router = router.layer(axum::Extension(cfg.mailer.clone()));
}
if cfg.audit.retention_days > 0 {
let prune_pool = cfg.pool.clone();
let retention = cfg.audit.retention_days;
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
loop {
match crate::auth::audit::prune(&prune_pool, retention).await {
Ok(n) if n > 0 => {
eprintln!("[audit] pruned {n} event(s) older than {retention}d");
}
Ok(_) => {}
Err(err) => {
eprintln!("[audit] WARN: prune failed: {err}");
}
}
tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
}
});
}
router = router.layer(
AuthLayer::new(Some(cfg.pool.clone()))
.with_config(AxumAuthConfig::<i64>::default().with_anonymous_user_id(Some(1))),
);
let session_store = SessionStore::<SessionPool>::new(
Some(cfg.pool.into()),
SessionConfig::default()
.with_table_name(cfg.session_table_name)
.with_ip_and_user_agent(false)
.with_lifetime(cfg.session_lifetime)
.with_max_lifetime(cfg.session_max_lifetime)
.with_max_age(Some(cfg.cookie_max_age)),
)
.await?;
router = router.layer(SessionLayer::new(session_store));
Ok(router)
}
#[cfg(feature = "ratelimit")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LenientIpKeyExtractor;
#[cfg(feature = "ratelimit")]
impl tower_governor::key_extractor::KeyExtractor for LenientIpKeyExtractor {
type Key = std::net::IpAddr;
fn extract<T>(
&self,
req: &http::Request<T>,
) -> Result<Self::Key, tower_governor::GovernorError> {
use tower_governor::key_extractor::SmartIpKeyExtractor;
Ok(SmartIpKeyExtractor
.extract(req)
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)))
}
}