Skip to main content

folk_plugin_http/
server.rs

1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Instant;
5
6use anyhow::Result;
7use axum::Router;
8use axum::body::Body;
9use axum::extract::{ConnectInfo, State};
10use axum::http::{self, Request, Response};
11use axum::routing::any;
12use ipnet::IpNet;
13use tokio::net::TcpListener;
14use tokio::sync::watch;
15use tower::ServiceBuilder;
16use tower_http::limit::RequestBodyLimitLayer;
17use tower_http::timeout::TimeoutLayer;
18use tracing::{error, info};
19
20use crate::config::HttpConfig;
21use crate::payload::{decode_response, encode_request};
22
23#[derive(Clone)]
24struct AppState {
25    executor: Arc<dyn folk_api::Executor>,
26    config: Arc<HttpConfig>,
27    active_connections: Arc<AtomicU64>,
28}
29
30pub struct HttpServer {
31    config: HttpConfig,
32    executor: Arc<dyn folk_api::Executor>,
33    active_connections: Arc<AtomicU64>,
34}
35
36impl HttpServer {
37    pub fn new(
38        config: HttpConfig,
39        executor: Arc<dyn folk_api::Executor>,
40        active_connections: Arc<AtomicU64>,
41    ) -> Self {
42        Self {
43            config,
44            executor,
45            active_connections,
46        }
47    }
48
49    pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
50        let state = AppState {
51            executor: self.executor.clone(),
52            config: Arc::new(self.config.clone()),
53            active_connections: self.active_connections.clone(),
54        };
55
56        let mut app = Router::new()
57            .route("/{*path}", any(handle))
58            .route("/", any(handle))
59            .with_state(state)
60            .layer(
61                ServiceBuilder::new()
62                    .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
63                    .layer(TimeoutLayer::with_status_code(
64                        http::StatusCode::GATEWAY_TIMEOUT,
65                        self.config.write_timeout,
66                    )),
67            );
68
69        if self.config.compression.enabled {
70            app = app.layer(build_compression_layer(&self.config.compression));
71        }
72
73        #[cfg(feature = "tls")]
74        if let Some(ref tls) = self.config.tls {
75            return self.run_tls(app, tls, shutdown).await;
76        }
77
78        #[cfg(feature = "h2c")]
79        if self.config.h2c {
80            return self.run_h2c(app, shutdown).await;
81        }
82
83        self.run_plain(app, shutdown).await
84    }
85
86    async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
87        let listener = TcpListener::bind(self.config.listen).await?;
88
89        axum::serve(
90            listener,
91            app.into_make_service_with_connect_info::<SocketAddr>(),
92        )
93        .with_graceful_shutdown(shutdown_signal(shutdown))
94        .await?;
95
96        Ok(())
97    }
98
99    #[cfg(feature = "tls")]
100    async fn run_tls(
101        &self,
102        app: Router,
103        tls: &crate::config::TlsConfig,
104        shutdown: watch::Receiver<bool>,
105    ) -> Result<()> {
106        use axum_server::Handle;
107        use axum_server::tls_rustls::RustlsConfig;
108
109        let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
110
111        info!(cert = %tls.cert.display(), "TLS enabled");
112
113        let handle = Handle::new();
114        let shutdown_handle = handle.clone();
115        tokio::spawn(async move {
116            shutdown_signal(shutdown).await;
117            shutdown_handle.graceful_shutdown(None);
118        });
119
120        axum_server::bind_rustls(self.config.listen, rustls_config)
121            .handle(handle)
122            .serve(app.into_make_service_with_connect_info::<SocketAddr>())
123            .await?;
124
125        Ok(())
126    }
127
128    #[cfg(feature = "h2c")]
129    async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
130        use hyper_util::rt::{TokioExecutor, TokioIo};
131        use hyper_util::server::conn::auto::Builder as AutoBuilder;
132
133        info!("h2c (HTTP/2 cleartext) enabled");
134
135        let listener = TcpListener::bind(self.config.listen).await?;
136        let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
137        let mut tasks = tokio::task::JoinSet::new();
138
139        loop {
140            tokio::select! {
141                result = listener.accept() => {
142                    let (stream, remote_addr) = result?;
143                    let app = app.clone();
144                    let builder = builder.clone();
145                    tasks.spawn(async move {
146                        let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
147                            // Inject ConnectInfo manually since we bypass axum::serve.
148                            req.extensions_mut().insert(ConnectInfo(remote_addr));
149                            let app = app.clone();
150                            async move {
151                                let resp = tower::Service::call(&mut app.clone(), req).await;
152                                resp.map_err(|e| match e {})
153                            }
154                        });
155                        let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
156                    });
157                }
158                _ = async {
159                    loop {
160                        if shutdown.changed().await.is_err() || *shutdown.borrow() {
161                            break;
162                        }
163                    }
164                } => {
165                    break;
166                }
167            }
168        }
169
170        // Wait for active connections to finish
171        while tasks.join_next().await.is_some() {}
172
173        Ok(())
174    }
175}
176
177async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
178    loop {
179        if shutdown.changed().await.is_err() || *shutdown.borrow() {
180            break;
181        }
182    }
183}
184
185struct ConnectionGuard(Arc<AtomicU64>);
186
187impl Drop for ConnectionGuard {
188    fn drop(&mut self) {
189        self.0.fetch_sub(1, Ordering::Relaxed);
190    }
191}
192
193async fn handle(
194    State(state): State<AppState>,
195    connect_info: ConnectInfo<SocketAddr>,
196    req: Request<Body>,
197) -> Response<Body> {
198    state.active_connections.fetch_add(1, Ordering::Relaxed);
199    let _conn_guard = ConnectionGuard(state.active_connections.clone());
200    let start = Instant::now();
201    let method = req.method().clone();
202    let uri = req.uri().clone();
203    let peer_addr = connect_info.0;
204
205    let client_ip = resolve_client_ip(
206        peer_addr.ip(),
207        req.headers()
208            .get("x-forwarded-for")
209            .and_then(|v| v.to_str().ok()),
210        &state.config.trusted_proxies,
211    );
212
213    let response = handle_inner(&state, req).await;
214
215    if state.config.access_log {
216        let duration = start.elapsed();
217        let status = response.status().as_u16();
218        let response_bytes = response
219            .headers()
220            .get(http::header::CONTENT_LENGTH)
221            .and_then(|v| v.to_str().ok())
222            .and_then(|v| v.parse::<u64>().ok())
223            .unwrap_or(0);
224        info!(
225            client_ip = %client_ip,
226            method = %method,
227            uri = %uri,
228            status = status,
229            duration_ms = duration.as_millis() as u64,
230            response_bytes = response_bytes,
231            "http request",
232        );
233    }
234
235    // _conn_guard drop handles fetch_sub
236    response
237}
238
239async fn handle_inner(state: &AppState, req: Request<Body>) -> Response<Body> {
240    let max_body = state.config.max_request_size;
241    let read_timeout = state.config.read_timeout;
242    let payload = match tokio::time::timeout(read_timeout, encode_request(req, max_body)).await {
243        Ok(Ok(p)) => p,
244        Ok(Err(e)) => {
245            error!(error = ?e, "encode request");
246            return Response::builder()
247                .status(500)
248                .body(Body::from("encode error"))
249                .unwrap();
250        }
251        Err(_) => {
252            return Response::builder()
253                .status(408)
254                .body(Body::from("request body read timeout"))
255                .unwrap();
256        }
257    };
258
259    let response_value = match state.executor.execute_value("http.handle", payload).await {
260        Ok(v) => v,
261        Err(e) => {
262            error!(error = ?e, "dispatch to worker");
263            return Response::builder()
264                .status(502)
265                .body(Body::from("worker error"))
266                .unwrap();
267        }
268    };
269
270    match decode_response(response_value) {
271        Ok(r) => r,
272        Err(e) => {
273            error!(error = ?e, "decode response");
274            Response::builder()
275                .status(500)
276                .body(Body::from("decode error"))
277                .unwrap()
278        }
279    }
280}
281
282/// Resolve the real client IP from X-Forwarded-For if the peer is a trusted proxy.
283///
284/// Walks the X-Forwarded-For chain from right to left, stopping at the first
285/// IP that is NOT in a trusted subnet. This is the standard secure algorithm
286/// (rightmost non-trusted).
287pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
288    if trusted.is_empty() {
289        return peer_ip;
290    }
291
292    if !is_trusted(peer_ip, trusted) {
293        return peer_ip;
294    }
295
296    let Some(xff) = xff else {
297        return peer_ip;
298    };
299
300    let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
301
302    // Walk from right to left — the rightmost non-trusted IP is the client.
303    for addr_str in addrs.iter().rev() {
304        if let Ok(ip) = addr_str.parse::<IpAddr>() {
305            if !is_trusted(ip, trusted) {
306                return ip;
307            }
308        }
309    }
310
311    // All IPs in the chain are trusted — use peer IP.
312    peer_ip
313}
314
315fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
316    trusted.iter().any(|net| net.contains(&ip))
317}
318
319fn build_compression_layer(
320    config: &crate::config::CompressionConfig,
321) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
322    use crate::config::CompressionAlgorithm;
323    use tower_http::compression::CompressionLayer;
324
325    let mut layer = CompressionLayer::new()
326        .no_gzip()
327        .no_br()
328        .no_zstd()
329        .no_deflate();
330
331    for algo in &config.algorithms {
332        layer = match algo {
333            CompressionAlgorithm::Gzip => layer.gzip(true),
334            CompressionAlgorithm::Br => layer.br(true),
335            CompressionAlgorithm::Zstd => layer.zstd(true),
336            CompressionAlgorithm::Deflate => layer.deflate(true),
337        };
338    }
339
340    #[allow(clippy::cast_possible_truncation)]
341    let min_size = config.min_size as u16;
342    layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
343}