1use std::{
2 cmp::min,
3 collections::HashSet,
4 convert::{TryFrom, TryInto},
5 io::Read,
6 sync::{Arc, Mutex, OnceLock, atomic::{AtomicBool, Ordering}},
7 time::Duration,
8};
9
10use futures::StreamExt;
11use hyper::{
12 Body, Client, Request, Response, StatusCode, Uri,
13 body::{Bytes, HttpBody},
14 header::{self, HeaderValue},
15 http::uri::Scheme,
16};
17use hyper_tls::HttpsConnector;
18use tokio::sync::broadcast::Sender;
19
20use crate::{Action, Config, ProxyTarget, inject};
21
22use super::{Context, SERVER_HEADER};
23
24
25const PROXY_ERROR_HTML: &str = include_str!("../assets/proxy-error.html");
27
28pub(crate) struct ProxyContext {
29 is_polling_target: Arc<AtomicBool>,
30}
31
32impl ProxyContext {
33 pub(crate) fn new() -> Self {
34 Self {
35 is_polling_target: Arc::new(AtomicBool::new(false)),
36 }
37 }
38}
39
40pub(crate) async fn forward(
46 mut req: Request<Body>,
47 target: &ProxyTarget,
48 ctx: &Context,
49 actions: Sender<Action>,
50) -> Response<Body> {
51 adjust_request(&mut req, target);
52 let uri = req.uri().clone();
53
54 log::trace!("Forwarding request to proxy target {}", uri);
55 let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new());
56 match client.request(req).await {
57 Ok(response) => adjust_response(response, ctx, &uri, target, &ctx.config).await,
58 Err(e) => {
59 log::warn!("Failed to reach proxy target '{}': {}", uri, e);
60 let msg = format!("Failed to reach {}\n\n{}", uri, e);
61 start_polling(&ctx.proxy, target, actions);
62 gateway_error(&msg, e, &ctx.config)
63 }
64 }
65}
66
67fn adjust_request(req: &mut Request<Body>, target: &ProxyTarget) {
68 let uri = {
70 let mut parts = req.uri().clone().into_parts();
71 parts.scheme = Some(target.scheme.clone());
72 parts.authority = Some(target.authority.clone());
73 Uri::from_parts(parts).expect("bug: invalid URI")
74 };
75 *req.uri_mut() = uri.clone();
76
77 if let Some(host) = req.headers_mut().get_mut(header::HOST) {
79 *host = HeaderValue::from_str(target.authority.as_str())
82 .expect("bug: URI authority should be ASCII");
83 }
84
85 if let Some(header) = req.headers_mut().get_mut(header::ACCEPT_ENCODING) {
87 let value = header.to_str()
94 .expect("'accept-encoding' header value contains non-ASCII bytes");
95 let new_value = filter_encodings(&value);
96
97 if new_value.is_empty() {
98 req.headers_mut().remove(header::ACCEPT_ENCODING);
99 } else {
100 *header = HeaderValue::try_from(new_value)
102 .expect("bug: non-ASCII values in new 'accept-encoding' header");
103 }
104 }
105}
106
107const SUPPORTED_COMPRESSIONS: &[&str] = &["gzip", "br", "identity"];
111
112fn download_body_error(e: hyper::Error, uri: &Uri, ctx: &Context) -> Response<Body> {
113 log::warn!("Failed to download full response from proxy target");
114 let msg = format!("Failed to download response from {}\n\n{}", uri, e);
115 return gateway_error(&msg, e, &ctx.config);
116}
117
118async fn adjust_response(
119 mut response: Response<Body>,
120 ctx: &Context,
121 uri: &Uri,
122 target: &ProxyTarget,
123 config: &Config,
124) -> Response<Body> {
125 if let Some(header) = response.headers_mut().get_mut(header::LOCATION) {
127 rewrite_location(header, target, config);
128 }
129
130 let (mut parts, mut body) = response.into_parts();
132 let mut body_start = vec![];
133 while !body.is_end_stream() && body_start.len() < 512 {
134 match body.data().await {
135 None => break,
136 Some(Err(e)) => return download_body_error(e, uri, ctx),
137 Some(Ok(bytes)) => body_start.extend_from_slice(&bytes),
138 }
139 }
140
141 let html_content_type = parts.headers.get(header::CONTENT_TYPE).map(|v| {
142 v.as_bytes().starts_with(b"text/html")
143 || v.as_bytes().starts_with(b"application/xhtml+xml")
144 });
145 let looks_like_html = body_start.iter().all(|b| *b != 0)
146 && infer::text::is_html(&body_start);
147
148 let uri_pq = uri.path_and_query().map(|pq| pq.to_string()).unwrap_or_default();
149 macro_rules! warn_once {
150 ($($t:tt)*) => {
151 static ALREADY_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
152 let newly_inserted = ALREADY_WARNED
153 .get_or_init(|| Mutex::new(HashSet::new()))
154 .lock()
155 .unwrap()
156 .insert(uri_pq.clone());
157 if newly_inserted {
158 log::warn!($($t)*);
159 }
160 };
161 }
162
163 let adjust_body = match (html_content_type, looks_like_html) {
165 (None, true) => {
166 warn_once!("Proxy response to '{uri_pq}' looks like HTML, but no 'Content-Type' \
167 header exists. I will treat it as HTML (injecting reload script), but setting \
168 the correct 'Content-Type' header is recommended.",
169 );
170 true
171 }
172 (None, false) => false,
173 (Some(true), true) => true,
174 (Some(false), true) => {
175 let header_bytes = parts.headers.get(header::CONTENT_TYPE).unwrap().as_bytes();
176 warn_once!("Proxy response to '{uri_pq}' looks like HTML, but the 'Content-Type' \
177 header indicates otherwise: '{}'. Not injecting reload script.",
178 String::from_utf8_lossy(header_bytes),
179 );
180 false
181 }
182 (Some(v), false) => v,
183 };
184
185 if !adjust_body {
186 let recombined_body = Body::wrap_stream(
187 futures::stream::once(async { Ok(Bytes::from(body_start)) }).chain(body)
188 );
189
190 return Response::from_parts(parts, recombined_body);
191 }
192
193
194 log::trace!("Response from proxy is HTML: injecting script");
195
196 while let Some(buf) = body.data().await {
199 match buf {
200 Ok(buf) => body_start.extend_from_slice(&buf),
201 Err(e) => return download_body_error(e, uri, ctx),
202 }
203 }
204 let body = body_start;
205
206 let new_body = match parts.headers.get(header::CONTENT_ENCODING).map(|v| v.as_bytes()) {
210 None => Bytes::from(inject::into(&body, &ctx.config)),
211
212 Some(b"gzip") => {
213 let mut decompressed = Vec::new();
214 flate2::read::GzDecoder::new(&*body).read_to_end(&mut decompressed)
215 .expect("unexpected error while decompressing GZIP");
216 let injected = inject::into(&decompressed, &ctx.config);
217 let mut out = Vec::new();
218 flate2::read::GzEncoder::new(&*injected, flate2::Compression::best())
219 .read_to_end(&mut out)
220 .expect("unexpected error while compressing GZIP");
221 Bytes::from(out)
222 }
223
224 Some(b"br") => {
225 let mut decompressed = Vec::new();
226 brotli::BrotliDecompress(&mut &*body, &mut decompressed)
227 .expect("unexpected error while decompressing Brotli");
228 let injected = inject::into(&decompressed, &ctx.config);
229 let mut out = Vec::new();
230 brotli::BrotliCompress(&mut &*injected, &mut out, &Default::default())
231 .expect("unexpected error while compressing Brotli");
232 Bytes::from(out)
233 }
234
235 Some(other) => {
236 log::warn!(
237 "Unsupported content encoding '{}'. Not injecting script!",
238 String::from_utf8_lossy(other),
239 );
240 Bytes::from(body)
241 }
242 };
243
244 if let Some(content_len) = parts.headers.get_mut(header::CONTENT_LENGTH) {
245 *content_len = new_body.len().into();
246 }
247
248 if let header::Entry::Occupied(mut e) = parts.headers.entry(header::CONTENT_SECURITY_POLICY) {
254 e.iter_mut().for_each(rewrite_csp);
255 }
256
257
258 Response::from_parts(parts, new_body.into())
259}
260
261fn rewrite_csp(header: &mut HeaderValue) {
266 use std::collections::{BTreeMap, btree_map::Entry};
267
268 let mut directives = BTreeMap::new();
271 header.as_bytes()
272 .split(|b| *b == b';')
274 .filter(|part| !part.is_empty())
276 .filter_map(|part| std::str::from_utf8(part).ok())
277 .for_each(|part| {
278 let mut split = part.trim().split_whitespace();
282 let name = split.next()
283 .expect("empty split iterator for non-empty string")
284 .to_ascii_lowercase();
285
286 match directives.entry(name) {
287 Entry::Occupied(entry) => {
293 log::warn!("CSP malformed, second {} directive ignored", entry.key());
294 }
295
296 Entry::Vacant(entry) => {
298 entry.insert(split.collect::<Vec<_>>());
299 }
300 }
301 });
302
303
304 let scripts_from_self_allowed = directives.get("script-src")
308 .or_else(|| directives.get("default-src"))
309 .map_or(true, |v| v.contains(&"'self'") || v.contains(&"*"));
310
311 let connect_to_self_allowed = directives.get("connect-src")
312 .or_else(|| directives.get("default-src"))
313 .map_or(true, |v| v.contains(&"'self'") || v.contains(&"*"));
314
315
316 if scripts_from_self_allowed && connect_to_self_allowed {
317 log::trace!("CSP header already allows scripts from and connect to 'self', not modifying");
318 return;
319 }
320
321 if !scripts_from_self_allowed {
323 let script_sources = directives.entry("script-src".to_owned()).or_default();
324 script_sources.retain(|src| *src != "'none'");
325 script_sources.push("'self'");
326 }
327 if !connect_to_self_allowed {
328 let script_sources = directives.entry("connect-src".to_owned()).or_default();
329 script_sources.retain(|src| *src != "'none'");
330 script_sources.push("'self'");
331 }
332
333 let mut out = String::new();
335 for (name, values) in directives {
336 use std::fmt::Write;
337
338 out.push_str(&name);
339 values.iter().for_each(|v| write!(out, " {v}").unwrap());
340 out.push_str("; ");
341 }
342
343 log::trace!("Modified CSP header \nfrom {header:?} \nto \"{out}\"");
346 *header = HeaderValue::from_str(&out)
347 .expect("modified CSP header has non-ASCII chars");
348}
349
350fn rewrite_location(header: &mut HeaderValue, target: &ProxyTarget, config: &Config) {
351 let value = match std::str::from_utf8(header.as_bytes()) {
352 Err(_) => {
353 log::warn!("Non UTF-8 'location' header: not rewriting");
354 return;
355 }
356 Ok(v) => v,
357 };
358
359 let mut uri = match value.parse::<Uri>() {
360 Err(_) => {
361 log::warn!("Could not parse 'location' header as URI: not rewriting");
362 return;
363 }
364 Ok(uri) => uri.into_parts(),
365 };
366
367 if uri.authority.as_ref() == Some(&target.authority) {
371 uri.scheme = Some(Scheme::HTTP);
373 let authority = config.bind_addr.to_string()
374 .try_into()
375 .expect("bind addr is not a valid authority");
376 uri.authority = Some(authority);
377
378 let uri = Uri::from_parts(uri).expect("bug: failed to build URI");
379 *header = HeaderValue::from_bytes(uri.to_string().as_bytes())
380 .expect("bug: new 'location' is invalid header value");
381 }
382}
383
384fn gateway_error(msg: &str, e: hyper::Error, config: &Config) -> Response<Body> {
385 let html = PROXY_ERROR_HTML
386 .replace("{{ error }}", msg)
387 .replace("{{ control_path }}", config.control_path());
388
389 let status = if e.is_timeout() {
390 StatusCode::GATEWAY_TIMEOUT
391 } else {
392 StatusCode::BAD_GATEWAY
393 };
394
395 Response::builder()
396 .status(status)
397 .header("Server", SERVER_HEADER)
398 .header("Content-Type", "text/html")
399 .body(html.into())
400 .unwrap()
401}
402
403fn start_polling(ctx: &ProxyContext, target: &ProxyTarget, actions: Sender<Action>) {
407 let is_polling = Arc::clone(&ctx.is_polling_target);
409 if is_polling.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
410 return;
411 }
412
413 let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new());
414 let uri = Uri::builder()
415 .scheme(target.scheme.clone())
416 .authority(target.authority.clone())
417 .path_and_query("/")
418 .build()
419 .unwrap();
420
421 log::info!("Start regularly polling '{}' until it is available...", uri);
422 tokio::spawn(async move {
423 const MAX_SLEEP_DURATION: Duration = Duration::from_secs(3);
425 let mut sleep_duration = Duration::from_millis(250);
426
427 loop {
428 tokio::time::sleep(sleep_duration).await;
429 sleep_duration = min(sleep_duration.mul_f32(1.5), MAX_SLEEP_DURATION);
430
431 log::trace!("Trying to connect to '{}' again", uri);
432 if client.get(uri.clone()).await.is_ok() {
433 log::debug!("Reconnected to proxy target, reloading all active browser sessions");
434 let _ = actions.send(Action::Reload);
435 is_polling.store(false, Ordering::SeqCst);
436 break;
437 }
438 }
439 });
440}
441
442fn filter_encodings(orig: &str) -> String {
445 let allowed_values = orig.split(',')
446 .map(|part| part.trim())
447 .filter(|part| {
448 let encoding = part.split_once(';').map(|p| p.0).unwrap_or(part);
449 SUPPORTED_COMPRESSIONS.contains(&encoding)
450 });
451
452 let mut new_value = String::new();
453 for (i, part) in allowed_values.enumerate() {
454 if i != 0 {
455 new_value.push_str(", ");
456 }
457 new_value.push_str(part);
458 }
459 new_value
460}
461
462
463#[cfg(test)]
464mod tests {
465 #[test]
466 fn encoding_filter() {
467 use super::filter_encodings as filter;
468
469 assert_eq!(filter(""), "");
470 assert_eq!(filter("gzip"), "gzip");
471 assert_eq!(filter("br"), "br");
472 assert_eq!(filter("gzip, br"), "gzip, br");
473 assert_eq!(filter("gzip, deflate"), "gzip");
474 assert_eq!(filter("deflate, gzip"), "gzip");
475 assert_eq!(filter("gzip, deflate, br"), "gzip, br");
476 assert_eq!(filter("deflate, gzip, br"), "gzip, br");
477 assert_eq!(filter("gzip, br, deflate"), "gzip, br");
478 assert_eq!(filter("deflate"), "");
479 assert_eq!(filter("br;q=1.0, deflate;q=0.5, gzip;q=0.8, *;q=0.1"), "br;q=1.0, gzip;q=0.8");
480 }
481
482 #[test]
483 fn modify_csp() {
484 #[track_caller]
485 fn assert_rewritten(original: &str, expected_rewritten: &str) {
486 let mut header = hyper::header::HeaderValue::from_str(original).unwrap();
487 super::rewrite_csp(&mut header);
488 if header.to_str().unwrap() != expected_rewritten {
489 panic!(
490 "unexpected rewritten CSP header:\n\
491 original: {}\n\
492 expected: {}\n\
493 actual: {}\n",
494 original,
495 expected_rewritten,
496 header.to_str().unwrap(),
497 );
498 }
499 }
500
501 #[track_caller]
502 fn assert_not_rewritten(original: &str) {
503 assert_rewritten(original, original);
504 }
505
506 assert_not_rewritten("default-src *");
507 assert_not_rewritten("default-src 'self'");
508 assert_not_rewritten("default-src 'self' https://google.com");
509 assert_not_rewritten("default-src 'none' https://google.com; \
510 script-src 'self'; connect-src *");
511
512 assert_rewritten(
513 "default-src 'none'",
514 "connect-src 'self'; default-src 'none'; script-src 'self'; ",
515 );
516 assert_rewritten(
517 "default-src 'none'; script-src http:",
518 "connect-src 'self'; default-src 'none'; script-src http: 'self'; ",
519 );
520 assert_rewritten(
521 "default-src 'self'; connect-src 'none'",
522 "connect-src 'self'; default-src 'self'; ",
523 );
524 assert_rewritten(
525 "default-src 'self'; script-src https:",
526 "default-src 'self'; script-src https: 'self'; ",
527 );
528 }
529}