1use std::collections::HashMap;
2use std::net::{IpAddr, SocketAddr};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::time::Instant;
6
7use anyhow::Result;
8use axum::Router;
9use axum::body::Body;
10use axum::extract::{ConnectInfo, State};
11use axum::http::{self, Request, Response};
12use axum::routing::any;
13use ipnet::IpNet;
14use tokio::net::TcpListener;
15use tokio::sync::watch;
16use tower::ServiceBuilder;
17use tower_http::limit::RequestBodyLimitLayer;
18use tower_http::timeout::TimeoutLayer;
19use tracing::{error, info, warn};
20
21use crate::config::HttpConfig;
22use crate::hooks::{HookEngine, HookResult, RequestContext, ResponseContext};
23use crate::payload::encode_request;
24
25#[derive(Clone)]
26struct AppState {
27 executor: Arc<dyn folk_api::Executor>,
28 config: Arc<HttpConfig>,
29 active_connections: Arc<AtomicU64>,
30 hook_engine: Option<Arc<HookEngine>>,
31}
32
33pub struct HttpServer {
34 config: HttpConfig,
35 executor: Arc<dyn folk_api::Executor>,
36 active_connections: Arc<AtomicU64>,
37 hook_engine: Option<Arc<HookEngine>>,
38}
39
40impl HttpServer {
41 pub fn new(
42 config: HttpConfig,
43 executor: Arc<dyn folk_api::Executor>,
44 active_connections: Arc<AtomicU64>,
45 hook_engine: Option<Arc<HookEngine>>,
46 ) -> Self {
47 Self {
48 config,
49 executor,
50 active_connections,
51 hook_engine,
52 }
53 }
54
55 pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
56 let state = AppState {
57 executor: self.executor.clone(),
58 config: Arc::new(self.config.clone()),
59 active_connections: self.active_connections.clone(),
60 hook_engine: self.hook_engine.clone(),
61 };
62
63 let mut app = Router::new()
64 .route("/{*path}", any(handle))
65 .route("/", any(handle))
66 .with_state(state)
67 .layer(
68 ServiceBuilder::new()
69 .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
70 .layer(TimeoutLayer::with_status_code(
71 http::StatusCode::GATEWAY_TIMEOUT,
72 self.config.write_timeout,
73 )),
74 );
75
76 if self.config.compression.enabled {
77 app = app.layer(build_compression_layer(&self.config.compression));
78 }
79
80 #[cfg(feature = "tls")]
81 if let Some(ref tls) = self.config.tls {
82 return self.run_tls(app, tls, shutdown).await;
83 }
84
85 #[cfg(feature = "h2c")]
86 if self.config.h2c {
87 return self.run_h2c(app, shutdown).await;
88 }
89
90 self.run_plain(app, shutdown).await
91 }
92
93 async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
94 let listener = TcpListener::bind(self.config.listen).await?;
95
96 axum::serve(
97 listener,
98 app.into_make_service_with_connect_info::<SocketAddr>(),
99 )
100 .with_graceful_shutdown(shutdown_signal(shutdown))
101 .await?;
102
103 Ok(())
104 }
105
106 #[cfg(feature = "tls")]
107 async fn run_tls(
108 &self,
109 app: Router,
110 tls: &crate::config::TlsConfig,
111 shutdown: watch::Receiver<bool>,
112 ) -> Result<()> {
113 use axum_server::Handle;
114 use axum_server::tls_rustls::RustlsConfig;
115
116 let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
117
118 info!(cert = %tls.cert.display(), "TLS enabled");
119
120 let handle = Handle::new();
121 let shutdown_handle = handle.clone();
122 tokio::spawn(async move {
123 shutdown_signal(shutdown).await;
124 shutdown_handle.graceful_shutdown(None);
125 });
126
127 axum_server::bind_rustls(self.config.listen, rustls_config)
128 .handle(handle)
129 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
130 .await?;
131
132 Ok(())
133 }
134
135 #[cfg(feature = "h2c")]
136 async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
137 use hyper_util::rt::{TokioExecutor, TokioIo};
138 use hyper_util::server::conn::auto::Builder as AutoBuilder;
139
140 info!("h2c (HTTP/2 cleartext) enabled");
141
142 let listener = TcpListener::bind(self.config.listen).await?;
143 let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
144 let mut tasks = tokio::task::JoinSet::new();
145
146 loop {
147 tokio::select! {
148 result = listener.accept() => {
149 let (stream, remote_addr) = result?;
150 let app = app.clone();
151 let builder = builder.clone();
152 tasks.spawn(async move {
153 let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
154 req.extensions_mut().insert(ConnectInfo(remote_addr));
156 let app = app.clone();
157 async move {
158 let resp = tower::Service::call(&mut app.clone(), req).await;
159 resp.map_err(|e| match e {})
160 }
161 });
162 let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
163 });
164 }
165 _ = async {
166 loop {
167 if shutdown.changed().await.is_err() || *shutdown.borrow() {
168 break;
169 }
170 }
171 } => {
172 break;
173 }
174 }
175 }
176
177 let deadline = tokio::time::sleep(self.config.shutdown_timeout);
179 tokio::pin!(deadline);
180 loop {
181 tokio::select! {
182 _ = &mut deadline => {
183 let remaining = tasks.len();
184 if remaining > 0 {
185 warn!(remaining, "h2c graceful shutdown timed out; aborting connections");
186 tasks.abort_all();
187 while tasks.join_next().await.is_some() {}
188 }
189 break;
190 }
191 result = tasks.join_next() => {
192 if result.is_none() {
193 break;
194 }
195 }
196 }
197 }
198
199 Ok(())
200 }
201}
202
203async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
204 loop {
205 if shutdown.changed().await.is_err() || *shutdown.borrow() {
206 break;
207 }
208 }
209}
210
211struct ConnectionGuard(Arc<AtomicU64>);
212
213impl Drop for ConnectionGuard {
214 fn drop(&mut self) {
215 self.0.fetch_sub(1, Ordering::Relaxed);
216 }
217}
218
219async fn handle(
220 State(state): State<AppState>,
221 connect_info: ConnectInfo<SocketAddr>,
222 req: Request<Body>,
223) -> Response<Body> {
224 state.active_connections.fetch_add(1, Ordering::Relaxed);
225 let _conn_guard = ConnectionGuard(state.active_connections.clone());
226 let start = Instant::now();
227 let method = req.method().clone();
228 let uri = req.uri().clone();
229 let peer_addr = connect_info.0;
230
231 let client_ip = resolve_client_ip(
232 peer_addr.ip(),
233 req.headers()
234 .get("x-forwarded-for")
235 .and_then(|v| v.to_str().ok()),
236 &state.config.trusted_proxies,
237 );
238
239 let (response, request_id) = handle_inner(&state, req, client_ip).await;
240
241 if state.config.access_log {
242 let duration = start.elapsed();
243 let status = response.status().as_u16();
244 let response_bytes = response
245 .headers()
246 .get(http::header::CONTENT_LENGTH)
247 .and_then(|v| v.to_str().ok())
248 .and_then(|v| v.parse::<u64>().ok())
249 .unwrap_or(0);
250 info!(
251 request_id = %request_id,
252 client_ip = %client_ip,
253 method = %method,
254 uri = %uri,
255 status = status,
256 duration_ms = duration.as_millis() as u64,
257 response_bytes = response_bytes,
258 "http request",
259 );
260 }
261
262 response
264}
265
266async fn handle_inner(
269 state: &AppState,
270 req: Request<Body>,
271 client_ip: IpAddr,
272) -> (Response<Body>, Arc<str>) {
273 let no_id: Arc<str> = Arc::from("");
274 let max_body = state.config.max_request_size;
275 let read_timeout = state.config.read_timeout;
276
277 let (parts, body) = req.into_parts();
279
280 let req_method = parts.method.to_string();
281 let req_path = parts.uri.path().to_string();
282 let req_query = parts.uri.query().unwrap_or("").to_string();
283 let req_headers: HashMap<String, String> = parts
284 .headers
285 .iter()
286 .filter_map(|(k, v)| Some((k.to_string(), v.to_str().ok()?.to_string())))
287 .collect();
288
289 let req_reassembled = Request::from_parts(parts, body);
291
292 let payload =
294 match tokio::time::timeout(read_timeout, encode_request(req_reassembled, max_body)).await {
295 Ok(Ok(p)) => p,
296 Ok(Err(e)) => {
297 error!(error = ?e, "encode request");
298 return (
299 Response::builder()
300 .status(500)
301 .body(Body::from("encode error"))
302 .unwrap(),
303 no_id,
304 );
305 }
306 Err(_) => {
307 return (
308 Response::builder()
309 .status(408)
310 .body(Body::from("request body read timeout"))
311 .unwrap(),
312 no_id,
313 );
314 }
315 };
316
317 let mut req_ctx = RequestContext {
319 method: req_method,
320 path: req_path,
321 query: req_query,
322 client_ip: client_ip.to_string(),
323 request_id: String::new(), headers: req_headers,
325 extra: HashMap::new(),
326 error: None,
327 short_circuited: false,
328 };
329
330 if let Some(ref engine) = state.hook_engine {
331 match engine.run_request_before(&mut req_ctx) {
332 HookResult::ShortCircuit(resp) => {
333 return (resp, no_id);
335 }
336 HookResult::Continue => {}
337 }
338 }
339
340 let (response_value, request_id) = match state
342 .executor
343 .execute_value_traced("http.handle", payload)
344 .await
345 {
346 Ok(v) => v,
347 Err(e) => {
348 error!(error = ?e, "dispatch to worker");
349
350 if let Some(ref engine) = state.hook_engine {
352 let mut err_ctx = req_ctx.clone();
353 err_ctx.request_id = no_id.to_string();
354 err_ctx.error = Some(e.to_string());
355 if let HookResult::ShortCircuit(resp) = engine.run_request_error(&mut err_ctx) {
356 return (resp, no_id);
357 }
358 }
359
360 return (
361 Response::builder()
362 .status(502)
363 .body(Body::from("worker error"))
364 .unwrap(),
365 no_id,
366 );
367 }
368 };
369
370 let status = response_value
372 .get("status")
373 .and_then(|v| v.as_u64())
374 .unwrap_or(200) as u16;
375
376 let resp_headers: HashMap<String, String> = response_value
377 .get("headers")
378 .and_then(|v| v.as_object())
379 .map(|obj| {
380 obj.iter()
381 .filter_map(|(k, v)| Some((k.clone(), v.as_str()?.to_string())))
382 .collect()
383 })
384 .unwrap_or_default();
385
386 let body_str = response_value
387 .get("body")
388 .and_then(|v| v.as_str())
389 .unwrap_or("")
390 .to_string();
391 let body_encoding = response_value
392 .get("body_encoding")
393 .and_then(|v| v.as_str())
394 .map(str::to_string);
395
396 req_ctx.request_id = request_id.to_string();
398
399 let mut resp_ctx = ResponseContext {
401 status,
402 resp_headers,
403 body: None, short_circuited: false,
405 };
406
407 if let Some(ref engine) = state.hook_engine {
408 match engine.run_response_headers(&mut resp_ctx) {
409 HookResult::ShortCircuit(resp) => return (resp, request_id),
410 HookResult::Continue => {}
411 }
412 }
413
414 let body_bytes = if body_encoding.as_deref() == Some("base64") {
416 use base64::Engine;
417 match base64::engine::general_purpose::STANDARD.decode(&body_str) {
418 Ok(b) => b,
419 Err(e) => {
420 error!(error = ?e, "decode base64 response body");
421 return (
422 Response::builder()
423 .status(500)
424 .body(Body::from("decode error"))
425 .unwrap(),
426 request_id,
427 );
428 }
429 }
430 } else {
431 body_str.clone().into_bytes()
432 };
433
434 if let Some(ref engine) = state.hook_engine {
436 if engine.has_event("response.after") {
438 resp_ctx.body = Some(body_bytes.clone());
439 }
440
441 match engine.run_response_after(&mut resp_ctx) {
442 HookResult::ShortCircuit(resp) => return (resp, request_id),
443 HookResult::Continue => {}
444 }
445 }
446
447 let final_body_bytes: bytes::Bytes = resp_ctx
449 .body
450 .take()
451 .map(bytes::Bytes::from)
452 .unwrap_or_else(|| bytes::Bytes::from(body_bytes));
453
454 let mut builder = Response::builder().status(resp_ctx.status);
455 for (k, v) in &resp_ctx.resp_headers {
456 builder = builder.header(k.as_str(), v.as_str());
457 }
458
459 let response = match builder.body(Body::from(final_body_bytes)) {
460 Ok(r) => r,
461 Err(e) => {
462 error!(error = ?e, "build response");
463 Response::builder()
464 .status(500)
465 .body(Body::from("build error"))
466 .unwrap()
467 }
468 };
469
470 (response, request_id)
471}
472
473pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
479 if trusted.is_empty() {
480 return peer_ip;
481 }
482
483 if !is_trusted(peer_ip, trusted) {
484 return peer_ip;
485 }
486
487 let Some(xff) = xff else {
488 return peer_ip;
489 };
490
491 let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
492
493 for addr_str in addrs.iter().rev() {
497 match addr_str.parse::<IpAddr>() {
498 Ok(ip) if !is_trusted(ip, trusted) => return ip,
499 Ok(_) => {} Err(_) => return peer_ip, }
502 }
503
504 peer_ip
506}
507
508fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
509 trusted.iter().any(|net| net.contains(&ip))
510}
511
512fn build_compression_layer(
513 config: &crate::config::CompressionConfig,
514) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
515 use crate::config::CompressionAlgorithm;
516 use tower_http::compression::CompressionLayer;
517
518 let mut layer = CompressionLayer::new()
519 .no_gzip()
520 .no_br()
521 .no_zstd()
522 .no_deflate();
523
524 for algo in &config.algorithms {
525 layer = match algo {
526 CompressionAlgorithm::Gzip => layer.gzip(true),
527 CompressionAlgorithm::Br => layer.br(true),
528 CompressionAlgorithm::Zstd => layer.zstd(true),
529 CompressionAlgorithm::Deflate => layer.deflate(true),
530 };
531 }
532
533 #[allow(clippy::cast_possible_truncation)]
534 let min_size = config.min_size as u16;
535 layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
536}