1use std::{net::SocketAddr, str::FromStr};
2
3use http::{
4 HeaderMap, HeaderName, HeaderValue, Method, StatusCode, Version,
5 header::{self, AsHeaderName, InvalidHeaderValue},
6 uri::{Scheme, Uri},
7};
8use n0_error::{Result, StackResultExt, StdResultExt, anyerr, ensure_any};
9use tokio::io::{self, AsyncRead, AsyncWrite, AsyncWriteExt};
10
11use crate::{
12 downstream::SrcAddr,
13 util::{Prebufferable, Prebuffered},
14};
15
16const HOP_BY_HOP_HEADERS: &[HeaderName] = &[
18 header::CONNECTION,
19 header::PROXY_AUTHENTICATE,
20 header::PROXY_AUTHORIZATION,
21 header::TE,
22 header::TRAILER,
23 header::TRANSFER_ENCODING,
24];
25
26const X_FORWARDED_FOR: &str = "x-forwarded-for";
27const X_FORWARDED_HOST: &str = "x-forwarded-host";
28
29const ALLOWED_CONNECTION_HEADERS: &[HeaderName; 1] = &[header::UPGRADE];
30
31pub fn filter_hop_by_hop_headers(headers: &mut HeaderMap<HeaderValue>) {
39 let connection_headers = headers
41 .get_all(header::CONNECTION)
42 .iter()
43 .filter_map(|v| v.to_str().ok())
44 .flat_map(|s| s.split(','))
45 .filter_map(|name| name.trim().parse::<HeaderName>().ok());
46
47 let (connection_keep, connection_remove): (Vec<_>, Vec<_>) =
48 connection_headers.partition(|h| ALLOWED_CONNECTION_HEADERS.contains(h));
49
50 for name in HOP_BY_HOP_HEADERS {
52 headers.remove(name);
53 }
54
55 for name in connection_remove {
57 headers.remove(&name);
58 }
59
60 if !connection_keep.is_empty() {
61 if let Ok(value) = HeaderValue::from_str(&connection_keep.join(", ")) {
62 headers.insert(header::CONNECTION, value);
63 }
64 }
65}
66
67#[derive(Debug, Clone, derive_more::Display, Ord, PartialOrd, Hash, Eq, PartialEq)]
72#[display("{host}:{port}")]
73pub struct Authority {
74 pub host: String,
76 pub port: u16,
78}
79
80impl FromStr for Authority {
81 type Err = n0_error::AnyError;
82 fn from_str(s: &str) -> Result<Self, Self::Err> {
83 Self::from_authority_str(s)
84 }
85}
86
87impl Authority {
88 pub fn new(host: String, port: u16) -> Self {
90 Self { host, port }
91 }
92
93 pub fn from_authority_uri(uri: &Uri) -> Result<Self> {
101 ensure_any!(uri.scheme().is_none(), "Expected URI without scheme");
102 ensure_any!(uri.path_and_query().is_none(), "Expected URI without path");
103 let authority = uri.authority().context("Expected URI with authority")?;
104 let host = authority.host();
105 let port = authority.port_u16().context("Expected URI with port")?;
106 Ok(Self {
107 host: host.to_string(),
108 port,
109 })
110 }
111
112 pub fn from_absolute_uri(uri: &Uri) -> Result<Self> {
122 let authority = uri.authority().context("Expected URI with authority")?;
123 let host = authority.host();
124 let port = match authority.port_u16() {
125 Some(port) => port,
126 None => match uri.scheme() {
127 Some(scheme) if *scheme == Scheme::HTTP => 80,
128 Some(scheme) if *scheme == Scheme::HTTPS => 443,
129 _ => Err(anyerr!("Expected URI with port or http(s) scheme"))?,
130 },
131 };
132 Ok(Self {
133 host: host.to_string(),
134 port,
135 })
136 }
137
138 pub fn from_authority_str(s: &str) -> Result<Self> {
142 Self::from_authority_uri(&Uri::from_str(s).std_context("Invalid authority string")?)
143 }
144
145 pub fn from_absolute_uri_str(s: &str) -> Result<Self> {
149 Self::from_absolute_uri(&Uri::from_str(s).std_context("Invalid authority string")?)
150 }
151
152 pub(super) fn to_addr(&self) -> String {
153 format!("{}:{}", self.host, self.port)
154 }
155
156 pub(crate) fn to_connect_request(&self) -> String {
157 let host = &self.host;
158 let port = &self.port;
159 format!("CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n\r\n")
160 }
161}
162
163pub(crate) fn absolute_target_to_origin_form(target: &Uri) -> Result<Uri> {
166 let path_and_query = target.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
167 Uri::from_str(path_and_query).std_context("invalid path_and_query")
168}
169
170#[derive(Debug)]
175pub struct HttpRequest {
176 pub version: Version,
178 pub headers: HeaderMap<HeaderValue>,
180 pub uri: Uri,
182 pub method: Method,
184}
185
186impl HttpRequest {
187 pub fn from_parts(parts: http::request::Parts) -> Self {
189 Self {
190 version: parts.version,
191 headers: parts.headers,
192 method: parts.method,
193 uri: parts.uri,
194 }
195 }
196
197 pub fn parse_with_len(buf: &[u8]) -> Result<Option<(usize, Self)>> {
202 let mut headers = [httparse::EMPTY_HEADER; 64];
203 let mut req = httparse::Request::new(&mut headers);
204 match req.parse(buf).std_context("Invalid HTTP request")? {
205 httparse::Status::Partial => Ok(None),
206 httparse::Status::Complete(header_len) => {
207 Self::from_parsed_request(req).map(|req| Some((header_len, req)))
208 }
209 }
210 }
211
212 fn from_parsed_request(req: httparse::Request) -> Result<Self> {
214 let method_str = req.method.context("Missing HTTP method")?;
215 let method = method_str.parse().std_context("Invalid HTTP method")?;
216 let path = req.path.context("Missing request target")?;
217 let uri = Uri::from_str(path).std_context("Invalid request target")?;
218 let headers = HeaderMap::from_iter(req.headers.iter_mut().flat_map(|h| {
219 let value = HeaderValue::from_bytes(h.value).ok()?;
220 let name = http::HeaderName::from_bytes(h.name.as_bytes()).ok()?;
221 Some((name, value))
222 }));
223 let version = if req.version == Some(1) {
224 http::Version::HTTP_11
225 } else {
226 http::Version::HTTP_10
227 };
228 Ok(Self {
229 version,
230 headers,
231 uri,
232 method,
233 })
234 }
235
236 pub async fn peek(reader: &mut Prebuffered<impl AsyncRead + Unpin>) -> Result<(usize, Self)> {
242 while !reader.is_full() {
243 reader.buffer_more().await?;
244 if let Some(request) = Self::parse_with_len(reader.buffer())? {
245 return Ok(request);
246 }
247 }
248 Err(io::Error::new(
249 io::ErrorKind::OutOfMemory,
250 "Buffer size limit reached before end of request header section",
251 )
252 .into())
253 }
254
255 pub async fn read(reader: &mut Prebuffered<impl AsyncRead + Unpin>) -> Result<Self> {
260 let (len, response) = Self::peek(reader).await?;
261 reader.discard(len);
262 Ok(response)
263 }
264
265 pub fn parse(buf: &[u8]) -> Result<Option<Self>> {
267 Ok(Self::parse_with_len(buf)?.map(|(_len, req)| req))
268 }
269
270 pub fn try_into_proxy_request(self) -> Result<HttpProxyRequest> {
277 let kind = match self.method {
278 Method::CONNECT => {
279 let target = Authority::from_authority_uri(&self.uri)?;
280 HttpProxyRequestKind::Tunnel { target }
281 }
282 _ => {
283 if self.uri.scheme().is_none() || self.uri.authority().is_none() {
284 return Err(anyerr!("Missing absolute-form request target"));
285 }
286 HttpProxyRequestKind::Absolute {
287 target: self.uri.clone(),
288 method: self.method,
289 }
290 }
291 };
292 Ok(HttpProxyRequest {
293 headers: self.headers,
294 kind,
295 })
296 }
297
298 pub fn host(&self) -> Option<&str> {
303 if self.version >= Version::HTTP_2 {
304 self.uri.host()
305 } else {
306 self.header_str(http::header::HOST)
307 }
308 }
309
310 pub fn header_str(&self, name: impl AsHeaderName) -> Option<&str> {
312 self.headers.get(name).and_then(|x| x.to_str().ok())
313 }
314
315 pub fn classify(&self) -> Result<HttpRequestKind> {
322 let uri = &self.uri;
323 match self.method {
324 Method::CONNECT => {
325 ensure_any!(
326 uri.scheme().is_none()
327 && uri.path_and_query().is_none()
328 && uri.authority().is_some()
329 && uri.authority().and_then(|a| a.port_u16()).is_some(),
330 "Invalid request-target form for CONNECT request"
331 );
332
333 Ok(HttpRequestKind::Tunnel)
334 }
335 _ => {
336 if self.uri.scheme().is_some() && self.version < Version::HTTP_2 {
339 ensure_any!(
340 self.uri.authority().is_some(),
341 "Invalid request target: scheme without authority"
342 );
343 Ok(HttpRequestKind::Http1Absolute)
344 } else {
345 Ok(HttpRequestKind::Origin)
346 }
347 }
348 }
349 }
350
351 pub fn set_forwarded_for(&mut self, src_addr: SocketAddr) -> &mut Self {
356 self.headers.append(
357 X_FORWARDED_FOR,
358 HeaderValue::from_str(&src_addr.to_string()).expect("valid header value"),
359 );
360 self
361 }
362
363 pub fn set_forwarded_for_if_tcp(&mut self, src_addr: SrcAddr) -> &mut Self {
367 match src_addr {
368 SrcAddr::Tcp(addr) => self.set_forwarded_for(addr),
369 #[cfg(unix)]
370 SrcAddr::Unix(_) => self,
371 }
372 }
373
374 pub fn remove_headers(
376 &mut self,
377 names: impl IntoIterator<Item = impl AsHeaderName>,
378 ) -> &mut Self {
379 for header in names {
380 self.headers.remove(header);
381 }
382 self
383 }
384
385 pub fn set_via(
389 &mut self,
390 pseudonym: impl std::fmt::Display,
391 ) -> Result<&mut Self, InvalidHeaderValue> {
392 self.headers.append(
393 header::VIA,
394 HeaderValue::from_str(&format!("{:?} {}", self.version, pseudonym))?,
395 );
396 Ok(self)
397 }
398
399 pub fn set_target(&mut self, target: Uri) -> Result<&mut Self, InvalidHeaderValue> {
403 if let Some(original_host) = self.headers.remove(header::HOST) {
404 self.headers.insert(X_FORWARDED_HOST, original_host);
405 }
406 if let Some(authority) = target.authority() {
407 self.headers
408 .insert(header::HOST, HeaderValue::from_str(authority.as_str())?);
409 }
410 self.uri = target;
411 Ok(self)
412 }
413
414 pub fn set_absolute_http_authority(&mut self, authority: Authority) -> Result<&mut Self> {
419 let mut parts = self.uri.clone().into_parts();
420 parts.authority = Some(authority.to_string().parse().anyerr()?);
421 parts.scheme = Some(Scheme::HTTP);
422 let uri = Uri::from_parts(parts).anyerr()?;
423 self.set_target(uri).anyerr()?;
424 Ok(self)
425 }
426
427 pub(crate) async fn write(
428 &self,
429 writer: &mut (impl AsyncWrite + Send + Unpin),
430 ) -> io::Result<()> {
431 let Self {
432 method,
433 uri,
434 headers,
435 ..
436 } = self;
437 writer.write_all(method.as_str().as_bytes()).await?;
438 writer.write_all(b" ").await?;
439 if let Some(s) = uri.scheme() {
440 writer.write_all(s.as_str().as_bytes()).await?;
441 writer.write_all(b"://").await?;
442 }
443 if let Some(s) = uri.authority() {
444 writer.write_all(s.as_str().as_bytes()).await?;
445 }
446 writer.write_all(uri.path().as_bytes()).await?;
447 if let Some(s) = uri.query() {
448 writer.write_all(b"?").await?;
449 writer.write_all(s.as_bytes()).await?;
450 }
451 writer.write_all(b" HTTP/1.1\r\n").await?;
452 for (key, value) in headers.iter() {
453 writer.write_all(key.as_str().as_bytes()).await?;
454 writer.write_all(b": ").await?;
455 writer.write_all(value.as_bytes()).await?;
456 writer.write_all(b"\r\n").await?;
457 }
458 writer.write_all(b"\r\n").await?;
459 Ok(())
460 }
461}
462
463#[derive(Debug, Eq, PartialEq)]
465pub enum HttpRequestKind {
466 Tunnel,
468 Http1Absolute,
472 Origin,
474}
475
476#[derive(Debug, Hash, Eq, PartialEq)]
481pub enum HttpProxyRequestKind {
482 Tunnel {
484 target: Authority,
486 },
487 Absolute {
489 target: Uri,
491 method: Method,
493 },
494}
495
496impl HttpProxyRequestKind {
497 pub fn authority(&self) -> Result<Authority> {
501 match self {
502 HttpProxyRequestKind::Tunnel { target } => Ok(target.clone()),
503 HttpProxyRequestKind::Absolute { target, .. } => {
504 let target = Authority::from_absolute_uri(&target)?;
505 Ok(target)
506 }
507 }
508 }
509}
510
511#[derive(derive_more::Debug)]
516pub struct HttpProxyRequest {
517 pub kind: HttpProxyRequestKind,
519 pub headers: HeaderMap<http::HeaderValue>,
521}
522
523#[derive(derive_more::Debug)]
528pub struct HttpResponse {
529 pub status: StatusCode,
531 pub reason: Option<String>,
533 pub headers: HeaderMap<http::HeaderValue>,
535}
536
537impl HttpResponse {
538 pub(crate) fn new(status: StatusCode) -> Self {
539 Self {
540 status,
541 reason: None,
542 headers: HeaderMap::new(),
543 }
544 }
545
546 pub(crate) fn with_reason(status: StatusCode, reason: impl ToString) -> Self {
547 Self {
548 status,
549 reason: Some(reason.to_string()),
550 headers: HeaderMap::new(),
551 }
552 }
553
554 pub(crate) fn no_body(mut self) -> Self {
555 self.headers.insert(
556 http::header::CONTENT_LENGTH,
557 HeaderValue::from_str("0").unwrap(),
558 );
559 self
560 }
561
562 pub(crate) async fn write(
563 &self,
564 writer: &mut (impl AsyncWrite + Send + Unpin),
565 finalize: bool,
566 ) -> io::Result<()> {
567 writer.write_all(self.status_line().as_bytes()).await?;
568 for (key, value) in self.headers.iter() {
569 writer.write_all(key.as_str().as_bytes()).await?;
570 writer.write_all(b": ").await?;
571 writer.write_all(value.as_bytes()).await?;
572 writer.write_all(b"\r\n").await?;
573 }
574 if finalize {
575 writer.write_all(b"\r\n").await?;
576 }
577 Ok(())
578 }
579
580 pub fn reason(&self) -> &str {
582 self.reason
583 .as_deref()
584 .or(self.status.canonical_reason())
585 .unwrap_or("")
586 }
587
588 pub fn status_line(&self) -> String {
590 format!(
591 "HTTP/1.1 {} {}\r\n",
592 self.status.as_u16(),
593 self.reason
594 .as_deref()
595 .or(self.status.canonical_reason())
596 .unwrap_or("")
597 )
598 }
599
600 pub fn parse(buf: &[u8]) -> Result<Option<Self>> {
602 Ok(Self::parse_with_len(buf)?.map(|(_len, res)| res))
603 }
604
605 pub fn parse_with_len(buf: &[u8]) -> Result<Option<(usize, Self)>> {
609 let mut headers = [httparse::EMPTY_HEADER; 64];
610 let mut res = httparse::Response::new(&mut headers);
611 match res
612 .parse(buf)
613 .std_context("Failed to parse HTTP response")?
614 {
615 httparse::Status::Partial => Ok(None),
616 httparse::Status::Complete(header_len) => {
617 let code = res.code.context("Missing response status code")?;
618 let status =
619 StatusCode::from_u16(code).std_context("Invalid response status code")?;
620 let reason = res.reason.map(ToOwned::to_owned);
621 let headers = HeaderMap::from_iter(res.headers.iter().flat_map(|h| {
622 let value = HeaderValue::from_bytes(h.value).ok()?;
623 let name = http::HeaderName::from_bytes(h.name.as_bytes()).ok()?;
624 Some((name, value))
625 }));
626 Ok(Some((
627 header_len,
628 HttpResponse {
629 status,
630 reason,
631 headers,
632 },
633 )))
634 }
635 }
636 }
637
638 pub async fn peek(reader: &mut impl Prebufferable) -> Result<(usize, Self)> {
643 while !reader.is_full() {
644 reader.buffer_more().await?;
645 if let Some(response) = Self::parse_with_len(reader.buffer())? {
646 return Ok(response);
647 }
648 }
649
650 Err(io::Error::new(
651 io::ErrorKind::OutOfMemory,
652 "Buffer size limit reached before end of response header section",
653 )
654 .into())
655 }
656
657 pub async fn read(reader: &mut impl Prebufferable) -> Result<Self> {
661 let (len, response) = Self::peek(reader).await?;
662 reader.discard(len);
663 Ok(response)
664 }
665}