1use axum::{
31 body::Body,
32 http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35 collections::HashSet,
36 future::Future,
37 pin::Pin,
38 str::FromStr,
39 task::{Context, Poll},
40};
41
42static HOP_BY_HOP_HEADERS: &[&str] = &[
44 "connection",
45 "keep-alive",
46 "proxy-connection",
47 "transfer-encoding",
48 "te",
49 "trailer",
50 "upgrade",
51];
52use tower::{Layer, Service};
53
54#[derive(Debug, Clone)]
56#[allow(dead_code)] struct ViaEntry {
58 protocol: String, pseudonym: String, port: Option<String>, comment: Option<String>, }
63
64impl ViaEntry {
65 fn parse(entry: &str) -> Option<Self> {
66 let mut parts = entry.split_whitespace();
67
68 let protocol = parts.next()?.to_string();
70
71 let pseudonym_part = parts.next()?;
73 let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74 let (name, port) = pseudonym_part.split_at(colon_idx);
75 (name.to_string(), Some(port[1..].to_string()))
76 } else {
77 (pseudonym_part.to_string(), None)
78 };
79
80 let comment = entry
82 .find('(')
83 .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85 Some(ViaEntry {
86 protocol,
87 pseudonym,
88 port,
89 comment,
90 })
91 }
92}
93
94impl std::fmt::Display for ViaEntry {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{} {}", self.protocol, self.pseudonym)?;
97 if let Some(port) = &self.port {
98 write!(f, ":{port}")?;
99 }
100 if let Some(comment) = &self.comment {
101 write!(f, " {comment}")?;
102 }
103 Ok(())
104 }
105}
106
107#[allow(dead_code)] fn parse_via_header(header: &str) -> Vec<ViaEntry> {
110 header
111 .split(',')
112 .filter_map(|entry| ViaEntry::parse(entry.trim()))
113 .collect()
114}
115
116#[derive(Clone, Debug)]
118pub struct Rfc9110Config {
119 pub server_names: Option<HashSet<String>>,
121 pub pseudonym: Option<String>,
123 pub combine_via: bool,
125 pub preserve_websocket_headers: bool,
130}
131
132impl Default for Rfc9110Config {
133 fn default() -> Self {
134 Self {
135 server_names: None,
136 pseudonym: None,
137 combine_via: true,
138 preserve_websocket_headers: true,
139 }
140 }
141}
142
143#[derive(Clone)]
145pub struct Rfc9110Layer {
146 config: Rfc9110Config,
147}
148
149impl Default for Rfc9110Layer {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl Rfc9110Layer {
156 pub fn new() -> Self {
158 Self {
159 config: Rfc9110Config::default(),
160 }
161 }
162
163 pub fn with_config(config: Rfc9110Config) -> Self {
165 Self { config }
166 }
167}
168
169impl<S> Layer<S> for Rfc9110Layer {
170 type Service = Rfc9110<S>;
171
172 fn layer(&self, inner: S) -> Self::Service {
173 Rfc9110 {
174 inner,
175 config: self.config.clone(),
176 }
177 }
178}
179
180#[derive(Clone)]
182pub struct Rfc9110<S> {
183 inner: S,
184 config: Rfc9110Config,
185}
186
187impl<S> Service<Request<Body>> for Rfc9110<S>
188where
189 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
190 S::Future: Send + 'static,
191{
192 type Response = S::Response;
193 type Error = S::Error;
194 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
195
196 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197 self.inner.poll_ready(cx)
198 }
199
200 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
201 let mut inner = self.inner.clone();
202 let config = self.config.clone();
203
204 Box::pin(async move {
205 if let Some(response) = detect_loop(&request, &config) {
207 return Ok(response);
208 }
209
210 let original_max_forwards =
212 if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
213 request.headers().get(http::header::MAX_FORWARDS).cloned()
214 } else {
215 None
216 };
217
218 if let Some(response) = process_max_forwards(&mut request) {
220 return Ok(response);
221 }
222
223 let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
225
226 let is_websocket =
228 config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
229
230 process_connection_header(&mut request, is_websocket);
232
233 let preserved_headers = request.headers().clone();
235
236 let via_header = add_via_header(&mut request, &config);
238
239 let mut response = inner.call(request).await?;
241
242 let is_websocket_response =
244 is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
245 process_response_headers(&mut response, is_websocket_response);
246
247 if let Some(via) = via_header {
249 if config.pseudonym.is_some() && !config.combine_via {
251 response
252 .headers_mut()
253 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
254 } else {
255 response.headers_mut().insert(http::header::VIA, via);
256 }
257 }
258
259 if let Some(max_forwards) = original_max_forwards {
261 response
263 .headers_mut()
264 .insert(http::header::MAX_FORWARDS, max_forwards);
265 } else if let Some(max_forwards) = max_forwards {
266 response
268 .headers_mut()
269 .insert(http::header::MAX_FORWARDS, max_forwards);
270 }
271
272 for (name, value) in preserved_headers.iter() {
274 if !is_hop_by_hop_header(name) {
275 response.headers_mut().insert(name, value.clone());
276 }
277 }
278
279 Ok(response)
280 })
281 }
282}
283
284fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
286 if let Some(server_names) = &config.server_names
288 && let Some(host) = request.uri().host()
289 && server_names.contains(host)
290 {
291 let mut response = Response::new(Body::empty());
292 *response.status_mut() = StatusCode::LOOP_DETECTED;
293 return Some(response);
294 }
295
296 if let Some(via) = request.headers().get(http::header::VIA)
298 && let Ok(via_str) = via.to_str()
299 {
300 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
301 let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
302
303 for entry in via_entries {
305 let parts: Vec<&str> = entry.split_whitespace().collect();
306 if parts.len() >= 2 && parts[1] == pseudonym {
307 let mut response = Response::new(Body::empty());
308 *response.status_mut() = StatusCode::LOOP_DETECTED;
309 return Some(response);
310 }
311 }
312 }
313
314 None
315}
316
317fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
319 let method = request.method();
320
321 if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
323 if *method != Method::TRACE && *method != Method::OPTIONS {
324 return None;
326 }
327
328 if let Ok(value_str) = max_forwards.to_str() {
329 if let Ok(value) = value_str.parse::<u32>() {
330 if value == 0 {
331 let mut response = Response::new(Body::empty());
332 if *method == Method::TRACE {
333 *response.body_mut() = Body::from(format!("{request:?}"));
334 } else {
335 response.headers_mut().insert(
337 http::header::ALLOW,
338 HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
339 );
340 }
341 *response.status_mut() = StatusCode::OK;
342 Some(response)
343 } else {
344 let new_value = value - 1;
346 request.headers_mut().insert(
347 http::header::MAX_FORWARDS,
348 HeaderValue::from_str(&new_value.to_string()).unwrap(),
349 );
350 None
351 }
352 } else {
353 None }
355 } else {
356 None }
358 } else {
359 None }
361}
362
363static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
365
366fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
368 let mut headers_to_remove = HashSet::new();
369
370 for &name in HOP_BY_HOP_HEADERS {
372 if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
374 continue;
375 }
376 headers_to_remove.insert(HeaderName::from_static(name));
377 }
378
379 if let Some(connection) = request
381 .headers()
382 .get_all(http::header::CONNECTION)
383 .iter()
384 .next()
385 && let Ok(connection_str) = connection.to_str()
386 {
387 for header in connection_str.split(',') {
388 let header = header.trim();
389 if preserve_websocket
391 && WEBSOCKET_HEADERS
392 .iter()
393 .any(|h| header.eq_ignore_ascii_case(h))
394 {
395 continue;
396 }
397 if let Ok(header_name) = HeaderName::from_str(header)
398 && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
399 {
400 headers_to_remove.insert(header_name);
401 }
402 }
403 }
404
405 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = request
408 .headers()
409 .iter()
410 .filter(|(k, _)| {
411 headers_to_remove
412 .iter()
413 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
414 })
415 .map(|(k, _)| k.clone())
416 .collect();
417
418 for header in headers_to_remove {
419 request.headers_mut().remove(&header);
420 }
421}
422
423fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
425 let protocol_version = match request.version() {
427 http::Version::HTTP_09 => "0.9",
428 http::Version::HTTP_10 => "1.0",
429 http::Version::HTTP_11 => "1.1",
430 http::Version::HTTP_2 => "2.0",
431 http::Version::HTTP_3 => "3.0",
432 _ => "1.1", };
434
435 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
437
438 if config.pseudonym.is_some() && !config.combine_via {
440 let via = HeaderValue::from_static("1.1 firewall");
441 request.headers_mut().insert(http::header::VIA, via.clone());
442 return Some(via);
443 }
444
445 let mut via_values = Vec::new();
447 if let Some(existing_via) = request.headers().get(http::header::VIA)
448 && let Ok(existing_via_str) = existing_via.to_str()
449 {
450 if config.combine_via && config.pseudonym.is_some() {
452 let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
453 let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
454 if all_same_protocol {
455 let via = HeaderValue::from_str(&format!(
456 "{} {}",
457 protocol_version,
458 config.pseudonym.as_ref().unwrap()
459 ))
460 .ok()?;
461 request.headers_mut().insert(http::header::VIA, via.clone());
462 return Some(via);
463 }
464 }
465 via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
466 }
467
468 let new_value = format!("{protocol_version} {pseudonym}");
470 via_values.push(new_value);
471
472 let combined_via = via_values.join(", ");
474 let via = HeaderValue::from_str(&combined_via).ok()?;
475 request.headers_mut().insert(http::header::VIA, via.clone());
476 Some(via)
477}
478
479fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
481 let mut headers_to_remove = HashSet::new();
482
483 for &name in HOP_BY_HOP_HEADERS {
485 if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
487 continue;
488 }
489 headers_to_remove.insert(HeaderName::from_static(name));
490 }
491
492 if let Some(connection) = response
494 .headers()
495 .get_all(http::header::CONNECTION)
496 .iter()
497 .next()
498 && let Ok(connection_str) = connection.to_str()
499 {
500 for header in connection_str.split(',') {
501 let header = header.trim();
502 if preserve_websocket
504 && WEBSOCKET_HEADERS
505 .iter()
506 .any(|h| header.eq_ignore_ascii_case(h))
507 {
508 continue;
509 }
510 if let Ok(header_name) = HeaderName::from_str(header)
511 && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
512 {
513 headers_to_remove.insert(header_name);
514 }
515 }
516 }
517
518 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = response
521 .headers()
522 .iter()
523 .filter(|(k, _)| {
524 headers_to_remove
525 .iter()
526 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
527 })
528 .map(|(k, _)| k.clone())
529 .collect();
530
531 for header in headers_to_remove {
532 response.headers_mut().remove(&header);
533 }
534
535 if let Some(via) = response.headers().get(http::header::VIA)
537 && let Ok(via_str) = via.to_str()
538 && via_str.contains("firewall")
539 {
540 response
541 .headers_mut()
542 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
543 }
544}
545
546fn is_hop_by_hop_header(name: &HeaderName) -> bool {
548 HOP_BY_HOP_HEADERS
549 .iter()
550 .any(|h| name.as_str().eq_ignore_ascii_case(h))
551 || name.as_str().eq_ignore_ascii_case("via")
552}
553
554fn is_end_to_end_header(name: &HeaderName) -> bool {
556 matches!(
557 name.as_str(),
558 "cache-control"
559 | "authorization"
560 | "content-length"
561 | "content-type"
562 | "content-encoding"
563 | "accept"
564 | "accept-encoding"
565 | "accept-language"
566 | "range"
567 | "cookie"
568 | "set-cookie"
569 | "etag"
570 )
571}
572
573fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
577 request.headers().contains_key("sec-websocket-key")
578 && request.headers().contains_key("sec-websocket-version")
579}