1mod error;
2mod extractors;
3mod handlers;
4mod routes;
5pub mod security;
6mod state;
7
8pub use state::AppState;
9
10use crate::services::analytics::{
11 extract_browser_family, extract_device_type, extract_referrer_domain, generate_session_hash,
12 get_daily_salt, lookup_country, run_aggregation_job, Analytics, AnalyticsConfig,
13 AnalyticsEvent,
14};
15use crate::{Config, Database};
16use anyhow::Result;
17use axum::body::Body;
18use axum::extract::{ConnectInfo, State};
19use axum::http::Request;
20use axum::middleware::{self, Next};
21use axum::response::Response;
22use axum::Router;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::Instant;
27use tokio::net::TcpListener;
28use tower_http::compression::CompressionLayer;
29use tower_http::timeout::TimeoutLayer;
30use tower_http::trace::TraceLayer;
31
32pub async fn serve(
33 config: Config,
34 config_path: PathBuf,
35 db: Database,
36 addr: &str,
37 shutdown_rx: Option<tokio::sync::watch::Receiver<bool>>,
38) -> Result<()> {
39 let analytics_config = AnalyticsConfig::default();
40 let analytics = Arc::new(Analytics::with_config(db.clone(), analytics_config));
41
42 let state =
43 AppState::new(config, config_path, db.clone(), false)?.with_analytics(analytics.clone());
44 let state = Arc::new(state);
45
46 let analytics_aggregator = analytics.clone();
47 let mut agg_rx = shutdown_rx.clone().unwrap_or_else(|| {
48 let (_, rx) = tokio::sync::watch::channel(false);
49 rx
50 });
51 let agg_handle = tokio::spawn(async move {
52 tokio::select! {
53 _ = run_aggregation_job(analytics_aggregator) => {}
54 _ = async { while agg_rx.changed().await.is_ok() { if *agg_rx.borrow() { break; } } } => {
55 tracing::info!("Analytics aggregation stopping...");
56 }
57 }
58 });
59
60 let app = Router::new()
61 .merge(routes::public_routes())
62 .merge(routes::admin_routes())
63 .merge(routes::htmx_routes())
64 .merge(routes::api_routes())
65 .layer(middleware::from_fn_with_state(
66 state.clone(),
67 security::write_rate_limit_middleware,
68 ))
69 .layer(middleware::from_fn_with_state(
70 state.clone(),
71 analytics_middleware,
72 ))
73 .layer(middleware::from_fn(security::apply_security_headers))
74 .layer(CompressionLayer::new())
75 .layer(TimeoutLayer::with_status_code(axum::http::StatusCode::GATEWAY_TIMEOUT, std::time::Duration::from_secs(30)))
76 .layer(TraceLayer::new_for_http())
77 .with_state(state);
78
79 let listener = TcpListener::bind(addr).await?;
80 let app = app.into_make_service_with_connect_info::<SocketAddr>();
81 tracing::info!("Server listening on {}", addr);
82 axum::serve(listener, app)
83 .with_graceful_shutdown(shutdown_signal())
84 .await?;
85
86 agg_handle.abort();
88 tracing::info!("Server shut down gracefully");
89 Ok(())
90}
91
92pub async fn serve_production(
93 config: &Config,
94 config_path: PathBuf,
95 host: &str,
96 port: u16,
97) -> Result<()> {
98 let db = Database::open(&config.database.path)?;
99
100 let analytics_config = AnalyticsConfig::default();
101 let analytics = Arc::new(Analytics::with_config(db.clone(), analytics_config));
102
103 let state = AppState::new(config.clone(), config_path, db.clone(), true)?
104 .with_analytics(analytics.clone());
105 let state = Arc::new(state);
106
107 let analytics_aggregator = analytics.clone();
108 let agg_handle = tokio::spawn(async move {
109 run_aggregation_job(analytics_aggregator).await;
110 });
111
112 let app = Router::new()
113 .merge(routes::public_routes())
114 .merge(routes::api_routes())
115 .merge(routes::production_fallback_routes())
116 .layer(middleware::from_fn_with_state(
117 state.clone(),
118 analytics_middleware,
119 ))
120 .layer(middleware::from_fn(security::apply_security_headers))
121 .layer(CompressionLayer::new())
122 .layer(TimeoutLayer::with_status_code(axum::http::StatusCode::GATEWAY_TIMEOUT, std::time::Duration::from_secs(30)))
123 .layer(TraceLayer::new_for_http())
124 .with_state(state);
125
126 let addr = format!("{}:{}", host, port);
127 let listener = TcpListener::bind(&addr).await?;
128 let app = app.into_make_service_with_connect_info::<SocketAddr>();
129 tracing::info!("Production server listening on {}", addr);
130 axum::serve(listener, app)
131 .with_graceful_shutdown(shutdown_signal())
132 .await?;
133
134 agg_handle.abort();
135 tracing::info!("Production server shut down gracefully");
136 Ok(())
137}
138
139async fn shutdown_signal() {
142 let ctrl_c = async {
143 tokio::signal::ctrl_c()
144 .await
145 .expect("Failed to install Ctrl+C handler");
146 };
147
148 #[cfg(unix)]
149 let terminate = async {
150 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
151 .expect("Failed to install SIGTERM handler")
152 .recv()
153 .await;
154 };
155
156 #[cfg(not(unix))]
157 let terminate = std::future::pending::<()>();
158
159 tokio::select! {
160 _ = ctrl_c => {
161 tracing::info!("Received Ctrl+C, initiating graceful shutdown...");
162 }
163 _ = terminate => {
164 tracing::info!("Received SIGTERM, initiating graceful shutdown...");
165 }
166 }
167}
168
169async fn analytics_middleware(
170 State(state): State<Arc<AppState>>,
171 ConnectInfo(addr): ConnectInfo<SocketAddr>,
172 request: Request<Body>,
173 next: Next,
174) -> Response {
175 let start = Instant::now();
176 let path = request.uri().path().to_string();
177
178 let dnt_header = request
180 .headers()
181 .get("dnt")
182 .and_then(|v| v.to_str().ok())
183 .map(|s| s.to_string());
184
185 if let Some(analytics) = &state.analytics {
187 if !analytics.should_track(&path, dnt_header.as_deref()) {
188 return next.run(request).await;
189 }
190 } else if should_skip_tracking(&path) {
191 return next.run(request).await;
192 }
193
194 let user_agent = request
195 .headers()
196 .get("user-agent")
197 .and_then(|v| v.to_str().ok())
198 .unwrap_or("")
199 .to_string();
200
201 let referrer = request
202 .headers()
203 .get("referer")
204 .and_then(|v| v.to_str().ok())
205 .unwrap_or("")
206 .to_string();
207
208 let ip = addr.ip().to_string();
209
210 let response = next.run(request).await;
211
212 if let Some(analytics) = &state.analytics {
213 let daily_salt = get_daily_salt(&state.db).unwrap_or_else(|_| "default".to_string());
214 let session_hash = generate_session_hash(&ip, &user_agent, &daily_salt);
215 let response_time_ms = start.elapsed().as_millis() as i64;
216
217 let (content_id, content_type) = extract_content_info(&path, &state.db);
218
219 let country_code = if analytics.config().geo_lookup {
221 lookup_country(&ip)
222 } else {
223 None
224 };
225
226 let event = AnalyticsEvent {
227 path: path.clone(),
228 referrer_domain: extract_referrer_domain(&referrer),
229 country_code,
230 device_type: extract_device_type(&user_agent),
231 browser_family: extract_browser_family(&user_agent),
232 session_hash,
233 response_time_ms: Some(response_time_ms),
234 status_code: response.status().as_u16(),
235 content_id,
236 content_type,
237 };
238
239 if let Err(e) = analytics.record_event(&event) {
241 tracing::error!("Failed to record analytics event: {}", e);
242 }
243 }
244
245 response
246}
247
248fn should_skip_tracking(path: &str) -> bool {
249 let skip_prefixes = ["/static", "/media", "/admin", "/api", "/htmx", "/_"];
250 let skip_exact = ["/robots.txt", "/favicon.ico", "/health", "/sitemap.xml"];
251
252 skip_prefixes.iter().any(|p| path.starts_with(p))
253 || skip_exact.contains(&path)
254 || path.ends_with(".css")
255 || path.ends_with(".js")
256 || path.ends_with(".png")
257 || path.ends_with(".jpg")
258 || path.ends_with(".ico")
259 || path.ends_with(".woff")
260 || path.ends_with(".woff2")
261}
262
263fn extract_content_info(path: &str, db: &Database) -> (Option<i64>, Option<String>) {
264 if path.starts_with("/posts/") {
265 let slug = path.trim_start_matches("/posts/");
266 if let Ok(conn) = db.get() {
267 if let Ok((id, content_type)) = conn.query_row(
268 "SELECT id, content_type FROM content WHERE slug = ?1",
269 [slug],
270 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?)),
271 ) {
272 return (Some(id), Some(content_type));
273 }
274 }
275 } else if path.starts_with("/pages/") {
276 let slug = path.trim_start_matches("/pages/");
277 if let Ok(conn) = db.get() {
278 if let Ok((id, content_type)) = conn.query_row(
279 "SELECT id, content_type FROM content WHERE slug = ?1",
280 [slug],
281 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?)),
282 ) {
283 return (Some(id), Some(content_type));
284 }
285 }
286 }
287 (None, None)
288}