1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19#[cfg(feature = "h2c")]
20use tracing::warn;
21use tracing::{error, info};
22
23use crate::config::HttpConfig;
24use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
25use crate::payload::encode_request;
26
27#[derive(Clone)]
28struct AppState {
29 executor: Arc<dyn folk_api::Executor>,
30 config: Arc<HttpConfig>,
31 active_connections: Arc<AtomicU64>,
32 hook_engine: Option<Arc<HookEngine>>,
33}
34
35pub struct HttpServer {
36 config: HttpConfig,
37 executor: Arc<dyn folk_api::Executor>,
38 active_connections: Arc<AtomicU64>,
39 hook_engine: Option<Arc<HookEngine>>,
40}
41
42impl HttpServer {
43 pub fn new(
44 config: HttpConfig,
45 executor: Arc<dyn folk_api::Executor>,
46 active_connections: Arc<AtomicU64>,
47 hook_engine: Option<Arc<HookEngine>>,
48 ) -> Self {
49 Self {
50 config,
51 executor,
52 active_connections,
53 hook_engine,
54 }
55 }
56
57 pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
58 let state = AppState {
59 executor: self.executor.clone(),
60 config: Arc::new(self.config.clone()),
61 active_connections: self.active_connections.clone(),
62 hook_engine: self.hook_engine.clone(),
63 };
64
65 let mut app = Router::new()
66 .route("/{*path}", any(handle))
67 .route("/", any(handle))
68 .with_state(state)
69 .layer(
70 ServiceBuilder::new()
71 .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
72 .layer(TimeoutLayer::with_status_code(
73 http::StatusCode::GATEWAY_TIMEOUT,
74 self.config.write_timeout,
75 )),
76 );
77
78 if self.config.compression.enabled {
79 app = app.layer(build_compression_layer(&self.config.compression));
80 }
81
82 #[cfg(feature = "tls")]
83 if let Some(ref tls) = self.config.tls {
84 return self.run_tls(app, tls, shutdown).await;
85 }
86
87 #[cfg(feature = "h2c")]
88 if self.config.h2c {
89 return self.run_h2c(app, shutdown).await;
90 }
91
92 self.run_plain(app, shutdown).await
93 }
94
95 async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
96 let listener = TcpListener::bind(self.config.listen).await?;
97
98 axum::serve(
99 listener,
100 app.into_make_service_with_connect_info::<SocketAddr>(),
101 )
102 .with_graceful_shutdown(shutdown_signal(shutdown))
103 .await?;
104
105 Ok(())
106 }
107
108 #[cfg(feature = "tls")]
109 async fn run_tls(
110 &self,
111 app: Router,
112 tls: &crate::config::TlsConfig,
113 shutdown: watch::Receiver<bool>,
114 ) -> Result<()> {
115 use axum_server::Handle;
116 use axum_server::tls_rustls::RustlsConfig;
117
118 let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
119
120 info!(cert = %tls.cert.display(), "TLS enabled");
121
122 let handle = Handle::new();
123 let shutdown_handle = handle.clone();
124 tokio::spawn(async move {
125 shutdown_signal(shutdown).await;
126 shutdown_handle.graceful_shutdown(None);
127 });
128
129 axum_server::bind_rustls(self.config.listen, rustls_config)
130 .handle(handle)
131 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
132 .await?;
133
134 Ok(())
135 }
136
137 #[cfg(feature = "h2c")]
138 async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
139 use hyper_util::rt::{TokioExecutor, TokioIo};
140 use hyper_util::server::conn::auto::Builder as AutoBuilder;
141
142 info!("h2c (HTTP/2 cleartext) enabled");
143
144 let listener = TcpListener::bind(self.config.listen).await?;
145 let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
146 let mut tasks = tokio::task::JoinSet::new();
147
148 loop {
149 tokio::select! {
150 result = listener.accept() => {
151 let (stream, remote_addr) = result?;
152 let app = app.clone();
153 let builder = builder.clone();
154 tasks.spawn(async move {
155 let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
156 req.extensions_mut().insert(ConnectInfo(remote_addr));
158 let app = app.clone();
159 async move {
160 let resp = tower::Service::call(&mut app.clone(), req).await;
161 resp.map_err(|e| match e {})
162 }
163 });
164 let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
165 });
166 }
167 _ = async {
168 loop {
169 if shutdown.changed().await.is_err() || *shutdown.borrow() {
170 break;
171 }
172 }
173 } => {
174 break;
175 }
176 }
177 }
178
179 let deadline = tokio::time::sleep(self.config.shutdown_timeout);
181 tokio::pin!(deadline);
182 loop {
183 tokio::select! {
184 _ = &mut deadline => {
185 let remaining = tasks.len();
186 if remaining > 0 {
187 warn!(remaining, "h2c graceful shutdown timed out; aborting connections");
188 tasks.abort_all();
189 while tasks.join_next().await.is_some() {}
190 }
191 break;
192 }
193 result = tasks.join_next() => {
194 if result.is_none() {
195 break;
196 }
197 }
198 }
199 }
200
201 Ok(())
202 }
203}
204
205async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
206 loop {
207 if shutdown.changed().await.is_err() || *shutdown.borrow() {
208 break;
209 }
210 }
211}
212
213struct ConnectionGuard(Arc<AtomicU64>);
214
215impl Drop for ConnectionGuard {
216 fn drop(&mut self) {
217 self.0.fetch_sub(1, Ordering::Relaxed);
218 }
219}
220
221async fn handle(
222 State(state): State<AppState>,
223 connect_info: ConnectInfo<SocketAddr>,
224 req: Request<Body>,
225) -> Response<Body> {
226 state.active_connections.fetch_add(1, Ordering::Relaxed);
227 let _conn_guard = ConnectionGuard(state.active_connections.clone());
228 let start = Instant::now();
229 let method = req.method().clone();
230 let uri = req.uri().clone();
231 let peer_addr = connect_info.0;
232
233 let client_ip = resolve_client_ip(
234 peer_addr.ip(),
235 req.headers()
236 .get("x-forwarded-for")
237 .and_then(|v| v.to_str().ok()),
238 &state.config.trusted_proxies,
239 );
240
241 let (response, request_id) = handle_inner(&state, req, client_ip).await;
242
243 if state.config.access_log {
244 let duration = start.elapsed();
245 let status = response.status().as_u16();
246 let response_bytes = response
247 .headers()
248 .get(http::header::CONTENT_LENGTH)
249 .and_then(|v| v.to_str().ok())
250 .and_then(|v| v.parse::<u64>().ok())
251 .unwrap_or(0);
252 info!(
253 request_id = %request_id,
254 client_ip = %client_ip,
255 method = %method,
256 uri = %uri,
257 status = status,
258 duration_ms = duration.as_millis() as u64,
259 response_bytes = response_bytes,
260 "http request",
261 );
262 }
263
264 response
266}
267
268async fn handle_inner(
271 state: &AppState,
272 req: Request<Body>,
273 client_ip: IpAddr,
274) -> (Response<Body>, Arc<str>) {
275 let no_id: Arc<str> = Arc::from("");
276 let max_body = state.config.max_request_size;
277 let read_timeout = state.config.read_timeout;
278
279 let (parts, body) = req.into_parts();
281
282 let mut req_ctx: Option<RequestContext> = if state.hook_engine.is_some() {
286 let headers: HashMap<String, String> = parts
287 .headers
288 .iter()
289 .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
290 .collect();
291 Some(RequestContext {
292 method: parts.method.to_string(),
293 path: parts.uri.path().to_string(),
294 query: parts.uri.query().unwrap_or("").to_string(),
295 client_ip: client_ip.to_string(),
296 request_id: String::new(), headers,
298 extra: HashMap::new(),
299 error: None,
300 short_circuited: false,
301 })
302 } else {
303 None
304 };
305
306 let req_reassembled = Request::from_parts(parts, body);
308
309 let payload =
311 match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
312 Ok(Ok(p)) => p,
313 Ok(Err(e)) => {
314 error!(error = ?e, "encode request");
315 return (
316 Response::builder()
317 .status(500)
318 .body(Body::from("encode error"))
319 .unwrap(),
320 no_id,
321 );
322 }
323 Err(_) => {
324 return (
325 Response::builder()
326 .status(408)
327 .body(Body::from("request body read timeout"))
328 .unwrap(),
329 no_id,
330 );
331 }
332 };
333
334 if let (Some(engine), Some(ctx)) = (&state.hook_engine, &mut req_ctx) {
336 match engine.run_request_before(ctx) {
337 HookResult::ShortCircuit(resp) => {
338 return (resp, no_id);
340 }
341 HookResult::Continue => {}
342 }
343 }
344
345 let (response_value, request_id) = match state
347 .executor
348 .execute_value_traced("http.handle", payload)
349 .await
350 {
351 Ok(v) => v,
352 Err(e) => {
353 error!(error = ?e, "dispatch to worker");
354
355 if let (Some(engine), Some(ctx)) = (&state.hook_engine, &req_ctx) {
357 let mut err_ctx = ctx.clone();
358 err_ctx.request_id = no_id.to_string();
359 err_ctx.error = Some(e.to_string());
360 if let HookResult::ShortCircuit(resp) = engine.run_request_error(&mut err_ctx) {
361 return (resp, no_id);
362 }
363 }
364
365 return (
366 Response::builder()
367 .status(502)
368 .body(Body::from("worker error"))
369 .unwrap(),
370 no_id,
371 );
372 }
373 };
374
375 let status = response_value
377 .get("status")
378 .and_then(|v| v.as_u64())
379 .unwrap_or(200) as u16;
380
381 let resp_headers: HashMap<String, String> = response_value
382 .get("headers")
383 .and_then(|v| v.as_object())
384 .map(|obj| {
385 obj.iter()
386 .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
387 .collect()
388 })
389 .unwrap_or_default();
390
391 let body_str = response_value
392 .get("body")
393 .and_then(|v| v.as_str())
394 .unwrap_or("")
395 .to_string();
396 let body_encoding = response_value
397 .get("body_encoding")
398 .and_then(|v| v.as_str())
399 .map(str::to_string);
400
401 if let Some(ref mut ctx) = req_ctx {
403 ctx.request_id = request_id.to_string();
404 }
405
406 let mut resp_ctx = ResponseContext {
408 status,
409 resp_headers,
410 body: None, short_circuited: false,
412 };
413
414 if let Some(ref engine) = state.hook_engine {
415 match engine.run_response_headers(&mut resp_ctx) {
416 HookResult::ShortCircuit(resp) => return (resp, request_id),
417 HookResult::Continue => {}
418 }
419 }
420
421 let body_bytes = if body_encoding.as_deref() == Some("base64") {
423 use base64::Engine;
424 match base64::engine::general_purpose::STANDARD.decode(&body_str) {
425 Ok(b) => b,
426 Err(e) => {
427 error!(error = ?e, "decode base64 response body");
428 return (
429 Response::builder()
430 .status(500)
431 .body(Body::from("decode error"))
432 .unwrap(),
433 request_id,
434 );
435 }
436 }
437 } else {
438 body_str.clone().into_bytes()
439 };
440
441 if let Some(ref engine) = state.hook_engine {
443 if engine.has_event("response.after") {
445 resp_ctx.body = Some(body_bytes.clone());
446 }
447
448 match engine.run_response_after(&mut resp_ctx) {
449 HookResult::ShortCircuit(resp) => return (resp, request_id),
450 HookResult::Continue => {}
451 }
452 }
453
454 let final_body_bytes: bytes::Bytes = resp_ctx
456 .body
457 .take()
458 .map(bytes::Bytes::from)
459 .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
460
461 let mut builder = Response::builder().status(resp_ctx.status);
462 for (k, v) in &resp_ctx.resp_headers {
463 builder = builder.header(k.as_str(), v.as_str());
464 }
465
466 let response = match builder.body(Body::from(final_body_bytes)) {
467 Ok(r) => r,
468 Err(e) => {
469 error!(error = ?e, "build response");
470 Response::builder()
471 .status(500)
472 .body(Body::from("build error"))
473 .unwrap()
474 }
475 };
476
477 (response, request_id)
478}
479
480pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
486 if trusted.is_empty() {
487 return peer_ip;
488 }
489
490 if !is_trusted(peer_ip, trusted) {
491 return peer_ip;
492 }
493
494 let Some(xff) = xff else {
495 return peer_ip;
496 };
497
498 let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
499
500 for addr_str in addrs.iter().rev() {
504 match addr_str.parse::<IpAddr>() {
505 Ok(ip) if !is_trusted(ip, trusted) => return ip,
506 Ok(_) => {} Err(_) => return peer_ip, }
509 }
510
511 peer_ip
513}
514
515fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
516 trusted.iter().any(|net| net.contains(&ip))
517}
518
519fn build_compression_layer(
520 config: &crate::config::CompressionConfig,
521) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
522 use crate::config::CompressionAlgorithm;
523 use tower_http::compression::CompressionLayer;
524
525 let mut layer = CompressionLayer::new()
526 .no_gzip()
527 .no_br()
528 .no_zstd()
529 .no_deflate();
530
531 for algo in &config.algorithms {
532 layer = match algo {
533 CompressionAlgorithm::Gzip => layer.gzip(true),
534 CompressionAlgorithm::Br => layer.br(true),
535 CompressionAlgorithm::Zstd => layer.zstd(true),
536 CompressionAlgorithm::Deflate => layer.deflate(true),
537 };
538 }
539
540 let min_size =
542 u16::try_from(config.min_size).expect("min_size validated in PluginFactory::create");
543 layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
544}