Skip to main content

pebble_cms/web/
mod.rs

1mod error;
2mod extractors;
3mod handlers;
4mod routes;
5pub mod security;
6mod state;
7
8pub use state::AppState;
9
10use crate::services::analytics::{
11    extract_browser_family, extract_device_type, extract_referrer_domain, generate_session_hash,
12    get_daily_salt, lookup_country, run_aggregation_job, Analytics, AnalyticsConfig,
13    AnalyticsEvent,
14};
15use crate::{Config, Database};
16use anyhow::Result;
17use axum::body::Body;
18use axum::extract::{ConnectInfo, State};
19use axum::http::Request;
20use axum::middleware::{self, Next};
21use axum::response::Response;
22use axum::Router;
23use std::net::SocketAddr;
24use std::path::PathBuf;
25use std::sync::Arc;
26use std::time::Instant;
27use tokio::net::TcpListener;
28use tower_http::compression::CompressionLayer;
29use tower_http::timeout::TimeoutLayer;
30use tower_http::trace::TraceLayer;
31
32pub async fn serve(
33    config: Config,
34    config_path: PathBuf,
35    db: Database,
36    addr: &str,
37    shutdown_rx: Option<tokio::sync::watch::Receiver<bool>>,
38) -> Result<()> {
39    let analytics_config = AnalyticsConfig::default();
40    let analytics = Arc::new(Analytics::with_config(db.clone(), analytics_config));
41
42    let state =
43        AppState::new(config, config_path, db.clone(), false)?.with_analytics(analytics.clone());
44    let state = Arc::new(state);
45
46    let analytics_aggregator = analytics.clone();
47    let mut agg_rx = shutdown_rx.clone().unwrap_or_else(|| {
48        let (_, rx) = tokio::sync::watch::channel(false);
49        rx
50    });
51    let agg_handle = tokio::spawn(async move {
52        tokio::select! {
53            _ = run_aggregation_job(analytics_aggregator) => {}
54            _ = async { while agg_rx.changed().await.is_ok() { if *agg_rx.borrow() { break; } } } => {
55                tracing::info!("Analytics aggregation stopping...");
56            }
57        }
58    });
59
60    let app = Router::new()
61        .merge(routes::public_routes())
62        .merge(routes::admin_routes())
63        .merge(routes::htmx_routes())
64        .merge(routes::api_routes())
65        .layer(middleware::from_fn_with_state(
66            state.clone(),
67            security::write_rate_limit_middleware,
68        ))
69        .layer(middleware::from_fn_with_state(
70            state.clone(),
71            analytics_middleware,
72        ))
73        .layer(middleware::from_fn(security::apply_security_headers))
74        .layer(CompressionLayer::new())
75        .layer(TimeoutLayer::with_status_code(axum::http::StatusCode::GATEWAY_TIMEOUT, std::time::Duration::from_secs(30)))
76        .layer(TraceLayer::new_for_http())
77        .with_state(state);
78
79    let listener = TcpListener::bind(addr).await?;
80    let app = app.into_make_service_with_connect_info::<SocketAddr>();
81    tracing::info!("Server listening on {}", addr);
82    axum::serve(listener, app)
83        .with_graceful_shutdown(shutdown_signal())
84        .await?;
85
86    // Signal background tasks to stop
87    agg_handle.abort();
88    tracing::info!("Server shut down gracefully");
89    Ok(())
90}
91
92pub async fn serve_production(
93    config: &Config,
94    config_path: PathBuf,
95    host: &str,
96    port: u16,
97) -> Result<()> {
98    let db = Database::open(&config.database.path)?;
99
100    let analytics_config = AnalyticsConfig::default();
101    let analytics = Arc::new(Analytics::with_config(db.clone(), analytics_config));
102
103    let state = AppState::new(config.clone(), config_path, db.clone(), true)?
104        .with_analytics(analytics.clone());
105    let state = Arc::new(state);
106
107    let analytics_aggregator = analytics.clone();
108    let agg_handle = tokio::spawn(async move {
109        run_aggregation_job(analytics_aggregator).await;
110    });
111
112    let app = Router::new()
113        .merge(routes::public_routes())
114        .merge(routes::api_routes())
115        .merge(routes::production_fallback_routes())
116        .layer(middleware::from_fn_with_state(
117            state.clone(),
118            analytics_middleware,
119        ))
120        .layer(middleware::from_fn(security::apply_security_headers))
121        .layer(CompressionLayer::new())
122        .layer(TimeoutLayer::with_status_code(axum::http::StatusCode::GATEWAY_TIMEOUT, std::time::Duration::from_secs(30)))
123        .layer(TraceLayer::new_for_http())
124        .with_state(state);
125
126    let addr = format!("{}:{}", host, port);
127    let listener = TcpListener::bind(&addr).await?;
128    let app = app.into_make_service_with_connect_info::<SocketAddr>();
129    tracing::info!("Production server listening on {}", addr);
130    axum::serve(listener, app)
131        .with_graceful_shutdown(shutdown_signal())
132        .await?;
133
134    agg_handle.abort();
135    tracing::info!("Production server shut down gracefully");
136    Ok(())
137}
138
139/// Listens for SIGTERM/SIGINT and returns when either is received.
140/// On Unix, also listens for SIGTERM. On all platforms, listens for Ctrl+C.
141async fn shutdown_signal() {
142    let ctrl_c = async {
143        tokio::signal::ctrl_c()
144            .await
145            .expect("Failed to install Ctrl+C handler");
146    };
147
148    #[cfg(unix)]
149    let terminate = async {
150        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
151            .expect("Failed to install SIGTERM handler")
152            .recv()
153            .await;
154    };
155
156    #[cfg(not(unix))]
157    let terminate = std::future::pending::<()>();
158
159    tokio::select! {
160        _ = ctrl_c => {
161            tracing::info!("Received Ctrl+C, initiating graceful shutdown...");
162        }
163        _ = terminate => {
164            tracing::info!("Received SIGTERM, initiating graceful shutdown...");
165        }
166    }
167}
168
169async fn analytics_middleware(
170    State(state): State<Arc<AppState>>,
171    ConnectInfo(addr): ConnectInfo<SocketAddr>,
172    request: Request<Body>,
173    next: Next,
174) -> Response {
175    let start = Instant::now();
176    let path = request.uri().path().to_string();
177
178    // Get DNT header before moving request
179    let dnt_header = request
180        .headers()
181        .get("dnt")
182        .and_then(|v| v.to_str().ok())
183        .map(|s| s.to_string());
184
185    // Check if we should track this request using analytics config
186    if let Some(analytics) = &state.analytics {
187        if !analytics.should_track(&path, dnt_header.as_deref()) {
188            return next.run(request).await;
189        }
190    } else if should_skip_tracking(&path) {
191        return next.run(request).await;
192    }
193
194    let user_agent = request
195        .headers()
196        .get("user-agent")
197        .and_then(|v| v.to_str().ok())
198        .unwrap_or("")
199        .to_string();
200
201    let referrer = request
202        .headers()
203        .get("referer")
204        .and_then(|v| v.to_str().ok())
205        .unwrap_or("")
206        .to_string();
207
208    let ip = addr.ip().to_string();
209
210    let response = next.run(request).await;
211
212    if let Some(analytics) = &state.analytics {
213        let daily_salt = get_daily_salt(&state.db).unwrap_or_else(|_| "default".to_string());
214        let session_hash = generate_session_hash(&ip, &user_agent, &daily_salt);
215        let response_time_ms = start.elapsed().as_millis() as i64;
216
217        let (content_id, content_type) = extract_content_info(&path, &state.db);
218
219        // Lookup country from IP (privacy-preserving: IP is not stored)
220        let country_code = if analytics.config().geo_lookup {
221            lookup_country(&ip)
222        } else {
223            None
224        };
225
226        let event = AnalyticsEvent {
227            path: path.clone(),
228            referrer_domain: extract_referrer_domain(&referrer),
229            country_code,
230            device_type: extract_device_type(&user_agent),
231            browser_family: extract_browser_family(&user_agent),
232            session_hash,
233            response_time_ms: Some(response_time_ms),
234            status_code: response.status().as_u16(),
235            content_id,
236            content_type,
237        };
238
239        // Record event immediately for real-time analytics
240        if let Err(e) = analytics.record_event(&event) {
241            tracing::error!("Failed to record analytics event: {}", e);
242        }
243    }
244
245    response
246}
247
248fn should_skip_tracking(path: &str) -> bool {
249    let skip_prefixes = ["/static", "/media", "/admin", "/api", "/htmx", "/_"];
250    let skip_exact = ["/robots.txt", "/favicon.ico", "/health", "/sitemap.xml"];
251
252    skip_prefixes.iter().any(|p| path.starts_with(p))
253        || skip_exact.contains(&path)
254        || path.ends_with(".css")
255        || path.ends_with(".js")
256        || path.ends_with(".png")
257        || path.ends_with(".jpg")
258        || path.ends_with(".ico")
259        || path.ends_with(".woff")
260        || path.ends_with(".woff2")
261}
262
263fn extract_content_info(path: &str, db: &Database) -> (Option<i64>, Option<String>) {
264    if path.starts_with("/posts/") {
265        let slug = path.trim_start_matches("/posts/");
266        if let Ok(conn) = db.get() {
267            if let Ok((id, content_type)) = conn.query_row(
268                "SELECT id, content_type FROM content WHERE slug = ?1",
269                [slug],
270                |row| Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?)),
271            ) {
272                return (Some(id), Some(content_type));
273            }
274        }
275    } else if path.starts_with("/pages/") {
276        let slug = path.trim_start_matches("/pages/");
277        if let Ok(conn) = db.get() {
278            if let Ok((id, content_type)) = conn.query_row(
279                "SELECT id, content_type FROM content WHERE slug = ?1",
280                [slug],
281                |row| Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?)),
282            ) {
283                return (Some(id), Some(content_type));
284            }
285        }
286    }
287    (None, None)
288}