1mod handlers;
4pub mod middleware;
5
6use axum::{
7 Router,
8 extract::DefaultBodyLimit,
9 http::{HeaderValue, Method, header},
10 middleware::from_fn_with_state,
11 routing::{any, get},
12};
13use middleware::auth_middleware;
14use riley_cms_core::{RileyCms, RileyCmsConfig};
15use std::net::IpAddr;
16use std::net::SocketAddr;
17use std::sync::Arc;
18use tower_governor::GovernorError;
19use tower_governor::GovernorLayer;
20use tower_governor::governor::GovernorConfigBuilder;
21use tower_governor::key_extractor::{KeyExtractor, PeerIpKeyExtractor, SmartIpKeyExtractor};
22use tower_http::cors::CorsLayer;
23use tower_http::set_header::SetResponseHeaderLayer;
24use tower_http::trace::TraceLayer;
25
26#[derive(Debug, Clone, Copy)]
35struct RileyCmsKeyExtractor {
36 behind_proxy: bool,
37}
38
39impl KeyExtractor for RileyCmsKeyExtractor {
40 type Key = IpAddr;
41
42 fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, GovernorError> {
43 if self.behind_proxy {
44 SmartIpKeyExtractor.extract(req)
45 } else {
46 PeerIpKeyExtractor.extract(req)
47 }
48 }
49}
50
51pub struct AppState {
53 pub riley_cms: RileyCms,
54 pub config: RileyCmsConfig,
55}
56
57fn api_v1_routes() -> Router<Arc<AppState>> {
59 Router::new()
60 .route("/posts", get(handlers::list_posts))
61 .route("/posts/{slug}", get(handlers::get_post))
62 .route("/posts/{slug}/raw", get(handlers::get_post_raw))
63 .route("/series", get(handlers::list_series))
64 .route("/series/{slug}", get(handlers::get_series))
65 .route("/assets", get(handlers::list_assets))
66}
67
68pub fn build_router(state: Arc<AppState>) -> Router {
73 let cors = build_cors_layer(&state.config);
74
75 Router::new()
76 .nest("/api/v1", api_v1_routes())
78 .route("/health", get(handlers::health))
80 .route("/git/{*path}", any(handlers::git_handler))
82 .layer(from_fn_with_state(state.clone(), auth_middleware))
84 .with_state(state)
86 .layer(DefaultBodyLimit::disable())
89 .layer(cors)
90 .layer(SetResponseHeaderLayer::overriding(
91 header::X_CONTENT_TYPE_OPTIONS,
92 HeaderValue::from_static("nosniff"),
93 ))
94 .layer(SetResponseHeaderLayer::overriding(
95 header::X_FRAME_OPTIONS,
96 HeaderValue::from_static("DENY"),
97 ))
98 .layer(SetResponseHeaderLayer::overriding(
99 header::CONTENT_SECURITY_POLICY,
100 HeaderValue::from_static("default-src 'none'"),
101 ))
102 .layer(
103 TraceLayer::new_for_http().make_span_with(
104 tower_http::trace::DefaultMakeSpan::new()
105 .level(tracing::Level::INFO)
106 .include_headers(false),
107 ),
108 )
109}
110
111fn build_cors_layer(config: &RileyCmsConfig) -> CorsLayer {
116 let origins = config
117 .server
118 .as_ref()
119 .map(|s| &s.cors_origins)
120 .filter(|o| !o.is_empty());
121
122 match origins {
123 Some(origins) if origins.iter().any(|o| o == "*") => CorsLayer::permissive(),
124 Some(origins) => {
125 let origins: Vec<_> = origins.iter().filter_map(|o| o.parse().ok()).collect();
126 CorsLayer::new()
127 .allow_origin(origins)
128 .allow_methods([Method::GET, Method::OPTIONS])
129 .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE])
130 }
131 None => CorsLayer::new(),
133 }
134}
135
136pub async fn serve(riley_cms: RileyCms) -> anyhow::Result<()> {
141 let config = riley_cms.config().clone();
142 let server_config = config.server.clone().unwrap_or_default();
143
144 let state = Arc::new(AppState { riley_cms, config });
145
146 let key_extractor = RileyCmsKeyExtractor {
150 behind_proxy: server_config.behind_proxy,
151 };
152 if server_config.behind_proxy {
153 tracing::info!(
154 "Rate limiter using proxy headers (X-Forwarded-For/X-Real-IP) for client IP"
155 );
156 }
157 let governor_conf = GovernorConfigBuilder::default()
158 .key_extractor(key_extractor)
159 .per_second(10)
160 .burst_size(50)
161 .finish()
162 .unwrap();
163 let governor_layer = GovernorLayer::new(governor_conf);
164
165 let app = build_router(state).layer(governor_layer);
166
167 let addr: SocketAddr = format!("{}:{}", server_config.host, server_config.port).parse()?;
168
169 tracing::info!("Starting server on {}", addr);
170
171 let listener = tokio::net::TcpListener::bind(addr).await?;
172 axum::serve(listener, app)
173 .with_graceful_shutdown(shutdown_signal())
174 .await?;
175
176 Ok(())
177}
178
179async fn shutdown_signal() {
181 let ctrl_c = async {
182 tokio::signal::ctrl_c()
183 .await
184 .expect("failed to install Ctrl+C handler");
185 };
186
187 #[cfg(unix)]
188 let terminate = async {
189 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
190 .expect("failed to install SIGTERM handler")
191 .recv()
192 .await;
193 };
194
195 #[cfg(not(unix))]
196 let terminate = std::future::pending::<()>();
197
198 tokio::select! {
199 _ = ctrl_c => {},
200 _ = terminate => {},
201 }
202
203 tracing::info!("Shutdown signal received, draining connections...");
204}