proxy_http/
lib.rs

1#![feature(doc_auto_cfg)]
2#![feature(doc_cfg)]
3
4use std::{net::SocketAddr, sync::Arc};
5
6use base64::{Engine, engine::general_purpose};
7use http_body_util::{BodyExt, Full};
8use hyper::{
9  Request, Response, StatusCode,
10  body::{Bytes, Incoming},
11  header::PROXY_AUTHORIZATION,
12};
13use hyper_util::rt::TokioIo;
14use proxy_fetch::Fetch;
15use tokio::net::TcpListener;
16
17mod error;
18
19pub use error::{Error, Result};
20pub async fn run(
21  fetch: impl Into<Arc<Fetch>>,
22  addr: SocketAddr,
23  user: impl AsRef<str>,
24  password: impl AsRef<str>,
25) -> Result<()> {
26  let fetch = fetch.into();
27  let user = user.as_ref().to_string();
28  let password = password.as_ref().to_string();
29
30  let listener = TcpListener::bind(addr).await?;
31
32  let user = Arc::new(user);
33  let password = Arc::new(password);
34
35  loop {
36    let (stream, _) = listener.accept().await?;
37
38    let io = TokioIo::new(stream);
39
40    let fetch = Arc::clone(&fetch);
41    let user = Arc::clone(&user);
42    let password = Arc::clone(&password);
43
44    tokio::task::spawn(async move {
45      let service = hyper::service::service_fn(move |req| {
46        handle(
47          req,
48          Arc::clone(&fetch),
49          Arc::clone(&user),
50          Arc::clone(&password),
51        )
52      });
53
54      if let Err(err) = hyper::server::conn::http1::Builder::new()
55        .serve_connection(io, service)
56        .await
57      {
58        eprintln!("Error serving connection: {:?}", err);
59      }
60    });
61  }
62}
63
64fn is_authorized(req: &Request<Incoming>, user: &str, password: &str) -> bool {
65  if user.is_empty() || password.is_empty() {
66    return true;
67  }
68  match req.headers().get(PROXY_AUTHORIZATION) {
69    Some(header) => {
70      if let Ok(header) = header.to_str()
71        && let Some(credentials) = header.strip_prefix("Basic ")
72        && let Ok(decoded) = general_purpose::STANDARD.decode(credentials)
73        && let Ok(decoded_str) = String::from_utf8(decoded)
74      {
75        let mut parts = decoded_str.splitn(2, ':');
76        if let (Some(u), Some(p)) = (parts.next(), parts.next()) {
77          return u == user && p == password;
78        }
79      }
80      false
81    }
82    None => false,
83  }
84}
85
86async fn handle(
87  mut req: Request<Incoming>,
88  fetch: Arc<Fetch>,
89  user: Arc<String>,
90  password: Arc<String>,
91) -> std::result::Result<Response<Full<Bytes>>, hyper::Error> {
92  if !is_authorized(&req, &user, &password) {
93    let mut res = Response::new(Full::new(Bytes::from("Proxy Authentication Required")));
94    *res.status_mut() = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
95    return Ok(res);
96  }
97
98  let method = req.method().clone();
99  let uri = req.uri().clone();
100  let headers = req.headers().clone();
101
102  let body = match req.body_mut().collect().await {
103    Ok(body) => body.to_bytes(),
104    Err(e) => {
105      eprintln!("Failed to collect request body: {}", e);
106      let response = Response::builder()
107        .status(StatusCode::BAD_REQUEST)
108        .body(Full::new(Bytes::from("Bad Request")))
109        .unwrap();
110      return Ok(response);
111    }
112  };
113
114  let uri = uri.to_string();
115
116  let uri = if let Some(remain) = uri.strip_prefix("http:") {
117    "https:".to_string() + remain
118  } else {
119    uri
120  };
121
122  match fetch.run(method, uri, headers, Some(body)).await {
123    Ok(res) => {
124      let mut builder = Response::builder().status(res.status);
125      for (key, value) in res.headers {
126        if let Some(key) = key {
127          builder = builder.header(key, value);
128        }
129      }
130      Ok(builder.body(Full::new(res.body)).unwrap())
131    }
132    Err(e) => {
133      eprintln!("Fetch error: {}", e);
134      Ok(
135        Response::builder()
136          .status(StatusCode::INTERNAL_SERVER_ERROR)
137          .body(Full::new(Bytes::from("Internal Server Error")))
138          .unwrap(),
139      )
140    }
141  }
142}