Skip to main content

riley_auth_api/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::extract::DefaultBodyLimit;
5use axum::Router;
6use sqlx::PgPool;
7use tokio::net::TcpListener;
8use metrics_exporter_prometheus::PrometheusHandle;
9use tower_http::cors::{AllowOrigin, CorsLayer};
10use tower_http::trace::TraceLayer;
11
12use riley_auth_core::config::Config;
13use riley_auth_core::jwt::Keys;
14use riley_auth_core::oauth::ResolvedProvider;
15use riley_auth_core::webhooks;
16
17use crate::routes;
18
19/// Cookie names derived from the configurable prefix.
20#[derive(Clone, Debug)]
21pub struct CookieNames {
22    pub access: String,
23    pub refresh: String,
24    pub oauth_state: String,
25    pub pkce: String,
26    pub setup: String,
27}
28
29impl CookieNames {
30    pub fn from_prefix(prefix: &str) -> Self {
31        Self {
32            access: format!("{prefix}_access"),
33            refresh: format!("{prefix}_refresh"),
34            oauth_state: format!("{prefix}_oauth_state"),
35            pkce: format!("{prefix}_pkce"),
36            setup: format!("{prefix}_setup"),
37        }
38    }
39}
40
41/// Shared application state.
42#[derive(Clone)]
43pub struct AppState {
44    pub config: Arc<Config>,
45    pub db: PgPool,
46    pub keys: Arc<Keys>,
47    pub http_client: reqwest::Client,
48    pub cookie_names: CookieNames,
49    pub username_regex: regex::Regex,
50    pub metrics_handle: Option<PrometheusHandle>,
51    pub providers: Arc<Vec<ResolvedProvider>>,
52    /// HTTP client for OAuth token exchange and profile fetching (reused across requests).
53    pub oauth_client: reqwest::Client,
54}
55
56pub async fn serve(config: Config, db: PgPool, keys: Keys) -> anyhow::Result<()> {
57    let addr = SocketAddr::new(config.server.host.parse()?, config.server.port);
58
59    let cors = build_cors(&config.server.cors_origins);
60
61    let behind_proxy = config.server.behind_proxy;
62    let rate_limit_backend = config.rate_limiting.backend.as_str();
63
64    // Build router with appropriate rate limiting backend
65    let base_router = match rate_limit_backend {
66        #[cfg(feature = "redis")]
67        "redis" => {
68            let redis_url = config
69                .rate_limiting
70                .redis_url
71                .as_ref()
72                .expect("redis_url validated at config load")
73                .resolve()?;
74            let limiter =
75                crate::rate_limit::TieredRedisRateLimiter::new(&redis_url, &config.rate_limiting.tiers)
76                    .await?;
77            let limiter = Arc::new(limiter);
78            tracing::info!("rate limiting backend: redis (tiered)");
79            routes::router_with_redis_rate_limit(behind_proxy, limiter)
80        }
81        #[cfg(not(feature = "redis"))]
82        "redis" => {
83            anyhow::bail!(
84                "rate_limiting.backend is \"redis\" but riley-auth was compiled without \
85                 the `redis` feature. Rebuild with `--features redis`."
86            );
87        }
88        _ => {
89            tracing::info!("rate limiting backend: in-memory (tiered)");
90            routes::router(behind_proxy, &config.rate_limiting.tiers)
91        }
92    };
93
94    // Initialize Prometheus metrics recorder if enabled
95    let metrics_handle = if config.metrics.enabled {
96        let handle = metrics_exporter_prometheus::PrometheusBuilder::new()
97            .install_recorder()
98            .map_err(|e| anyhow::anyhow!("failed to install metrics recorder: {e}"))?;
99        tracing::info!("metrics enabled, /metrics endpoint active");
100        Some(handle)
101    } else {
102        None
103    };
104
105    let cookie_names = CookieNames::from_prefix(&config.server.cookie_prefix);
106    let http_client = webhooks::build_webhook_client(config.webhooks.allow_private_ips);
107    if !config.webhooks.allow_private_ips {
108        tracing::info!("SSRF protection enabled for webhook delivery");
109    }
110    let username_regex = regex::Regex::new(&config.usernames.pattern)
111        .map_err(|e| anyhow::anyhow!("invalid username pattern: {e}"))?;
112
113    // Resolve OAuth providers at startup (OIDC discovery happens here)
114    let oauth_http = reqwest::Client::builder()
115        .user_agent("riley-auth")
116        .timeout(std::time::Duration::from_secs(10))
117        .build()?;
118    let providers = riley_auth_core::oauth::resolve_providers(
119        &config.oauth.providers,
120        &oauth_http,
121    )
122    .await?;
123    if providers.is_empty() {
124        tracing::warn!("no OAuth providers configured — login will not work");
125    } else {
126        let names: Vec<_> = providers.iter().map(|p| p.name.as_str()).collect();
127        tracing::info!(providers = ?names, "resolved {} OAuth provider(s)", providers.len());
128    }
129
130    let config = Arc::new(config);
131    let state = AppState {
132        config: Arc::clone(&config),
133        db: db.clone(),
134        keys: Arc::new(keys),
135        http_client: http_client.clone(),
136        cookie_names,
137        username_regex,
138        metrics_handle,
139        providers: Arc::new(providers),
140        oauth_client: oauth_http,
141    };
142
143    let app = Router::new()
144        .merge(base_router)
145        .layer(DefaultBodyLimit::max(1_048_576)) // 1 MiB
146        .layer(cors)
147        .layer(TraceLayer::new_for_http())
148        .with_state(state);
149
150    // Shutdown coordination: signal both the HTTP server and background workers
151    let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
152
153    // Start the webhook delivery worker
154    let delivery_shutdown = shutdown_rx.clone();
155    let block_private_ips = !config.webhooks.allow_private_ips;
156    let worker_handle = tokio::spawn(webhooks::delivery_worker(
157        db.clone(),
158        http_client,
159        config.webhooks.max_concurrent_deliveries,
160        block_private_ips,
161        delivery_shutdown,
162    ));
163
164    // Start the maintenance cleanup worker
165    let cleanup_handle = tokio::spawn(maintenance_worker(
166        db,
167        Arc::clone(&config),
168        shutdown_rx,
169    ));
170
171    tracing::info!(%addr, "starting server");
172    let listener = TcpListener::bind(addr).await?;
173
174    // Use into_make_service_with_connect_info so rate limit middleware
175    // can extract peer IP
176    axum::serve(
177        listener,
178        app.into_make_service_with_connect_info::<SocketAddr>(),
179    )
180    .with_graceful_shutdown(async move {
181        shutdown_signal().await;
182        let _ = shutdown_tx.send(true);
183    })
184    .await?;
185
186    // Wait for workers to finish draining
187    let _ = worker_handle.await;
188    let _ = cleanup_handle.await;
189
190    Ok(())
191}
192
193/// Background maintenance worker that periodically cleans up expired data.
194async fn maintenance_worker(
195    pool: PgPool,
196    config: Arc<Config>,
197    mut shutdown: tokio::sync::watch::Receiver<bool>,
198) {
199    let interval = std::time::Duration::from_secs(config.maintenance.cleanup_interval_secs);
200    let retention_days = config.maintenance.webhook_delivery_retention_days as i64;
201    let consumed_token_cutoff_secs = config.jwt.refresh_token_ttl_secs * 2;
202
203    tracing::info!(
204        interval_secs = config.maintenance.cleanup_interval_secs,
205        "maintenance worker started"
206    );
207
208    loop {
209        tokio::select! {
210            _ = tokio::time::sleep(interval) => {}
211            _ = shutdown.changed() => {
212                tracing::info!("maintenance worker shutting down");
213                return;
214            }
215        }
216
217        let cutoff = chrono::Utc::now() - chrono::Duration::seconds(consumed_token_cutoff_secs as i64);
218
219        match riley_auth_core::db::cleanup_expired_tokens(&pool).await {
220            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired refresh tokens"),
221            Err(e) => tracing::warn!("cleanup_expired_tokens failed: {e}"),
222            _ => {}
223        }
224
225        match riley_auth_core::db::cleanup_expired_auth_codes(&pool).await {
226            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired auth codes"),
227            Err(e) => tracing::warn!("cleanup_expired_auth_codes failed: {e}"),
228            _ => {}
229        }
230
231        match riley_auth_core::db::cleanup_expired_consent_requests(&pool).await {
232            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up expired consent requests"),
233            Err(e) => tracing::warn!("cleanup_expired_consent_requests failed: {e}"),
234            _ => {}
235        }
236
237        match riley_auth_core::db::cleanup_consumed_refresh_tokens(&pool, cutoff).await {
238            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up consumed refresh tokens"),
239            Err(e) => tracing::warn!("cleanup_consumed_refresh_tokens failed: {e}"),
240            _ => {}
241        }
242
243        match riley_auth_core::db::cleanup_webhook_deliveries(&pool, retention_days).await {
244            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up old webhook deliveries"),
245            Err(e) => tracing::warn!("cleanup_webhook_deliveries failed: {e}"),
246            _ => {}
247        }
248
249        match riley_auth_core::db::cleanup_webhook_outbox(&pool, retention_days).await {
250            Ok(n) if n > 0 => tracing::info!(count = n, "cleaned up old outbox entries"),
251            Err(e) => tracing::warn!("cleanup_webhook_outbox failed: {e}"),
252            _ => {}
253        }
254
255        match riley_auth_core::db::reset_stuck_outbox_entries(
256            &pool,
257            config.webhooks.stuck_processing_timeout_secs,
258        ).await {
259            Ok(n) if n > 0 => tracing::info!(count = n, "reset stuck processing outbox entries"),
260            Err(e) => tracing::warn!("reset_stuck_outbox_entries failed: {e}"),
261            _ => {}
262        }
263    }
264}
265
266fn build_cors(origins: &[String]) -> CorsLayer {
267    if origins.is_empty() {
268        // Restrictive CORS — no origins allowed, browsers block cross-origin requests
269        tracing::info!("no cors_origins configured — CORS disabled (same-origin only)");
270        CorsLayer::new()
271    } else if origins.len() == 1 && origins[0] == "*" {
272        tracing::warn!("cors_origins = [\"*\"] — using permissive CORS (not safe for production)");
273        CorsLayer::permissive()
274    } else {
275        let origins: Vec<_> = origins
276            .iter()
277            .filter_map(|o| match o.parse() {
278                Ok(v) => Some(v),
279                Err(e) => {
280                    tracing::warn!("ignoring unparseable CORS origin {o:?}: {e}");
281                    None
282                }
283            })
284            .collect();
285        CorsLayer::new()
286            .allow_origin(AllowOrigin::list(origins))
287            .allow_methods([
288                axum::http::Method::GET,
289                axum::http::Method::POST,
290                axum::http::Method::PATCH,
291                axum::http::Method::DELETE,
292            ])
293            .allow_headers([
294                axum::http::header::CONTENT_TYPE,
295                axum::http::header::AUTHORIZATION,
296                axum::http::HeaderName::from_static("x-requested-with"),
297            ])
298            .allow_credentials(true)
299    }
300}
301
302async fn shutdown_signal() {
303    let ctrl_c = tokio::signal::ctrl_c();
304
305    #[cfg(unix)]
306    {
307        let mut sigterm =
308            tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
309                .expect("failed to install SIGTERM handler");
310
311        tokio::select! {
312            _ = ctrl_c => tracing::info!("received CTRL+C"),
313            _ = sigterm.recv() => tracing::info!("received SIGTERM"),
314        }
315    }
316
317    #[cfg(not(unix))]
318    {
319        ctrl_c.await.ok();
320        tracing::info!("received CTRL+C");
321    }
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn cookie_names_default_prefix() {
330        let names = CookieNames::from_prefix("auth");
331        assert_eq!(names.access, "auth_access");
332        assert_eq!(names.refresh, "auth_refresh");
333        assert_eq!(names.oauth_state, "auth_oauth_state");
334        assert_eq!(names.pkce, "auth_pkce");
335        assert_eq!(names.setup, "auth_setup");
336    }
337
338    #[test]
339    fn cookie_names_custom_prefix() {
340        let names = CookieNames::from_prefix("myapp");
341        assert_eq!(names.access, "myapp_access");
342        assert_eq!(names.refresh, "myapp_refresh");
343        assert_eq!(names.oauth_state, "myapp_oauth_state");
344        assert_eq!(names.pkce, "myapp_pkce");
345        assert_eq!(names.setup, "myapp_setup");
346    }
347
348    #[test]
349    fn build_cors_empty_origins() {
350        // Empty origins produces a restrictive layer (no allowed origins)
351        let _layer = build_cors(&[]);
352    }
353
354    #[test]
355    fn build_cors_wildcard() {
356        // ["*"] produces a permissive layer
357        let _layer = build_cors(&["*".to_string()]);
358    }
359
360    #[test]
361    fn build_cors_explicit_origins() {
362        let _layer = build_cors(&[
363            "https://example.com".to_string(),
364            "https://app.example.com".to_string(),
365        ]);
366    }
367}