1use axum::{
31 body::Body,
32 http::{HeaderValue, Method, Request, Response, StatusCode, header::HeaderName},
33};
34use std::{
35 collections::{HashMap, HashSet},
36 future::Future,
37 pin::Pin,
38 str::FromStr,
39 task::{Context, Poll},
40};
41
42static HOP_BY_HOP_HEADERS: &[&str] = &[
44 "connection",
45 "keep-alive",
46 "proxy-connection",
47 "transfer-encoding",
48 "te",
49 "trailer",
50 "upgrade",
51];
52use tower::{Layer, Service};
53
54#[derive(Debug, Clone)]
56#[allow(dead_code)] struct ViaEntry {
58 protocol: String, pseudonym: String, port: Option<String>, comment: Option<String>, }
63
64impl ViaEntry {
65 fn parse(entry: &str) -> Option<Self> {
66 let mut parts = entry.split_whitespace();
67
68 let protocol = parts.next()?.to_string();
70
71 let pseudonym_part = parts.next()?;
73 let (pseudonym, port) = if let Some(colon_idx) = pseudonym_part.find(':') {
74 let (name, port) = pseudonym_part.split_at(colon_idx);
75 (name.to_string(), Some(port[1..].to_string()))
76 } else {
77 (pseudonym_part.to_string(), None)
78 };
79
80 let comment = entry
82 .find('(')
83 .and_then(|start| entry.rfind(')').map(|end| entry[start..=end].to_string()));
84
85 Some(ViaEntry {
86 protocol,
87 pseudonym,
88 port,
89 comment,
90 })
91 }
92}
93
94impl std::fmt::Display for ViaEntry {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 write!(f, "{} {}", self.protocol, self.pseudonym)?;
97 if let Some(port) = &self.port {
98 write!(f, ":{port}")?;
99 }
100 if let Some(comment) = &self.comment {
101 write!(f, " {comment}")?;
102 }
103 Ok(())
104 }
105}
106
107fn parse_via_header(header: &str) -> Vec<ViaEntry> {
109 header
110 .split(',')
111 .filter_map(|entry| ViaEntry::parse(entry.trim()))
112 .collect()
113}
114
115fn group_by_protocol(entries: Vec<ViaEntry>) -> HashMap<String, Vec<ViaEntry>> {
117 let mut groups = HashMap::new();
118 for entry in entries {
119 groups
120 .entry(entry.protocol.clone())
121 .or_insert_with(Vec::new)
122 .push(entry);
123 }
124 groups
125}
126
127#[derive(Clone, Debug)]
129pub struct Rfc9110Config {
130 pub server_names: Option<HashSet<String>>,
132 pub pseudonym: Option<String>,
134 pub combine_via: bool,
136}
137
138impl Default for Rfc9110Config {
139 fn default() -> Self {
140 Self {
141 server_names: None,
142 pseudonym: None,
143 combine_via: true,
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct Rfc9110Layer {
151 config: Rfc9110Config,
152}
153
154impl Default for Rfc9110Layer {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl Rfc9110Layer {
161 pub fn new() -> Self {
163 Self {
164 config: Rfc9110Config::default(),
165 }
166 }
167
168 pub fn with_config(config: Rfc9110Config) -> Self {
170 Self { config }
171 }
172}
173
174impl<S> Layer<S> for Rfc9110Layer {
175 type Service = Rfc9110<S>;
176
177 fn layer(&self, inner: S) -> Self::Service {
178 Rfc9110 {
179 inner,
180 config: self.config.clone(),
181 }
182 }
183}
184
185#[derive(Clone)]
187pub struct Rfc9110<S> {
188 inner: S,
189 config: Rfc9110Config,
190}
191
192impl<S> Service<Request<Body>> for Rfc9110<S>
193where
194 S: Service<Request<Body>, Response = Response<Body>> + Clone + Send + 'static,
195 S::Future: Send + 'static,
196{
197 type Response = S::Response;
198 type Error = S::Error;
199 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
200
201 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 self.inner.poll_ready(cx)
203 }
204
205 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
206 let mut inner = self.inner.clone();
207 let config = self.config.clone();
208
209 Box::pin(async move {
210 if let Some(response) = detect_loop(&request, &config) {
212 return Ok(response);
213 }
214
215 let original_max_forwards =
217 if request.method() != Method::TRACE && request.method() != Method::OPTIONS {
218 request.headers().get(http::header::MAX_FORWARDS).cloned()
219 } else {
220 None
221 };
222
223 if let Some(response) = process_max_forwards(&mut request) {
225 return Ok(response);
226 }
227
228 let max_forwards = request.headers().get(http::header::MAX_FORWARDS).cloned();
230
231 process_connection_header(&mut request);
233
234 let preserved_headers = request.headers().clone();
236
237 let via_header = add_via_header(&mut request, &config);
239
240 let mut response = inner.call(request).await?;
242
243 process_response_headers(&mut response);
245
246 if let Some(via) = via_header {
248 if config.pseudonym.is_some() && !config.combine_via {
250 response
251 .headers_mut()
252 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
253 } else {
254 response.headers_mut().insert(http::header::VIA, via);
255 }
256 }
257
258 if let Some(max_forwards) = original_max_forwards {
260 response
262 .headers_mut()
263 .insert(http::header::MAX_FORWARDS, max_forwards);
264 } else if let Some(max_forwards) = max_forwards {
265 response
267 .headers_mut()
268 .insert(http::header::MAX_FORWARDS, max_forwards);
269 }
270
271 for (name, value) in preserved_headers.iter() {
273 if !is_hop_by_hop_header(name) {
274 response.headers_mut().insert(name, value.clone());
275 }
276 }
277
278 Ok(response)
279 })
280 }
281}
282
283fn detect_loop(request: &Request<Body>, config: &Rfc9110Config) -> Option<Response<Body>> {
285 if let Some(server_names) = &config.server_names {
287 if let Some(host) = request.uri().host() {
288 if server_names.contains(host) {
289 let mut response = Response::new(Body::empty());
290 *response.status_mut() = StatusCode::LOOP_DETECTED;
291 return Some(response);
292 }
293 }
294 }
295
296 if let Some(via) = request.headers().get(http::header::VIA) {
298 if let Ok(via_str) = via.to_str() {
299 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
300 let via_entries: Vec<&str> = via_str.split(',').map(str::trim).collect();
301
302 for entry in via_entries {
304 let parts: Vec<&str> = entry.split_whitespace().collect();
305 if parts.len() >= 2 && parts[1] == pseudonym {
306 let mut response = Response::new(Body::empty());
307 *response.status_mut() = StatusCode::LOOP_DETECTED;
308 return Some(response);
309 }
310 }
311 }
312 }
313
314 None
315}
316
317fn process_max_forwards(request: &mut Request<Body>) -> Option<Response<Body>> {
319 let method = request.method();
320
321 if let Some(max_forwards) = request.headers().get(http::header::MAX_FORWARDS) {
323 if *method != Method::TRACE && *method != Method::OPTIONS {
324 return None;
326 }
327
328 if let Ok(value_str) = max_forwards.to_str() {
329 if let Ok(value) = value_str.parse::<u32>() {
330 if value == 0 {
331 let mut response = Response::new(Body::empty());
332 if *method == Method::TRACE {
333 *response.body_mut() = Body::from(format!("{request:?}"));
334 } else {
335 response.headers_mut().insert(
337 http::header::ALLOW,
338 HeaderValue::from_static("GET, HEAD, OPTIONS, TRACE"),
339 );
340 }
341 *response.status_mut() = StatusCode::OK;
342 Some(response)
343 } else {
344 let new_value = value - 1;
346 request.headers_mut().insert(
347 http::header::MAX_FORWARDS,
348 HeaderValue::from_str(&new_value.to_string()).unwrap(),
349 );
350 None
351 }
352 } else {
353 None }
355 } else {
356 None }
358 } else {
359 None }
361}
362
363fn process_connection_header(request: &mut Request<Body>) {
365 let mut headers_to_remove = HashSet::new();
366
367 for &name in HOP_BY_HOP_HEADERS {
369 headers_to_remove.insert(HeaderName::from_static(name));
370 }
371
372 if let Some(connection) = request
374 .headers()
375 .get_all(http::header::CONNECTION)
376 .iter()
377 .next()
378 {
379 if let Ok(connection_str) = connection.to_str() {
380 for header in connection_str.split(',') {
381 let header = header.trim();
382 if let Ok(header_name) = HeaderName::from_str(header) {
383 if is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name) {
384 headers_to_remove.insert(header_name);
385 }
386 }
387 }
388 }
389 }
390
391 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = request
394 .headers()
395 .iter()
396 .filter(|(k, _)| {
397 headers_to_remove
398 .iter()
399 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
400 })
401 .map(|(k, _)| k.clone())
402 .collect();
403
404 for header in headers_to_remove {
405 request.headers_mut().remove(&header);
406 }
407}
408
409fn add_via_header(request: &mut Request<Body>, config: &Rfc9110Config) -> Option<HeaderValue> {
411 let protocol_version = match request.version() {
413 http::Version::HTTP_09 => "0.9",
414 http::Version::HTTP_10 => "1.0",
415 http::Version::HTTP_11 => "1.1",
416 http::Version::HTTP_2 => "2.0",
417 http::Version::HTTP_3 => "3.0",
418 _ => "1.1", };
420
421 let pseudonym = config.pseudonym.as_deref().unwrap_or("proxy");
423
424 if config.pseudonym.is_some() && !config.combine_via {
426 let via = HeaderValue::from_static("1.1 firewall");
427 request.headers_mut().insert(http::header::VIA, via.clone());
428 return Some(via);
429 }
430
431 let mut via_values = Vec::new();
433 if let Some(existing_via) = request.headers().get(http::header::VIA) {
434 if let Ok(existing_via_str) = existing_via.to_str() {
435 if config.combine_via && config.pseudonym.is_some() {
437 let entries: Vec<_> = existing_via_str.split(',').map(|s| s.trim()).collect();
438 let all_same_protocol = entries.iter().all(|s| s.starts_with(protocol_version));
439 if all_same_protocol {
440 let via = HeaderValue::from_str(&format!(
441 "{} {}",
442 protocol_version,
443 config.pseudonym.as_ref().unwrap()
444 ))
445 .ok()?;
446 request.headers_mut().insert(http::header::VIA, via.clone());
447 return Some(via);
448 }
449 }
450 via_values.extend(existing_via_str.split(',').map(|s| s.trim().to_string()));
451 }
452 }
453
454 let new_value = format!("{protocol_version} {pseudonym}");
456 via_values.push(new_value);
457
458 let combined_via = via_values.join(", ");
460 let via = HeaderValue::from_str(&combined_via).ok()?;
461 request.headers_mut().insert(http::header::VIA, via.clone());
462 Some(via)
463}
464
465fn process_response_headers(response: &mut Response<Body>) {
467 let mut headers_to_remove = HashSet::new();
468
469 for &name in HOP_BY_HOP_HEADERS {
471 headers_to_remove.insert(HeaderName::from_static(name));
472 }
473
474 if let Some(connection) = response
476 .headers()
477 .get_all(http::header::CONNECTION)
478 .iter()
479 .next()
480 {
481 if let Ok(connection_str) = connection.to_str() {
482 for header in connection_str.split(',') {
483 let header = header.trim();
484 if let Ok(header_name) = HeaderName::from_str(header) {
485 if is_hop_by_hop_header(&header_name) || !is_end_to_end_header(&header_name) {
486 headers_to_remove.insert(header_name);
487 }
488 }
489 }
490 }
491 }
492
493 let headers_to_remove = headers_to_remove; let headers_to_remove: Vec<_> = response
496 .headers()
497 .iter()
498 .filter(|(k, _)| {
499 headers_to_remove
500 .iter()
501 .any(|h| k.as_str().eq_ignore_ascii_case(h.as_str()))
502 })
503 .map(|(k, _)| k.clone())
504 .collect();
505
506 for header in headers_to_remove {
507 response.headers_mut().remove(&header);
508 }
509
510 if let Some(via) = response.headers().get(http::header::VIA) {
512 if let Ok(via_str) = via.to_str() {
513 let entries = parse_via_header(via_str);
514 let _groups = group_by_protocol(entries);
515
516 if let Some(via_header) = response.headers().get(http::header::VIA) {
518 if let Ok(via_str) = via_header.to_str() {
519 if via_str.contains("firewall") {
520 response
521 .headers_mut()
522 .insert(http::header::VIA, HeaderValue::from_static("1.1 firewall"));
523 }
524 }
525 }
526 }
527 }
528}
529
530fn is_hop_by_hop_header(name: &HeaderName) -> bool {
532 HOP_BY_HOP_HEADERS
533 .iter()
534 .any(|h| name.as_str().eq_ignore_ascii_case(h))
535 || name.as_str().eq_ignore_ascii_case("via")
536}
537
538fn is_end_to_end_header(name: &HeaderName) -> bool {
540 matches!(
541 name.as_str(),
542 "cache-control"
543 | "authorization"
544 | "content-length"
545 | "content-type"
546 | "content-encoding"
547 | "accept"
548 | "accept-encoding"
549 | "accept-language"
550 | "range"
551 | "cookie"
552 | "set-cookie"
553 | "etag"
554 )
555}