Skip to main content

folk_plugin_http/
server.rs

1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19use tracing::{error, info};
20
21use crate::config::HttpConfig;
22use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
23use crate::payload::encode_request;
24
25#[derive(Clone)]
26struct AppState {
27    executor: Arc<dyn folk_api::Executor>,
28    config: Arc<HttpConfig>,
29    active_connections: Arc<AtomicU64>,
30    hook_engine: Option<Arc<HookEngine>>,
31}
32
33pub struct HttpServer {
34    config: HttpConfig,
35    executor: Arc<dyn folk_api::Executor>,
36    active_connections: Arc<AtomicU64>,
37    hook_engine: Option<Arc<HookEngine>>,
38}
39
40impl HttpServer {
41    pub fn new(
42        config: HttpConfig,
43        executor: Arc<dyn folk_api::Executor>,
44        active_connections: Arc<AtomicU64>,
45        hook_engine: Option<Arc<HookEngine>>,
46    ) -> Self {
47        Self {
48            config,
49            executor,
50            active_connections,
51            hook_engine,
52        }
53    }
54
55    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
56        let state = AppState {
57            executor: self.executor.clone(),
58            config: Arc::new(self.config.clone()),
59            active_connections: self.active_connections.clone(),
60            hook_engine: self.hook_engine.clone(),
61        };
62
63        let mut app = Router::new()
64            .route("/{*path}", any(handle))
65            .route("/", any(handle))
66            .with_state(state)
67            .layer(
68                ServiceBuilder::new()
69                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
70                    .layer(TimeoutLayer::with_status_code(
71                        http::StatusCode::GATEWAY_TIMEOUT,
72                        self.config.write_timeout,
73                    )),
74            );
75
76        if self.config.compression.enabled {
77            app = app.layer(build_compression_layer(&self.config.compression));
78        }
79
80        #[cfg(feature = "tls")]
81        if let Some(ref tls) = self.config.tls {
82            return self.run_tls(app, tls, shutdown).await;
83        }
84
85        #[cfg(feature = "h2c")]
86        if self.config.h2c {
87            return self.run_h2c(app, shutdown).await;
88        }
89
90        self.run_plain(app, shutdown).await
91    }
92
93    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
94        let listener = TcpListener::bind(self.config.listen).await?;
95
96        axum::serve(
97            listener,
98            app.into_make_service_with_connect_info::<SocketAddr>(),
99        )
100        .with_graceful_shutdown(shutdown_signal(shutdown))
101        .await?;
102
103        Ok(())
104    }
105
106    #[cfg(feature = "tls")]
107    async fn run_tls(
108        &self,
109        app: Router,
110        tls: &crate::config::TlsConfig,
111        shutdown: watch::Receiver<bool>,
112    ) -> Result<()> {
113        use axum_server::Handle;
114        use axum_server::tls_rustls::RustlsConfig;
115
116        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
117
118        info!(cert = %tls.cert.display(), "TLS enabled");
119
120        let handle = Handle::new();
121        let shutdown_handle = handle.clone();
122        tokio::spawn(async move {
123            shutdown_signal(shutdown).await;
124            shutdown_handle.graceful_shutdown(None);
125        });
126
127        axum_server::bind_rustls(self.config.listen, rustls_config)
128            .handle(handle)
129            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
130            .await?;
131
132        Ok(())
133    }
134
135    #[cfg(feature = "h2c")]
136    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
137        use hyper_util::rt::{TokioExecutor, TokioIo};
138        use hyper_util::server::conn::auto::Builder as AutoBuilder;
139
140        info!("h2c (HTTP/2 cleartext) enabled");
141
142        let listener = TcpListener::bind(self.config.listen).await?;
143        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
144        let mut tasks = tokio::task::JoinSet::new();
145
146        loop {
147            tokio::select! {
148                result = listener.accept() => {
149                    let (stream, remote_addr) = result?;
150                    let app = app.clone();
151                    let builder = builder.clone();
152                    tasks.spawn(async move {
153                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
154                            // Inject ConnectInfo manually since we bypass axum::serve.
155                            req.extensions_mut().insert(ConnectInfo(remote_addr));
156                            let app = app.clone();
157                            async move {
158                                let resp = tower::Service::call(&mut app.clone(), req).await;
159                                resp.map_err(|e| match e {})
160                            }
161                        });
162                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
163                    });
164                }
165                _ = async {
166                    loop {
167                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
168                            break;
169                        }
170                    }
171                } => {
172                    break;
173                }
174            }
175        }
176
177        // Wait for active connections to finish
178        while tasks.join_next().await.is_some() {}
179
180        Ok(())
181    }
182}
183
184async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
185    loop {
186        if shutdown.changed().await.is_err() || *shutdown.borrow() {
187            break;
188        }
189    }
190}
191
192struct ConnectionGuard(Arc<AtomicU64>);
193
194impl Drop for ConnectionGuard {
195    fn drop(&mut self) {
196        self.0.fetch_sub(1, Ordering::Relaxed);
197    }
198}
199
200async fn handle(
201    State(state): State<AppState>,
202    connect_info: ConnectInfo<SocketAddr>,
203    req: Request<Body>,
204) -> Response<Body> {
205    state.active_connections.fetch_add(1, Ordering::Relaxed);
206    let _conn_guard = ConnectionGuard(state.active_connections.clone());
207    let start = Instant::now();
208    let method = req.method().clone();
209    let uri = req.uri().clone();
210    let peer_addr = connect_info.0;
211
212    let client_ip = resolve_client_ip(
213        peer_addr.ip(),
214        req.headers()
215            .get("x-forwarded-for")
216            .and_then(|v| v.to_str().ok()),
217        &state.config.trusted_proxies,
218    );
219
220    let (response, request_id) = handle_inner(&state, req, client_ip).await;
221
222    if state.config.access_log {
223        let duration = start.elapsed();
224        let status = response.status().as_u16();
225        let response_bytes = response
226            .headers()
227            .get(http::header::CONTENT_LENGTH)
228            .and_then(|v| v.to_str().ok())
229            .and_then(|v| v.parse::<u64>().ok())
230            .unwrap_or(0);
231        info!(
232            request_id = %request_id,
233            client_ip = %client_ip,
234            method = %method,
235            uri = %uri,
236            status = status,
237            duration_ms = duration.as_millis() as u64,
238            response_bytes = response_bytes,
239            "http request",
240        );
241    }
242
243    // _conn_guard drop handles fetch_sub
244    response
245}
246
247/// Handles the request and returns the response together with the `request_id`
248/// under which it was dispatched (empty before dispatch, e.g. on read errors).
249async fn handle_inner(
250    state: &AppState,
251    req: Request<Body>,
252    client_ip: IpAddr,
253) -> (Response<Body>, Arc<str>) {
254    let no_id: Arc<str> = Arc::from("");
255    let max_body = state.config.max_request_size;
256    let read_timeout = state.config.read_timeout;
257
258    // ── Extract request parts for hooks before consuming the body ────────────
259    let (parts, body) = req.into_parts();
260
261    let req_method = parts.method.to_string();
262    let req_path = parts.uri.path().to_string();
263    let req_query = parts.uri.query().unwrap_or("").to_string();
264    let req_headers: HashMap<String, String> = parts
265        .headers
266        .iter()
267        .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
268        .collect();
269
270    // Reassemble so encode_request can consume the body.
271    let req_reassembled = Request::from_parts(parts, body);
272
273    // ── Read + encode body ───────────────────────────────────────────────────
274    let payload =
275        match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
276            Ok(Ok(p)) => p,
277            Ok(Err(e)) => {
278                error!(error = ?e, "encode request");
279                return (
280                    Response::builder()
281                        .status(500)
282                        .body(Body::from("encode error"))
283                        .unwrap(),
284                    no_id,
285                );
286            }
287            Err(_) => {
288                return (
289                    Response::builder()
290                        .status(408)
291                        .body(Body::from("request body read timeout"))
292                        .unwrap(),
293                    no_id,
294                );
295            }
296        };
297
298    // ── request.before hooks ──────────────────────────────────────────────────
299    let mut req_ctx = RequestContext {
300        method: req_method,
301        path: req_path,
302        query: req_query,
303        client_ip: client_ip.to_string(),
304        request_id: String::new(), // filled in after dispatch
305        headers: req_headers,
306        extra: HashMap::new(),
307        error: None,
308        short_circuited: false,
309    };
310
311    if let Some(ref engine) = state.hook_engine {
312        match engine.run_request_before(&mut req_ctx) {
313            HookResult::ShortCircuit(resp) => {
314                // Async hooks already spawned inside run_request_before.
315                return (resp, no_id);
316            }
317            HookResult::Continue => {}
318        }
319    }
320
321    // ── Dispatch to PHP worker ───────────────────────────────────────────────
322    let (response_value, request_id) = match state
323        .executor
324        .execute_value_traced("http.handle", payload)
325        .await
326    {
327        Ok(v) => v,
328        Err(e) => {
329            error!(error = ?e, "dispatch to worker");
330
331            // Fire request.error hooks.
332            if let Some(ref engine) = state.hook_engine {
333                let mut err_ctx = req_ctx.clone();
334                err_ctx.request_id = no_id.to_string();
335                err_ctx.error = Some(e.to_string());
336                engine.run_request_error(&err_ctx);
337            }
338
339            return (
340                Response::builder()
341                    .status(502)
342                    .body(Body::from("worker error"))
343                    .unwrap(),
344                no_id,
345            );
346        }
347    };
348
349    // ── Decode PHP response into parts ───────────────────────────────────────
350    let status = response_value
351        .get("status")
352        .and_then(|v| v.as_u64())
353        .unwrap_or(200) as u16;
354
355    let resp_headers: HashMap<String, String> = response_value
356        .get("headers")
357        .and_then(|v| v.as_object())
358        .map(|obj| {
359            obj.iter()
360                .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
361                .collect()
362        })
363        .unwrap_or_default();
364
365    let body_str = response_value
366        .get("body")
367        .and_then(|v| v.as_str())
368        .unwrap_or("")
369        .to_string();
370    let body_encoding = response_value
371        .get("body_encoding")
372        .and_then(|v| v.as_str())
373        .map(str::to_string);
374
375    // Update request_id in req_ctx for logging purposes.
376    req_ctx.request_id = request_id.to_string();
377
378    // ── response.headers hooks ────────────────────────────────────────────────
379    let mut resp_ctx = ResponseContext {
380        status,
381        resp_headers,
382        body: None, // not yet available at this stage
383        short_circuited: false,
384    };
385
386    if let Some(ref engine) = state.hook_engine {
387        match engine.run_response_headers(&mut resp_ctx) {
388            HookResult::ShortCircuit(resp) => return (resp, request_id),
389            HookResult::Continue => {}
390        }
391    }
392
393    // ── Decode body bytes ────────────────────────────────────────────────────
394    let body_bytes = if body_encoding.as_deref() == Some("base64") {
395        use base64::Engine;
396        match base64::engine::general_purpose::STANDARD.decode(&body_str) {
397            Ok(b) => b,
398            Err(e) => {
399                error!(error = ?e, "decode base64 response body");
400                return (
401                    Response::builder()
402                        .status(500)
403                        .body(Body::from("decode error"))
404                        .unwrap(),
405                    request_id,
406                );
407            }
408        }
409    } else {
410        body_str.clone().into_bytes()
411    };
412
413    // ── response.after hooks ──────────────────────────────────────────────────
414    if let Some(ref engine) = state.hook_engine {
415        // Only clone body_bytes when response.after hooks actually need it.
416        if engine.has_event("response.after") {
417            resp_ctx.body = Some(body_bytes.clone());
418        }
419
420        match engine.run_response_after(&mut resp_ctx) {
421            HookResult::ShortCircuit(resp) => return (resp, request_id),
422            HookResult::Continue => {}
423        }
424    }
425
426    // ── Build final response ──────────────────────────────────────────────────
427    let final_body_bytes: bytes::Bytes = resp_ctx
428        .body
429        .take()
430        .map(bytes::Bytes::from)
431        .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
432
433    let mut builder = Response::builder().status(resp_ctx.status);
434    for (k, v) in &resp_ctx.resp_headers {
435        builder = builder.header(k.as_str(), v.as_str());
436    }
437
438    let response = match builder.body(Body::from(final_body_bytes)) {
439        Ok(r) => r,
440        Err(e) => {
441            error!(error = ?e, "build response");
442            Response::builder()
443                .status(500)
444                .body(Body::from("build error"))
445                .unwrap()
446        }
447    };
448
449    (response, request_id)
450}
451
452/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
453///
454/// Walks the X-Forwarded-For chain from right to left, stopping at the first
455/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
456/// (rightmost non-trusted).
457pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
458    if trusted.is_empty() {
459        return peer_ip;
460    }
461
462    if !is_trusted(peer_ip, trusted) {
463        return peer_ip;
464    }
465
466    let Some(xff) = xff else {
467        return peer_ip;
468    };
469
470    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
471
472    // Walk from right to left — the rightmost non-trusted IP is the client.
473    for addr_str in addrs.iter().rev() {
474        if let Ok(ip) = addr_str.parse::<IpAddr>() {
475            if !is_trusted(ip, trusted) {
476                return ip;
477            }
478        }
479    }
480
481    // All IPs in the chain are trusted — use peer IP.
482    peer_ip
483}
484
485fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
486    trusted.iter().any(|net| net.contains(&ip))
487}
488
489fn build_compression_layer(
490    config: &crate::config::CompressionConfig,
491) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
492    use crate::config::CompressionAlgorithm;
493    use tower_http::compression::CompressionLayer;
494
495    let mut layer = CompressionLayer::new()
496        .no_gzip()
497        .no_br()
498        .no_zstd()
499        .no_deflate();
500
501    for algo in &config.algorithms {
502        layer = match algo {
503            CompressionAlgorithm::Gzip => layer.gzip(true),
504            CompressionAlgorithm::Br => layer.br(true),
505            CompressionAlgorithm::Zstd => layer.zstd(true),
506            CompressionAlgorithm::Deflate => layer.deflate(true),
507        };
508    }
509
510    #[allow(clippy::cast_possible_truncation)]
511    let min_size = config.min_size as u16;
512    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
513}