Skip to main content

folk_plugin_http/
server.rs

1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19use tracing::{error, info};
20
21use crate::config::HttpConfig;
22use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
23use crate::payload::encode_request;
24
25#[derive(Clone)]
26struct AppState {
27    executor: Arc<dyn folk_api::Executor>,
28    config: Arc<HttpConfig>,
29    active_connections: Arc<AtomicU64>,
30    hook_engine: Option<Arc<HookEngine>>,
31}
32
33pub struct HttpServer {
34    config: HttpConfig,
35    executor: Arc<dyn folk_api::Executor>,
36    active_connections: Arc<AtomicU64>,
37    hook_engine: Option<Arc<HookEngine>>,
38}
39
40impl HttpServer {
41    pub fn new(
42        config: HttpConfig,
43        executor: Arc<dyn folk_api::Executor>,
44        active_connections: Arc<AtomicU64>,
45        hook_engine: Option<Arc<HookEngine>>,
46    ) -> Self {
47        Self {
48            config,
49            executor,
50            active_connections,
51            hook_engine,
52        }
53    }
54
55    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
56        let state = AppState {
57            executor: self.executor.clone(),
58            config: Arc::new(self.config.clone()),
59            active_connections: self.active_connections.clone(),
60            hook_engine: self.hook_engine.clone(),
61        };
62
63        let mut app = Router::new()
64            .route("/{*path}", any(handle))
65            .route("/", any(handle))
66            .with_state(state)
67            .layer(
68                ServiceBuilder::new()
69                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
70                    .layer(TimeoutLayer::with_status_code(
71                        http::StatusCode::GATEWAY_TIMEOUT,
72                        self.config.write_timeout,
73                    )),
74            );
75
76        if self.config.compression.enabled {
77            app = app.layer(build_compression_layer(&self.config.compression));
78        }
79
80        #[cfg(feature = "tls")]
81        if let Some(ref tls) = self.config.tls {
82            return self.run_tls(app, tls, shutdown).await;
83        }
84
85        #[cfg(feature = "h2c")]
86        if self.config.h2c {
87            return self.run_h2c(app, shutdown).await;
88        }
89
90        self.run_plain(app, shutdown).await
91    }
92
93    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
94        let listener = TcpListener::bind(self.config.listen).await?;
95
96        axum::serve(
97            listener,
98            app.into_make_service_with_connect_info::<SocketAddr>(),
99        )
100        .with_graceful_shutdown(shutdown_signal(shutdown))
101        .await?;
102
103        Ok(())
104    }
105
106    #[cfg(feature = "tls")]
107    async fn run_tls(
108        &self,
109        app: Router,
110        tls: &crate::config::TlsConfig,
111        shutdown: watch::Receiver<bool>,
112    ) -> Result<()> {
113        use axum_server::Handle;
114        use axum_server::tls_rustls::RustlsConfig;
115
116        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
117
118        info!(cert = %tls.cert.display(), "TLS enabled");
119
120        let handle = Handle::new();
121        let shutdown_handle = handle.clone();
122        tokio::spawn(async move {
123            shutdown_signal(shutdown).await;
124            shutdown_handle.graceful_shutdown(None);
125        });
126
127        axum_server::bind_rustls(self.config.listen, rustls_config)
128            .handle(handle)
129            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
130            .await?;
131
132        Ok(())
133    }
134
135    #[cfg(feature = "h2c")]
136    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
137        use hyper_util::rt::{TokioExecutor, TokioIo};
138        use hyper_util::server::conn::auto::Builder as AutoBuilder;
139
140        info!("h2c (HTTP/2 cleartext) enabled");
141
142        let listener = TcpListener::bind(self.config.listen).await?;
143        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
144        let mut tasks = tokio::task::JoinSet::new();
145
146        loop {
147            tokio::select! {
148                result = listener.accept() => {
149                    let (stream, remote_addr) = result?;
150                    let app = app.clone();
151                    let builder = builder.clone();
152                    tasks.spawn(async move {
153                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
154                            // Inject ConnectInfo manually since we bypass axum::serve.
155                            req.extensions_mut().insert(ConnectInfo(remote_addr));
156                            let app = app.clone();
157                            async move {
158                                let resp = tower::Service::call(&mut app.clone(), req).await;
159                                resp.map_err(|e| match e {})
160                            }
161                        });
162                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
163                    });
164                }
165                _ = async {
166                    loop {
167                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
168                            break;
169                        }
170                    }
171                } => {
172                    break;
173                }
174            }
175        }
176
177        // Wait for active connections to finish
178        while tasks.join_next().await.is_some() {}
179
180        Ok(())
181    }
182}
183
184async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
185    loop {
186        if shutdown.changed().await.is_err() || *shutdown.borrow() {
187            break;
188        }
189    }
190}
191
192struct ConnectionGuard(Arc<AtomicU64>);
193
194impl Drop for ConnectionGuard {
195    fn drop(&mut self) {
196        self.0.fetch_sub(1, Ordering::Relaxed);
197    }
198}
199
200async fn handle(
201    State(state): State<AppState>,
202    connect_info: ConnectInfo<SocketAddr>,
203    req: Request<Body>,
204) -> Response<Body> {
205    state.active_connections.fetch_add(1, Ordering::Relaxed);
206    let _conn_guard = ConnectionGuard(state.active_connections.clone());
207    let start = Instant::now();
208    let method = req.method().clone();
209    let uri = req.uri().clone();
210    let peer_addr = connect_info.0;
211
212    let client_ip = resolve_client_ip(
213        peer_addr.ip(),
214        req.headers()
215            .get("x-forwarded-for")
216            .and_then(|v| v.to_str().ok()),
217        &state.config.trusted_proxies,
218    );
219
220    let (response, request_id) = handle_inner(&state, req, client_ip).await;
221
222    if state.config.access_log {
223        let duration = start.elapsed();
224        let status = response.status().as_u16();
225        let response_bytes = response
226            .headers()
227            .get(http::header::CONTENT_LENGTH)
228            .and_then(|v| v.to_str().ok())
229            .and_then(|v| v.parse::<u64>().ok())
230            .unwrap_or(0);
231        info!(
232            request_id = %request_id,
233            client_ip = %client_ip,
234            method = %method,
235            uri = %uri,
236            status = status,
237            duration_ms = duration.as_millis() as u64,
238            response_bytes = response_bytes,
239            "http request",
240        );
241    }
242
243    // _conn_guard drop handles fetch_sub
244    response
245}
246
247/// Handles the request and returns the response together with the `request_id`
248/// under which it was dispatched (empty before dispatch, e.g. on read errors).
249async fn handle_inner(
250    state: &AppState,
251    req: Request<Body>,
252    client_ip: IpAddr,
253) -> (Response<Body>, Arc<str>) {
254    let no_id: Arc<str> = Arc::from("");
255    let max_body = state.config.max_request_size;
256    let read_timeout = state.config.read_timeout;
257
258    // ── Extract request parts for hooks before consuming the body ────────────
259    let (parts, body) = req.into_parts();
260
261    let req_method = parts.method.to_string();
262    let req_path = parts.uri.path().to_string();
263    let req_query = parts.uri.query().unwrap_or("").to_string();
264    let req_headers: HashMap<String, String> = parts
265        .headers
266        .iter()
267        .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
268        .collect();
269
270    // Reassemble so encode_request can consume the body.
271    let req_reassembled = Request::from_parts(parts, body);
272
273    // ── Read + encode body ───────────────────────────────────────────────────
274    let payload =
275        match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
276            Ok(Ok(p)) => p,
277            Ok(Err(e)) => {
278                error!(error = ?e, "encode request");
279                return (
280                    Response::builder()
281                        .status(500)
282                        .body(Body::from("encode error"))
283                        .unwrap(),
284                    no_id,
285                );
286            }
287            Err(_) => {
288                return (
289                    Response::builder()
290                        .status(408)
291                        .body(Body::from("request body read timeout"))
292                        .unwrap(),
293                    no_id,
294                );
295            }
296        };
297
298    // ── request.before hooks ──────────────────────────────────────────────────
299    let mut req_ctx = RequestContext {
300        method: req_method,
301        path: req_path,
302        query: req_query,
303        client_ip: client_ip.to_string(),
304        request_id: String::new(), // filled in after dispatch
305        headers: req_headers,
306        extra: HashMap::new(),
307        error: None,
308        short_circuited: false,
309    };
310
311    if let Some(ref engine) = state.hook_engine {
312        match engine.run_request_before(&mut req_ctx) {
313            HookResult::ShortCircuit(resp) => {
314                // Async hooks already spawned inside run_request_before.
315                return (resp, no_id);
316            }
317            HookResult::Continue => {}
318        }
319    }
320
321    // ── Dispatch to PHP worker ───────────────────────────────────────────────
322    let (response_value, request_id) = match state
323        .executor
324        .execute_value_traced("http.handle", payload)
325        .await
326    {
327        Ok(v) => v,
328        Err(e) => {
329            error!(error = ?e, "dispatch to worker");
330
331            // Fire request.error hooks.
332            if let Some(ref engine) = state.hook_engine {
333                let mut err_ctx = req_ctx.clone();
334                err_ctx.request_id = no_id.to_string();
335                err_ctx.error = Some(e.to_string());
336                if let HookResult::ShortCircuit(resp) = engine.run_request_error(&mut err_ctx) {
337                    return (resp, no_id);
338                }
339            }
340
341            return (
342                Response::builder()
343                    .status(502)
344                    .body(Body::from("worker error"))
345                    .unwrap(),
346                no_id,
347            );
348        }
349    };
350
351    // ── Decode PHP response into parts ───────────────────────────────────────
352    let status = response_value
353        .get("status")
354        .and_then(|v| v.as_u64())
355        .unwrap_or(200) as u16;
356
357    let resp_headers: HashMap<String, String> = response_value
358        .get("headers")
359        .and_then(|v| v.as_object())
360        .map(|obj| {
361            obj.iter()
362                .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
363                .collect()
364        })
365        .unwrap_or_default();
366
367    let body_str = response_value
368        .get("body")
369        .and_then(|v| v.as_str())
370        .unwrap_or("")
371        .to_string();
372    let body_encoding = response_value
373        .get("body_encoding")
374        .and_then(|v| v.as_str())
375        .map(str::to_string);
376
377    // Update request_id in req_ctx for logging purposes.
378    req_ctx.request_id = request_id.to_string();
379
380    // ── response.headers hooks ────────────────────────────────────────────────
381    let mut resp_ctx = ResponseContext {
382        status,
383        resp_headers,
384        body: None, // not yet available at this stage
385        short_circuited: false,
386    };
387
388    if let Some(ref engine) = state.hook_engine {
389        match engine.run_response_headers(&mut resp_ctx) {
390            HookResult::ShortCircuit(resp) => return (resp, request_id),
391            HookResult::Continue => {}
392        }
393    }
394
395    // ── Decode body bytes ────────────────────────────────────────────────────
396    let body_bytes = if body_encoding.as_deref() == Some("base64") {
397        use base64::Engine;
398        match base64::engine::general_purpose::STANDARD.decode(&body_str) {
399            Ok(b) => b,
400            Err(e) => {
401                error!(error = ?e, "decode base64 response body");
402                return (
403                    Response::builder()
404                        .status(500)
405                        .body(Body::from("decode error"))
406                        .unwrap(),
407                    request_id,
408                );
409            }
410        }
411    } else {
412        body_str.clone().into_bytes()
413    };
414
415    // ── response.after hooks ──────────────────────────────────────────────────
416    if let Some(ref engine) = state.hook_engine {
417        // Only clone body_bytes when response.after hooks actually need it.
418        if engine.has_event("response.after") {
419            resp_ctx.body = Some(body_bytes.clone());
420        }
421
422        match engine.run_response_after(&mut resp_ctx) {
423            HookResult::ShortCircuit(resp) => return (resp, request_id),
424            HookResult::Continue => {}
425        }
426    }
427
428    // ── Build final response ──────────────────────────────────────────────────
429    let final_body_bytes: bytes::Bytes = resp_ctx
430        .body
431        .take()
432        .map(bytes::Bytes::from)
433        .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
434
435    let mut builder = Response::builder().status(resp_ctx.status);
436    for (k, v) in &resp_ctx.resp_headers {
437        builder = builder.header(k.as_str(), v.as_str());
438    }
439
440    let response = match builder.body(Body::from(final_body_bytes)) {
441        Ok(r) => r,
442        Err(e) => {
443            error!(error = ?e, "build response");
444            Response::builder()
445                .status(500)
446                .body(Body::from("build error"))
447                .unwrap()
448        }
449    };
450
451    (response, request_id)
452}
453
454/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
455///
456/// Walks the X-Forwarded-For chain from right to left, stopping at the first
457/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
458/// (rightmost non-trusted).
459pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
460    if trusted.is_empty() {
461        return peer_ip;
462    }
463
464    if !is_trusted(peer_ip, trusted) {
465        return peer_ip;
466    }
467
468    let Some(xff) = xff else {
469        return peer_ip;
470    };
471
472    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
473
474    // Walk from right to left — the rightmost non-trusted IP is the client.
475    // An unparseable entry is treated as an untrusted boundary: stop and return peer_ip.
476    // Skipping garbage entries would allow a client to inject a fake trusted hop.
477    for addr_str in addrs.iter().rev() {
478        match addr_str.parse::<IpAddr>() {
479            Ok(ip) if !is_trusted(ip, trusted) => return ip,
480            Ok(_) => {}               // trusted hop, keep walking left
481            Err(_) => return peer_ip, // unparseable = untrusted boundary
482        }
483    }
484
485    // All IPs in the chain are trusted — use peer IP.
486    peer_ip
487}
488
489fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
490    trusted.iter().any(|net| net.contains(&ip))
491}
492
493fn build_compression_layer(
494    config: &crate::config::CompressionConfig,
495) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
496    use crate::config::CompressionAlgorithm;
497    use tower_http::compression::CompressionLayer;
498
499    let mut layer = CompressionLayer::new()
500        .no_gzip()
501        .no_br()
502        .no_zstd()
503        .no_deflate();
504
505    for algo in &config.algorithms {
506        layer = match algo {
507            CompressionAlgorithm::Gzip => layer.gzip(true),
508            CompressionAlgorithm::Br => layer.br(true),
509            CompressionAlgorithm::Zstd => layer.zstd(true),
510            CompressionAlgorithm::Deflate => layer.deflate(true),
511        };
512    }
513
514    #[allow(clippy::cast_possible_truncation)]
515    let min_size = config.min_size as u16;
516    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
517}