1use axum::{
31 body::Body,
32 http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35 collections::{HashMap, HashSet},
36 future::Future,
37 pin::Pin,
38 str::FromStr,
39 task::{Context, Poll},
40};
41
42static HOP_BY_HOP_HEADERS: &[&str] = &[
44 "connection",
45 "keep-alive",
46 "proxy-connection",
47 "transfer-encoding",
48 "te",
49 "trailer",
50 "upgrade",
51];
52use tower::{Layer, Service};
53
54#[derive(Debug, Clone)]
56#[allow(dead_code)] struct ViaEntry {
58 protocol: String, pseudonym: String, port: Option<String>, comment: Option<String>, }
63
64impl ViaEntry {
65 fn parse(entry: &str) -> Option<Self> {
66 let mut parts = entry.split_whitespace();
67
68 let protocol = parts.next()?.to_string();
70
71 let pseudonym_part = parts.next()?;
73 let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74 let (name, port) = pseudonym_part.split_at(colon_idx);
75 (name.to_string(), Some(port[1..].to_string()))
76 } else {
77 (pseudonym_part.to_string(), None)
78 };
79
80 let comment = entry
82 .find('(')
83 .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85 Some(ViaEntry {
86 protocol,
87 pseudonym,
88 port,
89 comment,
90 })
91 }
92}
93
94impl std::fmt::Display for ViaEntry {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{} {}", self.protocol, self.pseudonym)?;
97 if let Some(port) = &self.port {
98 write!(f, ":{port}")?;
99 }
100 if let Some(comment) = &self.comment {
101 write!(f, " {comment}")?;
102 }
103 Ok(())
104 }
105}
106
107fn parse_via_header(header: &str) -> Vec<ViaEntry> {
109 header
110 .split(',')
111 .filter_map(|entry| ViaEntry::parse(entry.trim()))
112 .collect()
113}
114
115fn group_by_protocol(entries: Vec<ViaEntry>) -> HashMap<String, Vec<ViaEntry>> {
117 let mut groups = HashMap::new();
118 for entry in entries {
119 groups
120 .entry(entry.protocol.clone())
121 .or_insert_with(Vec::new)
122 .push(entry);
123 }
124 groups
125}
126
127#[derive(Clone, Debug)]
129pub struct Rfc9110Config {
130 pub server_names: Option<HashSet<String>>,
132 pub pseudonym: Option<String>,
134 pub combine_via: bool,
136 pub preserve_websocket_headers: bool,
141}
142
143impl Default for Rfc9110Config {
144 fn default() -> Self {
145 Self {
146 server_names: None,
147 pseudonym: None,
148 combine_via: true,
149 preserve_websocket_headers: true,
150 }
151 }
152}
153
154#[derive(Clone)]
156pub struct Rfc9110Layer {
157 config: Rfc9110Config,
158}
159
160impl Default for Rfc9110Layer {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166impl Rfc9110Layer {
167 pub fn new() -> Self {
169 Self {
170 config: Rfc9110Config::default(),
171 }
172 }
173
174 pub fn with_config(config: Rfc9110Config) -> Self {
176 Self { config }
177 }
178}
179
180impl<S> Layer<S> for Rfc9110Layer {
181 type Service = Rfc9110<S>;
182
183 fn layer(&self, inner: S) -> Self::Service {
184 Rfc9110 {
185 inner,
186 config: self.config.clone(),
187 }
188 }
189}
190
191#[derive(Clone)]
193pub struct Rfc9110<S> {
194 inner: S,
195 config: Rfc9110Config,
196}
197
198impl<S> Service<Request<Body>> for Rfc9110<S>
199where
200 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
201 S::Future: Send + 'static,
202{
203 type Response = S::Response;
204 type Error = S::Error;
205 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
206
207 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 self.inner.poll_ready(cx)
209 }
210
211 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
212 let mut inner = self.inner.clone();
213 let config = self.config.clone();
214
215 Box::pin(async move {
216 if let Some(response) = detect_loop(&request, &config) {
218 return Ok(response);
219 }
220
221 let original_max_forwards =
223 if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
224 request.headers().get(http::header::MAX_FORWARDS).cloned()
225 } else {
226 None
227 };
228
229 if let Some(response) = process_max_forwards(&mut request) {
231 return Ok(response);
232 }
233
234 let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
236
237 let is_websocket =
239 config.preserve_websocket_headers && is_websocket_upgrade_request(&request);
240
241 process_connection_header(&mut request, is_websocket);
243
244 let preserved_headers = request.headers().clone();
246
247 let via_header = add_via_header(&mut request, &config);
249
250 let mut response = inner.call(request).await?;
252
253 let is_websocket_response =
255 is_websocket && response.status() == StatusCode::SWITCHING_PROTOCOLS;
256 process_response_headers(&mut response, is_websocket_response);
257
258 if let Some(via) = via_header {
260 if config.pseudonym.is_some() && !config.combine_via {
262 response
263 .headers_mut()
264 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
265 } else {
266 response.headers_mut().insert(http::header::VIA, via);
267 }
268 }
269
270 if let Some(max_forwards) = original_max_forwards {
272 response
274 .headers_mut()
275 .insert(http::header::MAX_FORWARDS, max_forwards);
276 } else if let Some(max_forwards) = max_forwards {
277 response
279 .headers_mut()
280 .insert(http::header::MAX_FORWARDS, max_forwards);
281 }
282
283 for (name, value) in preserved_headers.iter() {
285 if !is_hop_by_hop_header(name) {
286 response.headers_mut().insert(name, value.clone());
287 }
288 }
289
290 Ok(response)
291 })
292 }
293}
294
295fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
297 if let Some(server_names) = &config.server_names
299 && let Some(host) = request.uri().host()
300 && server_names.contains(host)
301 {
302 let mut response = Response::new(Body::empty());
303 *response.status_mut() = StatusCode::LOOP_DETECTED;
304 return Some(response);
305 }
306
307 if let Some(via) = request.headers().get(http::header::VIA)
309 && let Ok(via_str) = via.to_str()
310 {
311 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
312 let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
313
314 for entry in via_entries {
316 let parts: Vec<&str> = entry.split_whitespace().collect();
317 if parts.len() >= 2 && parts[1] == pseudonym {
318 let mut response = Response::new(Body::empty());
319 *response.status_mut() = StatusCode::LOOP_DETECTED;
320 return Some(response);
321 }
322 }
323 }
324
325 None
326}
327
328fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
330 let method = request.method();
331
332 if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
334 if *method != Method::TRACE && *method != Method::OPTIONS {
335 return None;
337 }
338
339 if let Ok(value_str) = max_forwards.to_str() {
340 if let Ok(value) = value_str.parse::<u32>() {
341 if value == 0 {
342 let mut response = Response::new(Body::empty());
343 if *method == Method::TRACE {
344 *response.body_mut() = Body::from(format!("{request:?}"));
345 } else {
346 response.headers_mut().insert(
348 http::header::ALLOW,
349 HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
350 );
351 }
352 *response.status_mut() = StatusCode::OK;
353 Some(response)
354 } else {
355 let new_value = value - 1;
357 request.headers_mut().insert(
358 http::header::MAX_FORWARDS,
359 HeaderValue::from_str(&new_value.to_string()).unwrap(),
360 );
361 None
362 }
363 } else {
364 None }
366 } else {
367 None }
369 } else {
370 None }
372}
373
374static WEBSOCKET_HEADERS: &[&str] = &["connection", "upgrade"];
376
377fn process_connection_header(request: &mut Request<Body>, preserve_websocket: bool) {
379 let mut headers_to_remove = HashSet::new();
380
381 for &name in HOP_BY_HOP_HEADERS {
383 if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
385 continue;
386 }
387 headers_to_remove.insert(HeaderName::from_static(name));
388 }
389
390 if let Some(connection) = request
392 .headers()
393 .get_all(http::header::CONNECTION)
394 .iter()
395 .next()
396 && let Ok(connection_str) = connection.to_str()
397 {
398 for header in connection_str.split(',') {
399 let header = header.trim();
400 if preserve_websocket
402 && WEBSOCKET_HEADERS
403 .iter()
404 .any(|h| header.eq_ignore_ascii_case(h))
405 {
406 continue;
407 }
408 if let Ok(header_name) = HeaderName::from_str(header)
409 && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
410 {
411 headers_to_remove.insert(header_name);
412 }
413 }
414 }
415
416 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = request
419 .headers()
420 .iter()
421 .filter(|(k, _)| {
422 headers_to_remove
423 .iter()
424 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
425 })
426 .map(|(k, _)| k.clone())
427 .collect();
428
429 for header in headers_to_remove {
430 request.headers_mut().remove(&header);
431 }
432}
433
434fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
436 let protocol_version = match request.version() {
438 http::Version::HTTP_09 => "0.9",
439 http::Version::HTTP_10 => "1.0",
440 http::Version::HTTP_11 => "1.1",
441 http::Version::HTTP_2 => "2.0",
442 http::Version::HTTP_3 => "3.0",
443 _ => "1.1", };
445
446 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
448
449 if config.pseudonym.is_some() && !config.combine_via {
451 let via = HeaderValue::from_static("1.1 firewall");
452 request.headers_mut().insert(http::header::VIA, via.clone());
453 return Some(via);
454 }
455
456 let mut via_values = Vec::new();
458 if let Some(existing_via) = request.headers().get(http::header::VIA)
459 && let Ok(existing_via_str) = existing_via.to_str()
460 {
461 if config.combine_via && config.pseudonym.is_some() {
463 let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
464 let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
465 if all_same_protocol {
466 let via = HeaderValue::from_str(&format!(
467 "{} {}",
468 protocol_version,
469 config.pseudonym.as_ref().unwrap()
470 ))
471 .ok()?;
472 request.headers_mut().insert(http::header::VIA, via.clone());
473 return Some(via);
474 }
475 }
476 via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
477 }
478
479 let new_value = format!("{protocol_version} {pseudonym}");
481 via_values.push(new_value);
482
483 let combined_via = via_values.join(", ");
485 let via = HeaderValue::from_str(&combined_via).ok()?;
486 request.headers_mut().insert(http::header::VIA, via.clone());
487 Some(via)
488}
489
490fn process_response_headers(response: &mut Response<Body>, preserve_websocket: bool) {
492 let mut headers_to_remove = HashSet::new();
493
494 for &name in HOP_BY_HOP_HEADERS {
496 if preserve_websocket && WEBSOCKET_HEADERS.contains(&name) {
498 continue;
499 }
500 headers_to_remove.insert(HeaderName::from_static(name));
501 }
502
503 if let Some(connection) = response
505 .headers()
506 .get_all(http::header::CONNECTION)
507 .iter()
508 .next()
509 && let Ok(connection_str) = connection.to_str()
510 {
511 for header in connection_str.split(',') {
512 let header = header.trim();
513 if preserve_websocket
515 && WEBSOCKET_HEADERS
516 .iter()
517 .any(|h| header.eq_ignore_ascii_case(h))
518 {
519 continue;
520 }
521 if let Ok(header_name) = HeaderName::from_str(header)
522 && (is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name))
523 {
524 headers_to_remove.insert(header_name);
525 }
526 }
527 }
528
529 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = response
532 .headers()
533 .iter()
534 .filter(|(k, _)| {
535 headers_to_remove
536 .iter()
537 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
538 })
539 .map(|(k, _)| k.clone())
540 .collect();
541
542 for header in headers_to_remove {
543 response.headers_mut().remove(&header);
544 }
545
546 if let Some(via) = response.headers().get(http::header::VIA)
548 && let Ok(via_str) = via.to_str()
549 {
550 let entries = parse_via_header(via_str);
551 let _groups = group_by_protocol(entries);
552
553 if via_str.contains("firewall") {
555 response
556 .headers_mut()
557 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
558 }
559 }
560}
561
562fn is_hop_by_hop_header(name: &HeaderName) -> bool {
564 HOP_BY_HOP_HEADERS
565 .iter()
566 .any(|h| name.as_str().eq_ignore_ascii_case(h))
567 || name.as_str().eq_ignore_ascii_case("via")
568}
569
570fn is_end_to_end_header(name: &HeaderName) -> bool {
572 matches!(
573 name.as_str(),
574 "cache-control"
575 | "authorization"
576 | "content-length"
577 | "content-type"
578 | "content-encoding"
579 | "accept"
580 | "accept-encoding"
581 | "accept-language"
582 | "range"
583 | "cookie"
584 | "set-cookie"
585 | "etag"
586 )
587}
588
589fn is_websocket_upgrade_request(request: &Request<Body>) -> bool {
593 request.headers().contains_key("sec-websocket-key")
594 && request.headers().contains_key("sec-websocket-version")
595}