Skip to main content

folk_plugin_http/
server.rs

1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19use tracing::{error, info, warn};
20
21use crate::config::HttpConfig;
22use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
23use crate::payload::encode_request;
24
25#[derive(Clone)]
26struct AppState {
27    executor: Arc<dyn folk_api::Executor>,
28    config: Arc<HttpConfig>,
29    active_connections: Arc<AtomicU64>,
30    hook_engine: Option<Arc<HookEngine>>,
31}
32
33pub struct HttpServer {
34    config: HttpConfig,
35    executor: Arc<dyn folk_api::Executor>,
36    active_connections: Arc<AtomicU64>,
37    hook_engine: Option<Arc<HookEngine>>,
38}
39
40impl HttpServer {
41    pub fn new(
42        config: HttpConfig,
43        executor: Arc<dyn folk_api::Executor>,
44        active_connections: Arc<AtomicU64>,
45        hook_engine: Option<Arc<HookEngine>>,
46    ) -> Self {
47        Self {
48            config,
49            executor,
50            active_connections,
51            hook_engine,
52        }
53    }
54
55    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
56        let state = AppState {
57            executor: self.executor.clone(),
58            config: Arc::new(self.config.clone()),
59            active_connections: self.active_connections.clone(),
60            hook_engine: self.hook_engine.clone(),
61        };
62
63        let mut app = Router::new()
64            .route("/{*path}", any(handle))
65            .route("/", any(handle))
66            .with_state(state)
67            .layer(
68                ServiceBuilder::new()
69                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
70                    .layer(TimeoutLayer::with_status_code(
71                        http::StatusCode::GATEWAY_TIMEOUT,
72                        self.config.write_timeout,
73                    )),
74            );
75
76        if self.config.compression.enabled {
77            app = app.layer(build_compression_layer(&self.config.compression));
78        }
79
80        #[cfg(feature = "tls")]
81        if let Some(ref tls) = self.config.tls {
82            return self.run_tls(app, tls, shutdown).await;
83        }
84
85        #[cfg(feature = "h2c")]
86        if self.config.h2c {
87            return self.run_h2c(app, shutdown).await;
88        }
89
90        self.run_plain(app, shutdown).await
91    }
92
93    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
94        let listener = TcpListener::bind(self.config.listen).await?;
95
96        axum::serve(
97            listener,
98            app.into_make_service_with_connect_info::<SocketAddr>(),
99        )
100        .with_graceful_shutdown(shutdown_signal(shutdown))
101        .await?;
102
103        Ok(())
104    }
105
106    #[cfg(feature = "tls")]
107    async fn run_tls(
108        &self,
109        app: Router,
110        tls: &crate::config::TlsConfig,
111        shutdown: watch::Receiver<bool>,
112    ) -> Result<()> {
113        use axum_server::Handle;
114        use axum_server::tls_rustls::RustlsConfig;
115
116        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
117
118        info!(cert = %tls.cert.display(), "TLS enabled");
119
120        let handle = Handle::new();
121        let shutdown_handle = handle.clone();
122        tokio::spawn(async move {
123            shutdown_signal(shutdown).await;
124            shutdown_handle.graceful_shutdown(None);
125        });
126
127        axum_server::bind_rustls(self.config.listen, rustls_config)
128            .handle(handle)
129            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
130            .await?;
131
132        Ok(())
133    }
134
135    #[cfg(feature = "h2c")]
136    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
137        use hyper_util::rt::{TokioExecutor, TokioIo};
138        use hyper_util::server::conn::auto::Builder as AutoBuilder;
139
140        info!("h2c (HTTP/2 cleartext) enabled");
141
142        let listener = TcpListener::bind(self.config.listen).await?;
143        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
144        let mut tasks = tokio::task::JoinSet::new();
145
146        loop {
147            tokio::select! {
148                result = listener.accept() => {
149                    let (stream, remote_addr) = result?;
150                    let app = app.clone();
151                    let builder = builder.clone();
152                    tasks.spawn(async move {
153                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
154                            // Inject ConnectInfo manually since we bypass axum::serve.
155                            req.extensions_mut().insert(ConnectInfo(remote_addr));
156                            let app = app.clone();
157                            async move {
158                                let resp = tower::Service::call(&mut app.clone(), req).await;
159                                resp.map_err(|e| match e {})
160                            }
161                        });
162                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
163                    });
164                }
165                _ = async {
166                    loop {
167                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
168                            break;
169                        }
170                    }
171                } => {
172                    break;
173                }
174            }
175        }
176
177        // Drain in-flight connections, but only up to shutdown_timeout.
178        let deadline = tokio::time::sleep(self.config.shutdown_timeout);
179        tokio::pin!(deadline);
180        loop {
181            tokio::select! {
182                _ = &mut deadline => {
183                    let remaining = tasks.len();
184                    if remaining > 0 {
185                        warn!(remaining, "h2c graceful shutdown timed out; aborting connections");
186                        tasks.abort_all();
187                        while tasks.join_next().await.is_some() {}
188                    }
189                    break;
190                }
191                result = tasks.join_next() => {
192                    if result.is_none() {
193                        break;
194                    }
195                }
196            }
197        }
198
199        Ok(())
200    }
201}
202
203async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
204    loop {
205        if shutdown.changed().await.is_err() || *shutdown.borrow() {
206            break;
207        }
208    }
209}
210
211struct ConnectionGuard(Arc<AtomicU64>);
212
213impl Drop for ConnectionGuard {
214    fn drop(&mut self) {
215        self.0.fetch_sub(1, Ordering::Relaxed);
216    }
217}
218
219async fn handle(
220    State(state): State<AppState>,
221    connect_info: ConnectInfo<SocketAddr>,
222    req: Request<Body>,
223) -> Response<Body> {
224    state.active_connections.fetch_add(1, Ordering::Relaxed);
225    let _conn_guard = ConnectionGuard(state.active_connections.clone());
226    let start = Instant::now();
227    let method = req.method().clone();
228    let uri = req.uri().clone();
229    let peer_addr = connect_info.0;
230
231    let client_ip = resolve_client_ip(
232        peer_addr.ip(),
233        req.headers()
234            .get("x-forwarded-for")
235            .and_then(|v| v.to_str().ok()),
236        &state.config.trusted_proxies,
237    );
238
239    let (response, request_id) = handle_inner(&state, req, client_ip).await;
240
241    if state.config.access_log {
242        let duration = start.elapsed();
243        let status = response.status().as_u16();
244        let response_bytes = response
245            .headers()
246            .get(http::header::CONTENT_LENGTH)
247            .and_then(|v| v.to_str().ok())
248            .and_then(|v| v.parse::<u64>().ok())
249            .unwrap_or(0);
250        info!(
251            request_id = %request_id,
252            client_ip = %client_ip,
253            method = %method,
254            uri = %uri,
255            status = status,
256            duration_ms = duration.as_millis() as u64,
257            response_bytes = response_bytes,
258            "http request",
259        );
260    }
261
262    // _conn_guard drop handles fetch_sub
263    response
264}
265
266/// Handles the request and returns the response together with the `request_id`
267/// under which it was dispatched (empty before dispatch, e.g. on read errors).
268async fn handle_inner(
269    state: &AppState,
270    req: Request<Body>,
271    client_ip: IpAddr,
272) -> (Response<Body>, Arc<str>) {
273    let no_id: Arc<str> = Arc::from("");
274    let max_body = state.config.max_request_size;
275    let read_timeout = state.config.read_timeout;
276
277    // ── Extract request parts for hooks before consuming the body ────────────
278    let (parts, body) = req.into_parts();
279
280    let req_method = parts.method.to_string();
281    let req_path = parts.uri.path().to_string();
282    let req_query = parts.uri.query().unwrap_or("").to_string();
283    let req_headers: HashMap<String, String> = parts
284        .headers
285        .iter()
286        .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
287        .collect();
288
289    // Reassemble so encode_request can consume the body.
290    let req_reassembled = Request::from_parts(parts, body);
291
292    // ── Read + encode body ───────────────────────────────────────────────────
293    let payload =
294        match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
295            Ok(Ok(p)) => p,
296            Ok(Err(e)) => {
297                error!(error = ?e, "encode request");
298                return (
299                    Response::builder()
300                        .status(500)
301                        .body(Body::from("encode error"))
302                        .unwrap(),
303                    no_id,
304                );
305            }
306            Err(_) => {
307                return (
308                    Response::builder()
309                        .status(408)
310                        .body(Body::from("request body read timeout"))
311                        .unwrap(),
312                    no_id,
313                );
314            }
315        };
316
317    // ── request.before hooks ──────────────────────────────────────────────────
318    let mut req_ctx = RequestContext {
319        method: req_method,
320        path: req_path,
321        query: req_query,
322        client_ip: client_ip.to_string(),
323        request_id: String::new(), // filled in after dispatch
324        headers: req_headers,
325        extra: HashMap::new(),
326        error: None,
327        short_circuited: false,
328    };
329
330    if let Some(ref engine) = state.hook_engine {
331        match engine.run_request_before(&mut req_ctx) {
332            HookResult::ShortCircuit(resp) => {
333                // Async hooks already spawned inside run_request_before.
334                return (resp, no_id);
335            }
336            HookResult::Continue => {}
337        }
338    }
339
340    // ── Dispatch to PHP worker ───────────────────────────────────────────────
341    let (response_value, request_id) = match state
342        .executor
343        .execute_value_traced("http.handle", payload)
344        .await
345    {
346        Ok(v) => v,
347        Err(e) => {
348            error!(error = ?e, "dispatch to worker");
349
350            // Fire request.error hooks.
351            if let Some(ref engine) = state.hook_engine {
352                let mut err_ctx = req_ctx.clone();
353                err_ctx.request_id = no_id.to_string();
354                err_ctx.error = Some(e.to_string());
355                if let HookResult::ShortCircuit(resp) = engine.run_request_error(&mut err_ctx) {
356                    return (resp, no_id);
357                }
358            }
359
360            return (
361                Response::builder()
362                    .status(502)
363                    .body(Body::from("worker error"))
364                    .unwrap(),
365                no_id,
366            );
367        }
368    };
369
370    // ── Decode PHP response into parts ───────────────────────────────────────
371    let status = response_value
372        .get("status")
373        .and_then(|v| v.as_u64())
374        .unwrap_or(200) as u16;
375
376    let resp_headers: HashMap<String, String> = response_value
377        .get("headers")
378        .and_then(|v| v.as_object())
379        .map(|obj| {
380            obj.iter()
381                .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
382                .collect()
383        })
384        .unwrap_or_default();
385
386    let body_str = response_value
387        .get("body")
388        .and_then(|v| v.as_str())
389        .unwrap_or("")
390        .to_string();
391    let body_encoding = response_value
392        .get("body_encoding")
393        .and_then(|v| v.as_str())
394        .map(str::to_string);
395
396    // Update request_id in req_ctx for logging purposes.
397    req_ctx.request_id = request_id.to_string();
398
399    // ── response.headers hooks ────────────────────────────────────────────────
400    let mut resp_ctx = ResponseContext {
401        status,
402        resp_headers,
403        body: None, // not yet available at this stage
404        short_circuited: false,
405    };
406
407    if let Some(ref engine) = state.hook_engine {
408        match engine.run_response_headers(&mut resp_ctx) {
409            HookResult::ShortCircuit(resp) => return (resp, request_id),
410            HookResult::Continue => {}
411        }
412    }
413
414    // ── Decode body bytes ────────────────────────────────────────────────────
415    let body_bytes = if body_encoding.as_deref() == Some("base64") {
416        use base64::Engine;
417        match base64::engine::general_purpose::STANDARD.decode(&body_str) {
418            Ok(b) => b,
419            Err(e) => {
420                error!(error = ?e, "decode base64 response body");
421                return (
422                    Response::builder()
423                        .status(500)
424                        .body(Body::from("decode error"))
425                        .unwrap(),
426                    request_id,
427                );
428            }
429        }
430    } else {
431        body_str.clone().into_bytes()
432    };
433
434    // ── response.after hooks ──────────────────────────────────────────────────
435    if let Some(ref engine) = state.hook_engine {
436        // Only clone body_bytes when response.after hooks actually need it.
437        if engine.has_event("response.after") {
438            resp_ctx.body = Some(body_bytes.clone());
439        }
440
441        match engine.run_response_after(&mut resp_ctx) {
442            HookResult::ShortCircuit(resp) => return (resp, request_id),
443            HookResult::Continue => {}
444        }
445    }
446
447    // ── Build final response ──────────────────────────────────────────────────
448    let final_body_bytes: bytes::Bytes = resp_ctx
449        .body
450        .take()
451        .map(bytes::Bytes::from)
452        .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
453
454    let mut builder = Response::builder().status(resp_ctx.status);
455    for (k, v) in &resp_ctx.resp_headers {
456        builder = builder.header(k.as_str(), v.as_str());
457    }
458
459    let response = match builder.body(Body::from(final_body_bytes)) {
460        Ok(r) => r,
461        Err(e) => {
462            error!(error = ?e, "build response");
463            Response::builder()
464                .status(500)
465                .body(Body::from("build error"))
466                .unwrap()
467        }
468    };
469
470    (response, request_id)
471}
472
473/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
474///
475/// Walks the X-Forwarded-For chain from right to left, stopping at the first
476/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
477/// (rightmost non-trusted).
478pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
479    if trusted.is_empty() {
480        return peer_ip;
481    }
482
483    if !is_trusted(peer_ip, trusted) {
484        return peer_ip;
485    }
486
487    let Some(xff) = xff else {
488        return peer_ip;
489    };
490
491    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
492
493    // Walk from right to left — the rightmost non-trusted IP is the client.
494    // An unparseable entry is treated as an untrusted boundary: stop and return peer_ip.
495    // Skipping garbage entries would allow a client to inject a fake trusted hop.
496    for addr_str in addrs.iter().rev() {
497        match addr_str.parse::<IpAddr>() {
498            Ok(ip) if !is_trusted(ip, trusted) => return ip,
499            Ok(_) => {}               // trusted hop, keep walking left
500            Err(_) => return peer_ip, // unparseable = untrusted boundary
501        }
502    }
503
504    // All IPs in the chain are trusted — use peer IP.
505    peer_ip
506}
507
508fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
509    trusted.iter().any(|net| net.contains(&ip))
510}
511
512fn build_compression_layer(
513    config: &crate::config::CompressionConfig,
514) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
515    use crate::config::CompressionAlgorithm;
516    use tower_http::compression::CompressionLayer;
517
518    let mut layer = CompressionLayer::new()
519        .no_gzip()
520        .no_br()
521        .no_zstd()
522        .no_deflate();
523
524    for algo in &config.algorithms {
525        layer = match algo {
526            CompressionAlgorithm::Gzip => layer.gzip(true),
527            CompressionAlgorithm::Br => layer.br(true),
528            CompressionAlgorithm::Zstd => layer.zstd(true),
529            CompressionAlgorithm::Deflate => layer.deflate(true),
530        };
531    }
532
533    #[allow(clippy::cast_possible_truncation)]
534    let min_size = config.min_size as u16;
535    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
536}