Skip to main content

folk_plugin_http/
server.rs

1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19#[cfg(feature = "h2c")]
20use tracing::warn;
21use tracing::{error, info};
22
23use crate::config::HttpConfig;
24use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
25use crate::payload::encode_request;
26
27#[derive(Clone)]
28struct AppState {
29    executor: Arc<dyn folk_api::Executor>,
30    config: Arc<HttpConfig>,
31    active_connections: Arc<AtomicU64>,
32    hook_engine: Option<Arc<HookEngine>>,
33}
34
35pub struct HttpServer {
36    config: HttpConfig,
37    executor: Arc<dyn folk_api::Executor>,
38    active_connections: Arc<AtomicU64>,
39    hook_engine: Option<Arc<HookEngine>>,
40}
41
42impl HttpServer {
43    pub fn new(
44        config: HttpConfig,
45        executor: Arc<dyn folk_api::Executor>,
46        active_connections: Arc<AtomicU64>,
47        hook_engine: Option<Arc<HookEngine>>,
48    ) -> Self {
49        Self {
50            config,
51            executor,
52            active_connections,
53            hook_engine,
54        }
55    }
56
57    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
58        let state = AppState {
59            executor: self.executor.clone(),
60            config: Arc::new(self.config.clone()),
61            active_connections: self.active_connections.clone(),
62            hook_engine: self.hook_engine.clone(),
63        };
64
65        let mut app = Router::new()
66            .route("/{*path}", any(handle))
67            .route("/", any(handle))
68            .with_state(state)
69            .layer(
70                ServiceBuilder::new()
71                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
72                    .layer(TimeoutLayer::with_status_code(
73                        http::StatusCode::GATEWAY_TIMEOUT,
74                        self.config.write_timeout,
75                    )),
76            );
77
78        if self.config.compression.enabled {
79            app = app.layer(build_compression_layer(&self.config.compression));
80        }
81
82        #[cfg(feature = "tls")]
83        if let Some(ref tls) = self.config.tls {
84            return self.run_tls(app, tls, shutdown).await;
85        }
86
87        #[cfg(feature = "h2c")]
88        if self.config.h2c {
89            return self.run_h2c(app, shutdown).await;
90        }
91
92        self.run_plain(app, shutdown).await
93    }
94
95    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
96        let listener = TcpListener::bind(self.config.listen).await?;
97
98        axum::serve(
99            listener,
100            app.into_make_service_with_connect_info::<SocketAddr>(),
101        )
102        .with_graceful_shutdown(shutdown_signal(shutdown))
103        .await?;
104
105        Ok(())
106    }
107
108    #[cfg(feature = "tls")]
109    async fn run_tls(
110        &self,
111        app: Router,
112        tls: &crate::config::TlsConfig,
113        shutdown: watch::Receiver<bool>,
114    ) -> Result<()> {
115        use axum_server::Handle;
116        use axum_server::tls_rustls::RustlsConfig;
117
118        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
119
120        info!(cert = %tls.cert.display(), "TLS enabled");
121
122        let handle = Handle::new();
123        let shutdown_handle = handle.clone();
124        tokio::spawn(async move {
125            shutdown_signal(shutdown).await;
126            shutdown_handle.graceful_shutdown(None);
127        });
128
129        axum_server::bind_rustls(self.config.listen, rustls_config)
130            .handle(handle)
131            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
132            .await?;
133
134        Ok(())
135    }
136
137    #[cfg(feature = "h2c")]
138    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
139        use hyper_util::rt::{TokioExecutor, TokioIo};
140        use hyper_util::server::conn::auto::Builder as AutoBuilder;
141
142        info!("h2c (HTTP/2 cleartext) enabled");
143
144        let listener = TcpListener::bind(self.config.listen).await?;
145        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
146        let mut tasks = tokio::task::JoinSet::new();
147
148        loop {
149            tokio::select! {
150                result = listener.accept() => {
151                    let (stream, remote_addr) = result?;
152                    let app = app.clone();
153                    let builder = builder.clone();
154                    tasks.spawn(async move {
155                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
156                            // Inject ConnectInfo manually since we bypass axum::serve.
157                            req.extensions_mut().insert(ConnectInfo(remote_addr));
158                            let app = app.clone();
159                            async move {
160                                let resp = tower::Service::call(&mut app.clone(), req).await;
161                                resp.map_err(|e| match e {})
162                            }
163                        });
164                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
165                    });
166                }
167                _ = async {
168                    loop {
169                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
170                            break;
171                        }
172                    }
173                } => {
174                    break;
175                }
176            }
177        }
178
179        // Drain in-flight connections, but only up to shutdown_timeout.
180        let deadline = tokio::time::sleep(self.config.shutdown_timeout);
181        tokio::pin!(deadline);
182        loop {
183            tokio::select! {
184                _ = &mut deadline => {
185                    let remaining = tasks.len();
186                    if remaining > 0 {
187                        warn!(remaining, "h2c graceful shutdown timed out; aborting connections");
188                        tasks.abort_all();
189                        while tasks.join_next().await.is_some() {}
190                    }
191                    break;
192                }
193                result = tasks.join_next() => {
194                    if result.is_none() {
195                        break;
196                    }
197                }
198            }
199        }
200
201        Ok(())
202    }
203}
204
205async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
206    loop {
207        if shutdown.changed().await.is_err() || *shutdown.borrow() {
208            break;
209        }
210    }
211}
212
213struct ConnectionGuard(Arc<AtomicU64>);
214
215impl Drop for ConnectionGuard {
216    fn drop(&mut self) {
217        self.0.fetch_sub(1, Ordering::Relaxed);
218    }
219}
220
221async fn handle(
222    State(state): State<AppState>,
223    connect_info: ConnectInfo<SocketAddr>,
224    req: Request<Body>,
225) -> Response<Body> {
226    state.active_connections.fetch_add(1, Ordering::Relaxed);
227    let _conn_guard = ConnectionGuard(state.active_connections.clone());
228    let start = Instant::now();
229    let method = req.method().clone();
230    let uri = req.uri().clone();
231    let peer_addr = connect_info.0;
232
233    let client_ip = resolve_client_ip(
234        peer_addr.ip(),
235        req.headers()
236            .get("x-forwarded-for")
237            .and_then(|v| v.to_str().ok()),
238        &state.config.trusted_proxies,
239    );
240
241    let (response, request_id) = handle_inner(&state, req, client_ip).await;
242
243    if state.config.access_log {
244        let duration = start.elapsed();
245        let status = response.status().as_u16();
246        let response_bytes = response
247            .headers()
248            .get(http::header::CONTENT_LENGTH)
249            .and_then(|v| v.to_str().ok())
250            .and_then(|v| v.parse::<u64>().ok())
251            .unwrap_or(0);
252        info!(
253            request_id = %request_id,
254            client_ip = %client_ip,
255            method = %method,
256            uri = %uri,
257            status = status,
258            duration_ms = duration.as_millis() as u64,
259            response_bytes = response_bytes,
260            "http request",
261        );
262    }
263
264    // _conn_guard drop handles fetch_sub
265    response
266}
267
268/// Handles the request and returns the response together with the `request_id`
269/// under which it was dispatched (empty before dispatch, e.g. on read errors).
270async fn handle_inner(
271    state: &AppState,
272    req: Request<Body>,
273    client_ip: IpAddr,
274) -> (Response<Body>, Arc<str>) {
275    let no_id: Arc<str> = Arc::from("");
276    let max_body = state.config.max_request_size;
277    let read_timeout = state.config.read_timeout;
278
279    // ── Extract request parts; build hook context only when hooks are active ──
280    let (parts, body) = req.into_parts();
281
282    // Only pay the header-copy cost when there is actually a hook engine.
283    // At 10K rps with 20 headers per request this avoids ~200K String
284    // allocations per second on the no-hooks fast path.
285    let mut req_ctx: Option<RequestContext> = if state.hook_engine.is_some() {
286        let headers: HashMap<String, String> = parts
287            .headers
288            .iter()
289            .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
290            .collect();
291        Some(RequestContext {
292            method: parts.method.to_string(),
293            path: parts.uri.path().to_string(),
294            query: parts.uri.query().unwrap_or("").to_string(),
295            client_ip: client_ip.to_string(),
296            request_id: String::new(), // filled in after dispatch
297            headers,
298            extra: HashMap::new(),
299            error: None,
300            short_circuited: false,
301        })
302    } else {
303        None
304    };
305
306    // Reassemble so encode_request can consume the body.
307    let req_reassembled = Request::from_parts(parts, body);
308
309    // ── Read + encode body ───────────────────────────────────────────────────
310    let payload =
311        match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
312            Ok(Ok(p)) => p,
313            Ok(Err(e)) => {
314                error!(error = ?e, "encode request");
315                return (
316                    Response::builder()
317                        .status(500)
318                        .body(Body::from("encode error"))
319                        .unwrap(),
320                    no_id,
321                );
322            }
323            Err(_) => {
324                return (
325                    Response::builder()
326                        .status(408)
327                        .body(Body::from("request body read timeout"))
328                        .unwrap(),
329                    no_id,
330                );
331            }
332        };
333
334    // ── request.before hooks ──────────────────────────────────────────────────
335    if let (Some(engine), Some(ctx)) = (&state.hook_engine, &mut req_ctx) {
336        match engine.run_request_before(ctx) {
337            HookResult::ShortCircuit(resp) => {
338                // Async hooks already spawned inside run_request_before.
339                return (resp, no_id);
340            }
341            HookResult::Continue => {}
342        }
343    }
344
345    // ── Dispatch to PHP worker ───────────────────────────────────────────────
346    let (response_value, request_id) = match state
347        .executor
348        .execute_value_traced("http.handle", payload)
349        .await
350    {
351        Ok(v) => v,
352        Err(e) => {
353            error!(error = ?e, "dispatch to worker");
354
355            // Fire request.error hooks.
356            if let (Some(engine), Some(ctx)) = (&state.hook_engine, &req_ctx) {
357                let mut err_ctx = ctx.clone();
358                err_ctx.request_id = no_id.to_string();
359                err_ctx.error = Some(e.to_string());
360                if let HookResult::ShortCircuit(resp) = engine.run_request_error(&mut err_ctx) {
361                    return (resp, no_id);
362                }
363            }
364
365            return (
366                Response::builder()
367                    .status(502)
368                    .body(Body::from("worker error"))
369                    .unwrap(),
370                no_id,
371            );
372        }
373    };
374
375    // ── Decode PHP response into parts ───────────────────────────────────────
376    let status = response_value
377        .get("status")
378        .and_then(|v| v.as_u64())
379        .unwrap_or(200) as u16;
380
381    let resp_headers: HashMap<String, String> = response_value
382        .get("headers")
383        .and_then(|v| v.as_object())
384        .map(|obj| {
385            obj.iter()
386                .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
387                .collect()
388        })
389        .unwrap_or_default();
390
391    let body_str = response_value
392        .get("body")
393        .and_then(|v| v.as_str())
394        .unwrap_or("")
395        .to_string();
396    let body_encoding = response_value
397        .get("body_encoding")
398        .and_then(|v| v.as_str())
399        .map(str::to_string);
400
401    // Update request_id in req_ctx for logging purposes.
402    if let Some(ref mut ctx) = req_ctx {
403        ctx.request_id = request_id.to_string();
404    }
405
406    // ── response.headers hooks ────────────────────────────────────────────────
407    let mut resp_ctx = ResponseContext {
408        status,
409        resp_headers,
410        body: None, // not yet available at this stage
411        short_circuited: false,
412    };
413
414    if let Some(ref engine) = state.hook_engine {
415        match engine.run_response_headers(&mut resp_ctx) {
416            HookResult::ShortCircuit(resp) => return (resp, request_id),
417            HookResult::Continue => {}
418        }
419    }
420
421    // ── Decode body bytes ────────────────────────────────────────────────────
422    let body_bytes = if body_encoding.as_deref() == Some("base64") {
423        use base64::Engine;
424        match base64::engine::general_purpose::STANDARD.decode(&body_str) {
425            Ok(b) => b,
426            Err(e) => {
427                error!(error = ?e, "decode base64 response body");
428                return (
429                    Response::builder()
430                        .status(500)
431                        .body(Body::from("decode error"))
432                        .unwrap(),
433                    request_id,
434                );
435            }
436        }
437    } else {
438        body_str.clone().into_bytes()
439    };
440
441    // ── response.after hooks ──────────────────────────────────────────────────
442    if let Some(ref engine) = state.hook_engine {
443        // Only clone body_bytes when response.after hooks actually need it.
444        if engine.has_event("response.after") {
445            resp_ctx.body = Some(body_bytes.clone());
446        }
447
448        match engine.run_response_after(&mut resp_ctx) {
449            HookResult::ShortCircuit(resp) => return (resp, request_id),
450            HookResult::Continue => {}
451        }
452    }
453
454    // ── Build final response ──────────────────────────────────────────────────
455    let final_body_bytes: bytes::Bytes = resp_ctx
456        .body
457        .take()
458        .map(bytes::Bytes::from)
459        .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
460
461    let mut builder = Response::builder().status(resp_ctx.status);
462    for (k, v) in &resp_ctx.resp_headers {
463        builder = builder.header(k.as_str(), v.as_str());
464    }
465
466    let response = match builder.body(Body::from(final_body_bytes)) {
467        Ok(r) => r,
468        Err(e) => {
469            error!(error = ?e, "build response");
470            Response::builder()
471                .status(500)
472                .body(Body::from("build error"))
473                .unwrap()
474        }
475    };
476
477    (response, request_id)
478}
479
480/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
481///
482/// Walks the X-Forwarded-For chain from right to left, stopping at the first
483/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
484/// (rightmost non-trusted).
485pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
486    if trusted.is_empty() {
487        return peer_ip;
488    }
489
490    if !is_trusted(peer_ip, trusted) {
491        return peer_ip;
492    }
493
494    let Some(xff) = xff else {
495        return peer_ip;
496    };
497
498    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
499
500    // Walk from right to left — the rightmost non-trusted IP is the client.
501    // An unparseable entry is treated as an untrusted boundary: stop and return peer_ip.
502    // Skipping garbage entries would allow a client to inject a fake trusted hop.
503    for addr_str in addrs.iter().rev() {
504        match addr_str.parse::<IpAddr>() {
505            Ok(ip) if !is_trusted(ip, trusted) => return ip,
506            Ok(_) => {}               // trusted hop, keep walking left
507            Err(_) => return peer_ip, // unparseable = untrusted boundary
508        }
509    }
510
511    // All IPs in the chain are trusted — use peer IP.
512    peer_ip
513}
514
515fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
516    trusted.iter().any(|net| net.contains(&ip))
517}
518
519fn build_compression_layer(
520    config: &crate::config::CompressionConfig,
521) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
522    use crate::config::CompressionAlgorithm;
523    use tower_http::compression::CompressionLayer;
524
525    let mut layer = CompressionLayer::new()
526        .no_gzip()
527        .no_br()
528        .no_zstd()
529        .no_deflate();
530
531    for algo in &config.algorithms {
532        layer = match algo {
533            CompressionAlgorithm::Gzip => layer.gzip(true),
534            CompressionAlgorithm::Br => layer.br(true),
535            CompressionAlgorithm::Zstd => layer.zstd(true),
536            CompressionAlgorithm::Deflate => layer.deflate(true),
537        };
538    }
539
540    // min_size is validated to fit u16 before this function is called (see PluginFactory::create)
541    let min_size =
542        u16::try_from(config.min_size).expect("min_size validated in PluginFactory::create");
543    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
544}