1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::DefaultBodyLimit;
5use axum::Router;
6use sqlx::PgPool;
7use tokio::net::TcpListener;
8use metrics_exporter_prometheus::PrometheusHandle;
9use tower_http::cors::{AllowOrigin, CorsLayer};
10use tower_http::trace::TraceLayer;
11
12use riley_auth_core::config::Config;
13use riley_auth_core::jwt::Keys;
14use riley_auth_core::oauth::ResolvedProvider;
15use riley_auth_core::webhooks;
16
17use crate::routes;
18
19#[derive(Clone, Debug)]
21pub struct CookieNames {
22 pub access: String,
23 pub refresh: String,
24 pub oauth_state: String,
25 pub pkce: String,
26 pub setup: String,
27}
28
29impl CookieNames {
30 pub fn from_prefix(prefix: &str) -> Self {
31 Self {
32 access: format!("{prefix}_access"),
33 refresh: format!("{prefix}_refresh"),
34 oauth_state: format!("{prefix}_oauth_state"),
35 pkce: format!("{prefix}_pkce"),
36 setup: format!("{prefix}_setup"),
37 }
38 }
39}
40
41#[derive(Clone)]
43pub struct AppState {
44 pub config: Arc<Config>,
45 pub db: PgPool,
46 pub keys: Arc<Keys>,
47 pub http_client: reqwest::Client,
48 pub cookie_names: CookieNames,
49 pub username_regex: regex::Regex,
50 pub metrics_handle: Option<PrometheusHandle>,
51 pub providers: Arc<Vec<ResolvedProvider>>,
52 pub oauth_client: reqwest::Client,
54}
55
56pub async fn serve(config: Config, db: PgPool, keys: Keys) -> anyhow::Result<()> {
57 let addr = SocketAddr::new(config.server.host.parse()?, config.server.port);
58
59 let cors = build_cors(&config.server.cors_origins);
60
61 let behind_proxy = config.server.behind_proxy;
62 let rate_limit_backend = config.rate_limiting.backend.as_str();
63
64 let base_router = match rate_limit_backend {
66 #[cfg(feature = "redis")]
67 "redis" => {
68 let redis_url = config
69 .rate_limiting
70 .redis_url
71 .as_ref()
72 .expect("redis_url validated at config load")
73 .resolve()?;
74 let limiter =
75 crate::rate_limit::TieredRedisRateLimiter::new(&redis_url, &config.rate_limiting.tiers)
76 .await?;
77 let limiter = Arc::new(limiter);
78 tracing::info!("rate limiting backend: redis (tiered)");
79 routes::router_with_redis_rate_limit(behind_proxy, limiter)
80 }
81 #[cfg(not(feature = "redis"))]
82 "redis" => {
83 anyhow::bail!(
84 "rate_limiting.backend is \"redis\" but riley-auth was compiled without \
85 the `redis` feature. Rebuild with `--features redis`."
86 );
87 }
88 _ => {
89 tracing::info!("rate limiting backend: in-memory (tiered)");
90 routes::router(behind_proxy, &config.rate_limiting.tiers)
91 }
92 };
93
94 let metrics_handle = if config.metrics.enabled {
96 let handle = metrics_exporter_prometheus::PrometheusBuilder::new()
97 .install_recorder()
98 .map_err(|e| anyhow::anyhow!("failed to install metrics recorder: {e}"))?;
99 tracing::info!("metrics enabled, /metrics endpoint active");
100 Some(handle)
101 } else {
102 None
103 };
104
105 let cookie_names = CookieNames::from_prefix(&config.server.cookie_prefix);
106 let http_client = webhooks::build_webhook_client(config.webhooks.allow_private_ips);
107 if !config.webhooks.allow_private_ips {
108 tracing::info!("SSRF protection enabled for webhook delivery");
109 }
110 let username_regex = regex::Regex::new(&config.usernames.pattern)
111 .map_err(|e| anyhow::anyhow!("invalid username pattern: {e}"))?;
112
113 let oauth_http = reqwest::Client::builder()
115 .user_agent("riley-auth")
116 .timeout(std::time::Duration::from_secs(10))
117 .build()?;
118 let providers = riley_auth_core::oauth::resolve_providers(
119 &config.oauth.providers,
120 &oauth_http,
121 )
122 .await?;
123 if providers.is_empty() {
124 tracing::warn!("no OAuth providers configured — login will not work");
125 } else {
126 let names: Vec<_> = providers.iter().map(|p| p.name.as_str()).collect();
127 tracing::info!(providers = ?names, "resolved {} OAuth provider(s)", providers.len());
128 }
129
130 let config = Arc::new(config);
131 let state = AppState {
132 config: Arc::clone(&config),
133 db: db.clone(),
134 keys: Arc::new(keys),
135 http_client: http_client.clone(),
136 cookie_names,
137 username_regex,
138 metrics_handle,
139 providers: Arc::new(providers),
140 oauth_client: oauth_http,
141 };
142
143 let app = Router::new()
144 .merge(base_router)
145 .layer(DefaultBodyLimit::max(1_048_576)) .layer(cors)
147 .layer(TraceLayer::new_for_http())
148 .with_state(state);
149
150 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
152
153 let delivery_shutdown = shutdown_rx.clone();
155 let block_private_ips = !config.webhooks.allow_private_ips;
156 let worker_handle = tokio::spawn(webhooks::delivery_worker(
157 db.clone(),
158 http_client,
159 config.webhooks.max_concurrent_deliveries,
160 block_private_ips,
161 delivery_shutdown,
162 ));
163
164 let cleanup_handle = tokio::spawn(maintenance_worker(
166 db,
167 Arc::clone(&config),
168 shutdown_rx,
169 ));
170
171 tracing::info!(%addr, "starting server");
172 let listener = TcpListener::bind(addr).await?;
173
174 axum::serve(
177 listener,
178 app.into_make_service_with_connect_info::<SocketAddr>(),
179 )
180 .with_graceful_shutdown(async move {
181 shutdown_signal().await;
182 let _ = shutdown_tx.send(true);
183 })
184 .await?;
185
186 let _ = worker_handle.await;
188 let _ = cleanup_handle.await;
189
190 Ok(())
191}
192
193async fn maintenance_worker(
195 pool: PgPool,
196 config: Arc<Config>,
197 mut shutdown: tokio::sync::watch::Receiver<bool>,
198) {
199 let interval = std::time::Duration::from_secs(config.maintenance.cleanup_interval_secs);
200 let retention_days = config.maintenance.webhook_delivery_retention_days as i64;
201 let consumed_token_cutoff_secs = config.jwt.refresh_token_ttl_secs * 2;
202
203 tracing::info!(
204 interval_secs = config.maintenance.cleanup_interval_secs,
205 "maintenance worker started"
206 );
207
208 loop {
209 tokio::select! {
210 _ = tokio::time::sleep(interval) => {}
211 _ = shutdown.changed() => {
212 tracing::info!("maintenance worker shutting down");
213 return;
214 }
215 }
216
217 let cutoff = chrono::Utc::now() - chrono::Duration::seconds(consumed_token_cutoff_secs as i64);
218
219 match riley_auth_core::db::cleanup_expired_tokens(&pool).await {
220 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired refresh tokens"),
221 Err(e) => tracing::warn!("cleanup_expired_tokens failed: {e}"),
222 _ => {}
223 }
224
225 match riley_auth_core::db::cleanup_expired_auth_codes(&pool).await {
226 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired auth codes"),
227 Err(e) => tracing::warn!("cleanup_expired_auth_codes failed: {e}"),
228 _ => {}
229 }
230
231 match riley_auth_core::db::cleanup_expired_consent_requests(&pool).await {
232 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired consent requests"),
233 Err(e) => tracing::warn!("cleanup_expired_consent_requests failed: {e}"),
234 _ => {}
235 }
236
237 match riley_auth_core::db::cleanup_consumed_refresh_tokens(&pool, cutoff).await {
238 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up consumed refresh tokens"),
239 Err(e) => tracing::warn!("cleanup_consumed_refresh_tokens failed: {e}"),
240 _ => {}
241 }
242
243 match riley_auth_core::db::cleanup_webhook_deliveries(&pool, retention_days).await {
244 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up old webhook deliveries"),
245 Err(e) => tracing::warn!("cleanup_webhook_deliveries failed: {e}"),
246 _ => {}
247 }
248
249 match riley_auth_core::db::cleanup_webhook_outbox(&pool, retention_days).await {
250 Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up old outbox entries"),
251 Err(e) => tracing::warn!("cleanup_webhook_outbox failed: {e}"),
252 _ => {}
253 }
254
255 match riley_auth_core::db::reset_stuck_outbox_entries(
256 &pool,
257 config.webhooks.stuck_processing_timeout_secs,
258 ).await {
259 Ok(n) if n > 0 => tracing::info!(count = n, "reset stuck processing outbox entries"),
260 Err(e) => tracing::warn!("reset_stuck_outbox_entries failed: {e}"),
261 _ => {}
262 }
263 }
264}
265
266fn build_cors(origins: &[String]) -> CorsLayer {
267 if origins.is_empty() {
268 tracing::info!("no cors_origins configured — CORS disabled (same-origin only)");
270 CorsLayer::new()
271 } else if origins.len() == 1 && origins[0] == "*" {
272 tracing::warn!("cors_origins = [\"*\"] — using permissive CORS (not safe for production)");
273 CorsLayer::permissive()
274 } else {
275 let origins: Vec<_> = origins
276 .iter()
277 .filter_map(|o| match o.parse() {
278 Ok(v) => Some(v),
279 Err(e) => {
280 tracing::warn!("ignoring unparseable CORS origin {o:?}: {e}");
281 None
282 }
283 })
284 .collect();
285 CorsLayer::new()
286 .allow_origin(AllowOrigin::list(origins))
287 .allow_methods([
288 axum::http::Method::GET,
289 axum::http::Method::POST,
290 axum::http::Method::PATCH,
291 axum::http::Method::DELETE,
292 ])
293 .allow_headers([
294 axum::http::header::CONTENT_TYPE,
295 axum::http::header::AUTHORIZATION,
296 axum::http::HeaderName::from_static("x-requested-with"),
297 ])
298 .allow_credentials(true)
299 }
300}
301
302async fn shutdown_signal() {
303 let ctrl_c = tokio::signal::ctrl_c();
304
305 #[cfg(unix)]
306 {
307 let mut sigterm =
308 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
309 .expect("failed to install SIGTERM handler");
310
311 tokio::select! {
312 _ = ctrl_c => tracing::info!("received CTRL+C"),
313 _ = sigterm.recv() => tracing::info!("received SIGTERM"),
314 }
315 }
316
317 #[cfg(not(unix))]
318 {
319 ctrl_c.await.ok();
320 tracing::info!("received CTRL+C");
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327
328 #[test]
329 fn cookie_names_default_prefix() {
330 let names = CookieNames::from_prefix("auth");
331 assert_eq!(names.access, "auth_access");
332 assert_eq!(names.refresh, "auth_refresh");
333 assert_eq!(names.oauth_state, "auth_oauth_state");
334 assert_eq!(names.pkce, "auth_pkce");
335 assert_eq!(names.setup, "auth_setup");
336 }
337
338 #[test]
339 fn cookie_names_custom_prefix() {
340 let names = CookieNames::from_prefix("myapp");
341 assert_eq!(names.access, "myapp_access");
342 assert_eq!(names.refresh, "myapp_refresh");
343 assert_eq!(names.oauth_state, "myapp_oauth_state");
344 assert_eq!(names.pkce, "myapp_pkce");
345 assert_eq!(names.setup, "myapp_setup");
346 }
347
348 #[test]
349 fn build_cors_empty_origins() {
350 let _layer = build_cors(&[]);
352 }
353
354 #[test]
355 fn build_cors_wildcard() {
356 let _layer = build_cors(&["*".to_string()]);
358 }
359
360 #[test]
361 fn build_cors_explicit_origins() {
362 let _layer = build_cors(&[
363 "https://example.com".to_string(),
364 "https://app.example.com".to_string(),
365 ]);
366 }
367}