1#![feature(doc_auto_cfg)]
2#![feature(doc_cfg)]
3
4use std::{net::SocketAddr, sync::Arc};
5
6use base64::{Engine, engine::general_purpose};
7use http_body_util::{BodyExt, Full};
8use hyper::{
9 Request, Response, StatusCode,
10 body::{Bytes, Incoming},
11 header::PROXY_AUTHORIZATION,
12};
13use hyper_util::rt::TokioIo;
14use proxy_fetch::Fetch;
15use tokio::net::TcpListener;
16
17mod error;
18
19pub use error::{Error, Result};
20pub async fn run(
21 fetch: impl Into<Arc<Fetch>>,
22 addr: SocketAddr,
23 user: impl AsRef<str>,
24 password: impl AsRef<str>,
25) -> Result<()> {
26 let fetch = fetch.into();
27 let user = user.as_ref().to_string();
28 let password = password.as_ref().to_string();
29
30 let listener = TcpListener::bind(addr).await?;
31
32 let user = Arc::new(user);
33 let password = Arc::new(password);
34
35 loop {
36 let (stream, _) = listener.accept().await?;
37
38 let io = TokioIo::new(stream);
39
40 let fetch = Arc::clone(&fetch);
41 let user = Arc::clone(&user);
42 let password = Arc::clone(&password);
43
44 tokio::task::spawn(async move {
45 let service = hyper::service::service_fn(move |req| {
46 handle(
47 req,
48 Arc::clone(&fetch),
49 Arc::clone(&user),
50 Arc::clone(&password),
51 )
52 });
53
54 if let Err(err) = hyper::server::conn::http1::Builder::new()
55 .serve_connection(io, service)
56 .await
57 {
58 eprintln!("Error serving connection: {:?}", err);
59 }
60 });
61 }
62}
63
64fn is_authorized(req: &Request<Incoming>, user: &str, password: &str) -> bool {
65 if user.is_empty() || password.is_empty() {
66 return true;
67 }
68 match req.headers().get(PROXY_AUTHORIZATION) {
69 Some(header) => {
70 if let Ok(header) = header.to_str()
71 && let Some(credentials) = header.strip_prefix("Basic ")
72 && let Ok(decoded) = general_purpose::STANDARD.decode(credentials)
73 && let Ok(decoded_str) = String::from_utf8(decoded)
74 {
75 let mut parts = decoded_str.splitn(2, ':');
76 if let (Some(u), Some(p)) = (parts.next(), parts.next()) {
77 return u == user && p == password;
78 }
79 }
80 false
81 }
82 None => false,
83 }
84}
85
86async fn handle(
87 mut req: Request<Incoming>,
88 fetch: Arc<Fetch>,
89 user: Arc<String>,
90 password: Arc<String>,
91) -> std::result::Result<Response<Full<Bytes>>, hyper::Error> {
92 if !is_authorized(&req, &user, &password) {
93 let mut res = Response::new(Full::new(Bytes::from("Proxy Authentication Required")));
94 *res.status_mut() = StatusCode::PROXY_AUTHENTICATION_REQUIRED;
95 return Ok(res);
96 }
97
98 let method = req.method().clone();
99 let uri = req.uri().clone();
100 let headers = req.headers().clone();
101
102 let body = match req.body_mut().collect().await {
103 Ok(body) => body.to_bytes(),
104 Err(e) => {
105 eprintln!("Failed to collect request body: {}", e);
106 let response = Response::builder()
107 .status(StatusCode::BAD_REQUEST)
108 .body(Full::new(Bytes::from("Bad Request")))
109 .unwrap();
110 return Ok(response);
111 }
112 };
113
114 let uri = uri.to_string();
115
116 let uri = if let Some(remain) = uri.strip_prefix("http:") {
117 "https:".to_string() + remain
118 } else {
119 uri
120 };
121
122 match fetch.run(method, uri, headers, Some(body)).await {
123 Ok(res) => {
124 let mut builder = Response::builder().status(res.status);
125 for (key, value) in res.headers {
126 if let Some(key) = key {
127 builder = builder.header(key, value);
128 }
129 }
130 Ok(builder.body(Full::new(res.body)).unwrap())
131 }
132 Err(e) => {
133 eprintln!("Fetch error: {}", e);
134 Ok(
135 Response::builder()
136 .status(StatusCode::INTERNAL_SERVER_ERROR)
137 .body(Full::new(Bytes::from("Internal Server Error")))
138 .unwrap(),
139 )
140 }
141 }
142}