1use std::net::{IpAddr, SocketAddr};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Instant;
5
6use anyhow::Result;
7use axum::Router;
8use axum::body::Body;
9use axum::extract::{ConnectInfo, State};
10use axum::http::{self, Request, Response};
11use axum::routing::any;
12use ipnet::IpNet;
13use tokio::net::TcpListener;
14use tokio::sync::watch;
15use tower::ServiceBuilder;
16use tower_http::limit::RequestBodyLimitLayer;
17use tower_http::timeout::TimeoutLayer;
18use tracing::{error, info};
19
20use crate::config::HttpConfig;
21use crate::payload::{decode_response, encode_request};
22
23#[derive(Clone)]
24struct AppState {
25 executor: Arc<dyn folk_api::Executor>,
26 config: Arc<HttpConfig>,
27 active_connections: Arc<AtomicU64>,
28}
29
30pub struct HttpServer {
31 config: HttpConfig,
32 executor: Arc<dyn folk_api::Executor>,
33 active_connections: Arc<AtomicU64>,
34}
35
36impl HttpServer {
37 pub fn new(
38 config: HttpConfig,
39 executor: Arc<dyn folk_api::Executor>,
40 active_connections: Arc<AtomicU64>,
41 ) -> Self {
42 Self {
43 config,
44 executor,
45 active_connections,
46 }
47 }
48
49 pub async fn run(self, shutdown: watch::Receiver<bool>) -> Result<()> {
50 let state = AppState {
51 executor: self.executor.clone(),
52 config: Arc::new(self.config.clone()),
53 active_connections: self.active_connections.clone(),
54 };
55
56 let mut app = Router::new()
57 .route("/{*path}", any(handle))
58 .route("/", any(handle))
59 .with_state(state)
60 .layer(
61 ServiceBuilder::new()
62 .layer(RequestBodyLimitLayer::new(self.config.max_request_size))
63 .layer(TimeoutLayer::with_status_code(
64 http::StatusCode::GATEWAY_TIMEOUT,
65 self.config.write_timeout,
66 )),
67 );
68
69 if self.config.compression.enabled {
70 app = app.layer(build_compression_layer(&self.config.compression));
71 }
72
73 #[cfg(feature = "tls")]
74 if let Some(ref tls) = self.config.tls {
75 return self.run_tls(app, tls, shutdown).await;
76 }
77
78 #[cfg(feature = "h2c")]
79 if self.config.h2c {
80 return self.run_h2c(app, shutdown).await;
81 }
82
83 self.run_plain(app, shutdown).await
84 }
85
86 async fn run_plain(&self, app: Router, shutdown: watch::Receiver<bool>) -> Result<()> {
87 let listener = TcpListener::bind(self.config.listen).await?;
88
89 axum::serve(
90 listener,
91 app.into_make_service_with_connect_info::<SocketAddr>(),
92 )
93 .with_graceful_shutdown(shutdown_signal(shutdown))
94 .await?;
95
96 Ok(())
97 }
98
99 #[cfg(feature = "tls")]
100 async fn run_tls(
101 &self,
102 app: Router,
103 tls: &crate::config::TlsConfig,
104 shutdown: watch::Receiver<bool>,
105 ) -> Result<()> {
106 use axum_server::Handle;
107 use axum_server::tls_rustls::RustlsConfig;
108
109 let rustls_config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?;
110
111 info!(cert = %tls.cert.display(), "TLS enabled");
112
113 let handle = Handle::new();
114 let shutdown_handle = handle.clone();
115 tokio::spawn(async move {
116 shutdown_signal(shutdown).await;
117 shutdown_handle.graceful_shutdown(None);
118 });
119
120 axum_server::bind_rustls(self.config.listen, rustls_config)
121 .handle(handle)
122 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
123 .await?;
124
125 Ok(())
126 }
127
128 #[cfg(feature = "h2c")]
129 async fn run_h2c(&self, app: Router, mut shutdown: watch::Receiver<bool>) -> Result<()> {
130 use hyper_util::rt::{TokioExecutor, TokioIo};
131 use hyper_util::server::conn::auto::Builder as AutoBuilder;
132
133 info!("h2c (HTTP/2 cleartext) enabled");
134
135 let listener = TcpListener::bind(self.config.listen).await?;
136 let builder = Arc::new(AutoBuilder::new(TokioExecutor::new()));
137 let mut tasks = tokio::task::JoinSet::new();
138
139 loop {
140 tokio::select! {
141 result = listener.accept() => {
142 let (stream, remote_addr) = result?;
143 let app = app.clone();
144 let builder = builder.clone();
145 tasks.spawn(async move {
146 let svc = hyper::service::service_fn(move |mut req: Request<hyper::body::Incoming>| {
147 req.extensions_mut().insert(ConnectInfo(remote_addr));
149 let app = app.clone();
150 async move {
151 let resp = tower::Service::call(&mut app.clone(), req).await;
152 resp.map_err(|e| match e {})
153 }
154 });
155 let _ = builder.serve_connection_with_upgrades(TokioIo::new(stream), svc).await;
156 });
157 }
158 _ = async {
159 loop {
160 if shutdown.changed().await.is_err() || *shutdown.borrow() {
161 break;
162 }
163 }
164 } => {
165 break;
166 }
167 }
168 }
169
170 while tasks.join_next().await.is_some() {}
172
173 Ok(())
174 }
175}
176
177async fn shutdown_signal(mut shutdown: watch::Receiver<bool>) {
178 loop {
179 if shutdown.changed().await.is_err() || *shutdown.borrow() {
180 break;
181 }
182 }
183}
184
185struct ConnectionGuard(Arc<AtomicU64>);
186
187impl Drop for ConnectionGuard {
188 fn drop(&mut self) {
189 self.0.fetch_sub(1, Ordering::Relaxed);
190 }
191}
192
193async fn handle(
194 State(state): State<AppState>,
195 connect_info: ConnectInfo<SocketAddr>,
196 req: Request<Body>,
197) -> Response<Body> {
198 state.active_connections.fetch_add(1, Ordering::Relaxed);
199 let _conn_guard = ConnectionGuard(state.active_connections.clone());
200 let start = Instant::now();
201 let method = req.method().clone();
202 let uri = req.uri().clone();
203 let peer_addr = connect_info.0;
204
205 let client_ip = resolve_client_ip(
206 peer_addr.ip(),
207 req.headers()
208 .get("x-forwarded-for")
209 .and_then(|v| v.to_str().ok()),
210 &state.config.trusted_proxies,
211 );
212
213 let (response, request_id) = handle_inner(&state, req).await;
214
215 if state.config.access_log {
216 let duration = start.elapsed();
217 let status = response.status().as_u16();
218 let response_bytes = response
219 .headers()
220 .get(http::header::CONTENT_LENGTH)
221 .and_then(|v| v.to_str().ok())
222 .and_then(|v| v.parse::<u64>().ok())
223 .unwrap_or(0);
224 info!(
225 request_id = %request_id,
226 client_ip = %client_ip,
227 method = %method,
228 uri = %uri,
229 status = status,
230 duration_ms = duration.as_millis() as u64,
231 response_bytes = response_bytes,
232 "http request",
233 );
234 }
235
236 response
238}
239
240async fn handle_inner(state: &AppState, req: Request<Body>) -> (Response<Body>, Arc<str>) {
243 let no_id: Arc<str> = Arc::from("");
244 let max_body = state.config.max_request_size;
245 let read_timeout = state.config.read_timeout;
246 let payload = match tokio::time::timeout(read_timeout, encode_request(req, max_body)).await {
247 Ok(Ok(p)) => p,
248 Ok(Err(e)) => {
249 error!(error = ?e, "encode request");
250 return (
251 Response::builder()
252 .status(500)
253 .body(Body::from("encode error"))
254 .unwrap(),
255 no_id,
256 );
257 }
258 Err(_) => {
259 return (
260 Response::builder()
261 .status(408)
262 .body(Body::from("request body read timeout"))
263 .unwrap(),
264 no_id,
265 );
266 }
267 };
268
269 let (response_value, request_id) = match state
270 .executor
271 .execute_value_traced("http.handle", payload)
272 .await
273 {
274 Ok(v) => v,
275 Err(e) => {
276 error!(error = ?e, "dispatch to worker");
277 return (
278 Response::builder()
279 .status(502)
280 .body(Body::from("worker error"))
281 .unwrap(),
282 no_id,
283 );
284 }
285 };
286
287 let response = match decode_response(response_value) {
288 Ok(r) => r,
289 Err(e) => {
290 error!(error = ?e, "decode response");
291 Response::builder()
292 .status(500)
293 .body(Body::from("decode error"))
294 .unwrap()
295 }
296 };
297 (response, request_id)
298}
299
300pub fn resolve_client_ip(peer_ip: IpAddr, xff: Option<&str>, trusted: &[IpNet]) -> IpAddr {
306 if trusted.is_empty() {
307 return peer_ip;
308 }
309
310 if !is_trusted(peer_ip, trusted) {
311 return peer_ip;
312 }
313
314 let Some(xff) = xff else {
315 return peer_ip;
316 };
317
318 let addrs: Vec<&str> = xff.split(',').map(|s| s.trim()).collect();
319
320 for addr_str in addrs.iter().rev() {
322 if let Ok(ip) = addr_str.parse::<IpAddr>() {
323 if !is_trusted(ip, trusted) {
324 return ip;
325 }
326 }
327 }
328
329 peer_ip
331}
332
333fn is_trusted(ip: IpAddr, trusted: &[IpNet]) -> bool {
334 trusted.iter().any(|net| net.contains(&ip))
335}
336
337fn build_compression_layer(
338 config: &crate::config::CompressionConfig,
339) -> tower_http::compression::CompressionLayer<tower_http::compression::predicate::SizeAbove> {
340 use crate::config::CompressionAlgorithm;
341 use tower_http::compression::CompressionLayer;
342
343 let mut layer = CompressionLayer::new()
344 .no_gzip()
345 .no_br()
346 .no_zstd()
347 .no_deflate();
348
349 for algo in &config.algorithms {
350 layer = match algo {
351 CompressionAlgorithm::Gzip => layer.gzip(true),
352 CompressionAlgorithm::Br => layer.br(true),
353 CompressionAlgorithm::Zstd => layer.zstd(true),
354 CompressionAlgorithm::Deflate => layer.deflate(true),
355 };
356 }
357
358 #[allow(clippy::cast_possible_truncation)]
359 let min_size = config.min_size as u16;
360 layer.compress_when(tower_http::compression::predicate::SizeAbove::new(min_size))
361}