riley_cms_api/
lib.rs

1//! riley-cms-api: HTTP API server for riley_cms
2
3mod handlers;
4pub mod middleware;
5
6use axum::{
7    Router,
8    extract::DefaultBodyLimit,
9    http::{HeaderValue, Method, header},
10    middleware::from_fn_with_state,
11    routing::{any, get},
12};
13use middleware::auth_middleware;
14use riley_cms_core::{RileyCms, RileyCmsConfig};
15use std::net::IpAddr;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tower_governor::GovernorError;
19use tower_governor::GovernorLayer;
20use tower_governor::governor::GovernorConfigBuilder;
21use tower_governor::key_extractor::{KeyExtractor, PeerIpKeyExtractor, SmartIpKeyExtractor};
22use tower_http::cors::CorsLayer;
23use tower_http::set_header::SetResponseHeaderLayer;
24use tower_http::trace::TraceLayer;
25
26/// Rate-limit key extractor that respects proxy configuration.
27///
28/// When `behind_proxy` is true, extracts the client IP from X-Forwarded-For,
29/// X-Real-IP, or the Forwarded header (in that order). This is appropriate when
30/// deployed behind a trusted reverse proxy (nginx, Cloudflare, etc.).
31///
32/// When `behind_proxy` is false (default), uses the TCP peer address directly.
33/// This is correct for direct-to-internet deployments.
34#[derive(Debug, Clone, Copy)]
35struct RileyCmsKeyExtractor {
36    behind_proxy: bool,
37}
38
39impl KeyExtractor for RileyCmsKeyExtractor {
40    type Key = IpAddr;
41
42    fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> {
43        if self.behind_proxy {
44            SmartIpKeyExtractor.extract(req)
45        } else {
46            PeerIpKeyExtractor.extract(req)
47        }
48    }
49}
50
51/// Application state shared across handlers
52pub struct AppState {
53    pub riley_cms: RileyCms,
54    pub config: RileyCmsConfig,
55}
56
57/// Build the versioned API routes
58fn api_v1_routes() -> Router<Arc<AppState>> {
59    Router::new()
60        .route("/posts", get(handlers::list_posts))
61        .route("/posts/{slug}", get(handlers::get_post))
62        .route("/posts/{slug}/raw", get(handlers::get_post_raw))
63        .route("/series", get(handlers::list_series))
64        .route("/series/{slug}", get(handlers::get_series))
65        .route("/assets", get(handlers::list_assets))
66}
67
68/// Build the Axum router with all routes.
69///
70/// Note: Rate limiting is applied separately in `serve()` because it requires
71/// real TCP connection info (peer IP) which isn't available in `oneshot` tests.
72pub fn build_router(state: Arc<AppState>) -> Router {
73    let cors = build_cors_layer(&state.config);
74
75    Router::new()
76        // Versioned API routes
77        .nest("/api/v1", api_v1_routes())
78        // Health check (unversioned)
79        .route("/health", get(handlers::health))
80        // Git Smart HTTP routes (uses Basic Auth, not Bearer token)
81        .route("/git/{*path}", any(handlers::git_handler))
82        // Auth middleware - runs on all routes, sets AuthStatus in extensions
83        .layer(from_fn_with_state(state.clone(), auth_middleware))
84        // State and other middleware
85        .with_state(state)
86        // Disable Axum's default 2MB body limit. The git handler enforces its own
87        // streaming limit (default 100MB), and all other routes are GET-only.
88        .layer(DefaultBodyLimit::disable())
89        .layer(cors)
90        .layer(SetResponseHeaderLayer::overriding(
91            header::X_CONTENT_TYPE_OPTIONS,
92            HeaderValue::from_static("nosniff"),
93        ))
94        .layer(SetResponseHeaderLayer::overriding(
95            header::X_FRAME_OPTIONS,
96            HeaderValue::from_static("DENY"),
97        ))
98        .layer(SetResponseHeaderLayer::overriding(
99            header::CONTENT_SECURITY_POLICY,
100            HeaderValue::from_static("default-src 'none'"),
101        ))
102        .layer(
103            TraceLayer::new_for_http().make_span_with(
104                tower_http::trace::DefaultMakeSpan::new()
105                    .level(tracing::Level::INFO)
106                    .include_headers(false),
107            ),
108        )
109}
110
111/// Build CORS layer from config.
112///
113/// Defaults to denying all cross-origin requests if `cors_origins` is not configured.
114/// Set `cors_origins = ["*"]` to allow all origins, or specify explicit origins.
115fn build_cors_layer(config: &RileyCmsConfig) -> CorsLayer {
116    let origins = config
117        .server
118        .as_ref()
119        .map(|s| &s.cors_origins)
120        .filter(|o| !o.is_empty());
121
122    match origins {
123        Some(origins) if origins.iter().any(|o| o == "*") => CorsLayer::permissive(),
124        Some(origins) => {
125            let origins: Vec<_> = origins.iter().filter_map(|o| o.parse().ok()).collect();
126            CorsLayer::new()
127                .allow_origin(origins)
128                .allow_methods([Method::GET, Method::OPTIONS])
129                .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
130        }
131        // Default: deny all cross-origin requests (secure by default)
132        None => CorsLayer::new(),
133    }
134}
135
136/// Run the API server with graceful shutdown support.
137///
138/// The server will drain in-flight connections when receiving SIGINT (Ctrl+C)
139/// or SIGTERM (Docker stop / Kubernetes terminate).
140pub async fn serve(riley_cms: RileyCms) -> anyhow::Result<()> {
141    let config = riley_cms.config().clone();
142    let server_config = config.server.clone().unwrap_or_default();
143
144    let state = Arc::new(AppState { riley_cms, config });
145
146    // Rate limiting: 50 burst capacity, replenish 10/second per IP.
147    // Allows normal browsing but prevents brute-force on auth endpoints.
148    // Applied here (not in build_router) because it requires real TCP peer IP.
149    let key_extractor = RileyCmsKeyExtractor {
150        behind_proxy: server_config.behind_proxy,
151    };
152    if server_config.behind_proxy {
153        tracing::info!(
154            "Rate limiter using proxy headers (X-Forwarded-For/X-Real-IP) for client IP"
155        );
156    }
157    let governor_conf = GovernorConfigBuilder::default()
158        .key_extractor(key_extractor)
159        .per_second(10)
160        .burst_size(50)
161        .finish()
162        .unwrap();
163    let governor_layer = GovernorLayer::new(governor_conf);
164
165    let app = build_router(state).layer(governor_layer);
166
167    let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port).parse()?;
168
169    tracing::info!("Starting server on {}", addr);
170
171    let listener = tokio::net::TcpListener::bind(addr).await?;
172    axum::serve(listener, app)
173        .with_graceful_shutdown(shutdown_signal())
174        .await?;
175
176    Ok(())
177}
178
179/// Wait for a shutdown signal (SIGINT or SIGTERM).
180async fn shutdown_signal() {
181    let ctrl_c = async {
182        tokio::signal::ctrl_c()
183            .await
184            .expect("failed to install Ctrl+C handler");
185    };
186
187    #[cfg(unix)]
188    let terminate = async {
189        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
190            .expect("failed to install SIGTERM handler")
191            .recv()
192            .await;
193    };
194
195    #[cfg(not(unix))]
196    let terminate = std::future::pending::<()>();
197
198    tokio::select! {
199        _ = ctrl_c => {},
200        _ = terminate => {},
201    }
202
203    tracing::info!("Shutdown signal received, draining connections...");
204}