penguin/serve/
proxy.rs

1use std::{
2    cmp::min,
3    collections::HashSet,
4    convert::{TryFrom, TryInto},
5    io::Read,
6    sync::{Arc, Mutex, OnceLock, atomic::{AtomicBool, Ordering}},
7    time::Duration,
8};
9
10use futures::StreamExt;
11use hyper::{
12    Body, Client, Request, Response, StatusCode, Uri,
13    body::{Bytes, HttpBody},
14    header::{self, HeaderValue},
15    http::uri::Scheme,
16};
17use hyper_tls::HttpsConnector;
18use tokio::sync::broadcast::Sender;
19
20use crate::{Action, Config, ProxyTarget, inject};
21
22use super::{Context, SERVER_HEADER};
23
24
25/// HTML content to reply in case an error occurs when connecting to the proxy.
26const PROXY_ERROR_HTML: &str = include_str!("../assets/proxy-error.html");
27
28pub(crate) struct ProxyContext {
29    is_polling_target: Arc<AtomicBool>,
30}
31
32impl ProxyContext {
33    pub(crate) fn new() -> Self {
34        Self {
35            is_polling_target: Arc::new(AtomicBool::new(false)),
36        }
37    }
38}
39
40/// Forwards the given request to the specified proxy target and returns its
41/// response.
42///
43/// If the proxy target cannot be reached, a 502 Bad Gateway or 504 Gateway
44/// Timeout response is returned.
45pub(crate) async fn forward(
46    mut req: Request<Body>,
47    target: &ProxyTarget,
48    ctx: &Context,
49    actions: Sender<Action>,
50) -> Response<Body> {
51    adjust_request(&mut req, target);
52    let uri = req.uri().clone();
53
54    log::trace!("Forwarding request to proxy target {}", uri);
55    let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new());
56    match client.request(req).await {
57        Ok(response) => adjust_response(response, ctx, &uri, target, &ctx.config).await,
58        Err(e) => {
59            log::warn!("Failed to reach proxy target '{}': {}", uri, e);
60            let msg = format!("Failed to reach {}\n\n{}", uri, e);
61            start_polling(&ctx.proxy, target, actions);
62            gateway_error(&msg, e, &ctx.config)
63        }
64    }
65}
66
67fn adjust_request(req: &mut Request<Body>, target: &ProxyTarget) {
68    // Change the URI to the proxy target.
69    let uri = {
70        let mut parts = req.uri().clone().into_parts();
71        parts.scheme = Some(target.scheme.clone());
72        parts.authority = Some(target.authority.clone());
73        Uri::from_parts(parts).expect("bug: invalid URI")
74    };
75    *req.uri_mut() = uri.clone();
76
77    // If the `host` header is set, we need to adjust it, too.
78    if let Some(host) = req.headers_mut().get_mut(header::HOST) {
79        // `http::Uri` already does not parse non-ASCII hosts. Unicode hosts
80        // have to be encoded as punycode.
81        *host = HeaderValue::from_str(target.authority.as_str())
82            .expect("bug: URI authority should be ASCII");
83    }
84
85    // Deal with compression.
86    if let Some(header) = req.headers_mut().get_mut(header::ACCEPT_ENCODING) {
87        // In a production product, panicking here is not OK. But all encodings
88        // listed in [1] and the syntax described in [2] only contain ASCII
89        // bytes. So non-ASCII bytes here are highly unlikely.
90        //
91        // [1]: https://www.iana.org/assignments/http-parameters/http-parameters.xml
92        // [2]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
93        let value = header.to_str()
94            .expect("'accept-encoding' header value contains non-ASCII bytes");
95        let new_value = filter_encodings(&value);
96
97        if new_value.is_empty() {
98            req.headers_mut().remove(header::ACCEPT_ENCODING);
99        } else {
100            // It was ASCII before and we do not add any non-ASCII values.
101            *header = HeaderValue::try_from(new_value)
102                .expect("bug: non-ASCII values in new 'accept-encoding' header");
103        }
104    }
105}
106
107/// We support only gzip and brotli. But according to this statistics, those two
108/// make up the vast majority of requests:
109/// https://almanac.httparchive.org/en/2019/compression
110const SUPPORTED_COMPRESSIONS: &[&str] = &["gzip", "br", "identity"];
111
112fn download_body_error(e: hyper::Error, uri: &Uri, ctx: &Context) -> Response<Body> {
113    log::warn!("Failed to download full response from proxy target");
114    let msg = format!("Failed to download response from {}\n\n{}", uri, e);
115    return gateway_error(&msg, e, &ctx.config);
116}
117
118async fn adjust_response(
119    mut response: Response<Body>,
120    ctx: &Context,
121    uri: &Uri,
122    target: &ProxyTarget,
123    config: &Config,
124) -> Response<Body> {
125    // Rewrite `location` header if it's present.
126    if let Some(header) = response.headers_mut().get_mut(header::LOCATION) {
127        rewrite_location(header, target, config);
128    }
129
130    // Download the beginning of the body for sniffing.
131    let (mut parts, mut body) = response.into_parts();
132    let mut body_start = vec![];
133    while !body.is_end_stream() && body_start.len() < 512 {
134        match body.data().await {
135            None => break,
136            Some(Err(e)) => return download_body_error(e, uri, ctx),
137            Some(Ok(bytes)) => body_start.extend_from_slice(&bytes),
138        }
139    }
140
141    let html_content_type = parts.headers.get(header::CONTENT_TYPE).map(|v| {
142        v.as_bytes().starts_with(b"text/html")
143            || v.as_bytes().starts_with(b"application/xhtml+xml")
144    });
145    let looks_like_html = body_start.iter().all(|b| *b != 0)
146        && infer::text::is_html(&body_start);
147
148    let uri_pq = uri.path_and_query().map(|pq| pq.to_string()).unwrap_or_default();
149    macro_rules! warn_once {
150        ($($t:tt)*) => {
151            static ALREADY_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
152            let newly_inserted = ALREADY_WARNED
153                .get_or_init(|| Mutex::new(HashSet::new()))
154                .lock()
155                .unwrap()
156                .insert(uri_pq.clone());
157            if newly_inserted {
158                log::warn!($($t)*);
159            }
160        };
161    }
162
163    // Determine if we should treat this as HTML (i.e. inject our script).
164    let adjust_body = match (html_content_type, looks_like_html) {
165        (None, true) => {
166            warn_once!("Proxy response to '{uri_pq}' looks like HTML, but no 'Content-Type' \
167                header exists. I will treat it as HTML (injecting reload script), but setting \
168                the correct 'Content-Type' header is recommended.",
169            );
170            true
171        }
172        (None, false) => false,
173        (Some(true), true) => true,
174        (Some(false), true) => {
175            let header_bytes = parts.headers.get(header::CONTENT_TYPE).unwrap().as_bytes();
176            warn_once!("Proxy response to '{uri_pq}' looks like HTML, but the 'Content-Type' \
177                header indicates otherwise: '{}'. Not injecting reload script.",
178                String::from_utf8_lossy(header_bytes),
179            );
180            false
181        }
182        (Some(v), false) => v,
183    };
184
185    if !adjust_body {
186        let recombined_body = Body::wrap_stream(
187            futures::stream::once(async { Ok(Bytes::from(body_start)) }).chain(body)
188        );
189
190        return Response::from_parts(parts, recombined_body);
191    }
192
193
194    log::trace!("Response from proxy is HTML: injecting script");
195
196    // The response is HTML: we need to download it completely and
197    // inject our script.
198    while let Some(buf) = body.data().await {
199        match buf {
200            Ok(buf) => body_start.extend_from_slice(&buf),
201            Err(e) => return download_body_error(e, uri, ctx),
202        }
203    }
204    let body = body_start;
205
206    // Uncompress if necessary. All this allocates more than necessary, but I'd
207    // rather keep easier code in this case, as performance is unlikely to
208    // matter.
209    let new_body = match parts.headers.get(header::CONTENT_ENCODING).map(|v| v.as_bytes()) {
210        None => Bytes::from(inject::into(&body, &ctx.config)),
211
212        Some(b"gzip") => {
213            let mut decompressed = Vec::new();
214            flate2::read::GzDecoder::new(&*body).read_to_end(&mut decompressed)
215                .expect("unexpected error while decompressing GZIP");
216            let injected = inject::into(&decompressed, &ctx.config);
217            let mut out = Vec::new();
218            flate2::read::GzEncoder::new(&*injected, flate2::Compression::best())
219                .read_to_end(&mut out)
220                .expect("unexpected error while compressing GZIP");
221            Bytes::from(out)
222        }
223
224        Some(b"br") => {
225            let mut decompressed = Vec::new();
226            brotli::BrotliDecompress(&mut &*body, &mut decompressed)
227                .expect("unexpected error while decompressing Brotli");
228            let injected = inject::into(&decompressed, &ctx.config);
229            let mut out = Vec::new();
230            brotli::BrotliCompress(&mut &*injected, &mut out, &Default::default())
231                .expect("unexpected error while compressing Brotli");
232            Bytes::from(out)
233        }
234
235        Some(other) => {
236            log::warn!(
237                "Unsupported content encoding '{}'. Not injecting script!",
238                String::from_utf8_lossy(other),
239            );
240            Bytes::from(body)
241        }
242    };
243
244    if let Some(content_len) = parts.headers.get_mut(header::CONTENT_LENGTH) {
245        *content_len = new_body.len().into();
246    }
247
248    // We might need to adjust `Content-Security-Policy` to allow including
249    // scripts from `self`. This is most likely already the case, but we have
250    // to make sure. If the header appears multiple times, all header values
251    // need to allow a thing for it to be allowed. Thus we can just modify all
252    // headers independently from one another.
253    if let header::Entry::Occupied(mut e) = parts.headers.entry(header::CONTENT_SECURITY_POLICY) {
254        e.iter_mut().for_each(rewrite_csp);
255    }
256
257
258    Response::from_parts(parts, new_body.into())
259}
260
261/// We inject our own JS that connects via WS to the penguin server. These two
262/// things need to be allowed by the Content-Security-Policy. Usually they are,
263/// but in some cases we need to modify that header to allow for it.
264/// Unfortunately, it's a bit involved, but also fairly straight forward.
265fn rewrite_csp(header: &mut HeaderValue) {
266    use std::collections::{BTreeMap, btree_map::Entry};
267
268    // We have to parse the CSP. Compare section "2.2.1. Parse a serialized CSP"
269    // of TR CSP3: https://www.w3.org/TR/CSP3/#parse-serialized-policy
270    let mut directives = BTreeMap::new();
271    header.as_bytes()
272        // "strictly splitting on the U+003B SEMICOLON character (;)"
273        .split(|b| *b == b';')
274        // "If token is an empty string, or if token is not an ASCII string, continue."
275        .filter(|part| !part.is_empty())
276        .filter_map(|part| std::str::from_utf8(part).ok())
277        .for_each(|part| {
278            // "Strip leading and trailing ASCII whitespace" and then splitting
279            //  by whitespace to separate the directive name and all directive
280            //  values.
281            let mut split = part.trim().split_whitespace();
282            let name = split.next()
283                .expect("empty split iterator for non-empty string")
284                .to_ascii_lowercase();
285
286            match directives.entry(name) {
287                // "If policy’s directive set contains a directive whose name is
288                //  directive name, continue. Note: In this case, the user
289                //  agent SHOULD notify developers that a duplicate directive
290                //  was ignored. A console warning might be appropriate, for
291                //  example."
292                Entry::Occupied(entry) => {
293                    log::warn!("CSP malformed, second {} directive ignored", entry.key());
294                }
295
296                // "Append directive to policy’s directive set."
297                Entry::Vacant(entry) => {
298                    entry.insert(split.collect::<Vec<_>>());
299                }
300            }
301        });
302
303
304    // Of course, including the script/connect to self might still be allowed
305    // via other sources, like `http:`. But it also doesn't hurt to add `self`
306    // in those cases.
307    let scripts_from_self_allowed = directives.get("script-src")
308        .or_else(|| directives.get("default-src"))
309        .map_or(true, |v| v.contains(&"'self'") || v.contains(&"*"));
310
311    let connect_to_self_allowed = directives.get("connect-src")
312        .or_else(|| directives.get("default-src"))
313        .map_or(true, |v| v.contains(&"'self'") || v.contains(&"*"));
314
315
316    if scripts_from_self_allowed && connect_to_self_allowed {
317        log::trace!("CSP header already allows scripts from and connect to 'self', not modifying");
318        return;
319    }
320
321    // Add `self` to `script-src`/`connect-src`.
322    if !scripts_from_self_allowed {
323        let script_sources = directives.entry("script-src".to_owned()).or_default();
324        script_sources.retain(|src| *src != "'none'");
325        script_sources.push("'self'");
326    }
327    if !connect_to_self_allowed {
328        let script_sources = directives.entry("connect-src".to_owned()).or_default();
329        script_sources.retain(|src| *src != "'none'");
330        script_sources.push("'self'");
331    }
332
333    // Serialize parsed CSP into header value again.
334    let mut out = String::new();
335    for (name, values) in directives {
336        use std::fmt::Write;
337
338        out.push_str(&name);
339        values.iter().for_each(|v| write!(out, " {v}").unwrap());
340        out.push_str("; ");
341    }
342
343    // Above, we ignored all non-ASCII entries, so there shouldn't be a way our
344    // resulting string is non-ASCII.
345    log::trace!("Modified CSP header \nfrom {header:?} \nto   \"{out}\"");
346    *header = HeaderValue::from_str(&out)
347        .expect("modified CSP header has non-ASCII chars");
348}
349
350fn rewrite_location(header: &mut HeaderValue, target: &ProxyTarget, config: &Config) {
351    let value = match std::str::from_utf8(header.as_bytes()) {
352        Err(_) => {
353            log::warn!("Non UTF-8 'location' header: not rewriting");
354            return;
355        }
356        Ok(v) => v,
357    };
358
359    let mut uri = match value.parse::<Uri>() {
360        Err(_) => {
361            log::warn!("Could not parse 'location' header as URI: not rewriting");
362            return;
363        }
364        Ok(uri) => uri.into_parts(),
365    };
366
367    // If the redirect points to the proxy target itself (i.e. an internal
368    // redirect), we change the `location` header so that the browser changes
369    // the path & query, but stays on the Penguin host.
370    if uri.authority.as_ref() == Some(&target.authority) {
371        // Penguin itself only listens on HTTP
372        uri.scheme = Some(Scheme::HTTP);
373        let authority = config.bind_addr.to_string()
374            .try_into()
375            .expect("bind addr is not a valid authority");
376        uri.authority = Some(authority);
377
378        let uri = Uri::from_parts(uri).expect("bug: failed to build URI");
379        *header = HeaderValue::from_bytes(uri.to_string().as_bytes())
380            .expect("bug: new 'location' is invalid header value");
381    }
382}
383
384fn gateway_error(msg: &str, e: hyper::Error, config: &Config) -> Response<Body> {
385    let html = PROXY_ERROR_HTML
386        .replace("{{ error }}", msg)
387        .replace("{{ control_path }}", config.control_path());
388
389    let status = if e.is_timeout() {
390        StatusCode::GATEWAY_TIMEOUT
391    } else {
392        StatusCode::BAD_GATEWAY
393    };
394
395    Response::builder()
396        .status(status)
397        .header("Server", SERVER_HEADER)
398        .header("Content-Type", "text/html")
399        .body(html.into())
400        .unwrap()
401}
402
403/// Regularly polls the proxy target until it is reachable again. Once it is, it
404/// sends a reload action and stops. Makes sure (via `ctx`) that just one
405/// polling instance exists per penguin server.
406fn start_polling(ctx: &ProxyContext, target: &ProxyTarget, actions: Sender<Action>) {
407    // We only need one task polling the target.
408    let is_polling = Arc::clone(&ctx.is_polling_target);
409    if is_polling.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
410        return;
411    }
412
413    let client = Client::builder().build::<_, hyper::Body>(HttpsConnector::new());
414    let uri = Uri::builder()
415        .scheme(target.scheme.clone())
416        .authority(target.authority.clone())
417        .path_and_query("/")
418        .build()
419        .unwrap();
420
421    log::info!("Start regularly polling '{}' until it is available...", uri);
422    tokio::spawn(async move {
423        // We start polling quite quickly, but slow down up to this constant.
424        const MAX_SLEEP_DURATION: Duration = Duration::from_secs(3);
425        let mut sleep_duration = Duration::from_millis(250);
426
427        loop {
428            tokio::time::sleep(sleep_duration).await;
429            sleep_duration = min(sleep_duration.mul_f32(1.5), MAX_SLEEP_DURATION);
430
431            log::trace!("Trying to connect to '{}' again", uri);
432            if client.get(uri.clone()).await.is_ok() {
433                log::debug!("Reconnected to proxy target, reloading all active browser sessions");
434                let _ = actions.send(Action::Reload);
435                is_polling.store(false, Ordering::SeqCst);
436                break;
437            }
438        }
439    });
440}
441
442/// Filter the "accept-encoding" encodings in the header value `orig` and return
443/// a new value only containing the ones we support.
444fn filter_encodings(orig: &str) -> String {
445    let allowed_values = orig.split(',')
446        .map(|part| part.trim())
447        .filter(|part| {
448            let encoding = part.split_once(';').map(|p| p.0).unwrap_or(part);
449            SUPPORTED_COMPRESSIONS.contains(&encoding)
450        });
451
452    let mut new_value = String::new();
453    for (i, part) in allowed_values.enumerate() {
454        if i != 0 {
455            new_value.push_str(", ");
456        }
457        new_value.push_str(part);
458    }
459    new_value
460}
461
462
463#[cfg(test)]
464mod tests {
465    #[test]
466    fn encoding_filter() {
467        use super::filter_encodings as filter;
468
469        assert_eq!(filter(""), "");
470        assert_eq!(filter("gzip"), "gzip");
471        assert_eq!(filter("br"), "br");
472        assert_eq!(filter("gzip, br"), "gzip, br");
473        assert_eq!(filter("gzip, deflate"), "gzip");
474        assert_eq!(filter("deflate, gzip"), "gzip");
475        assert_eq!(filter("gzip, deflate, br"), "gzip, br");
476        assert_eq!(filter("deflate, gzip, br"), "gzip, br");
477        assert_eq!(filter("gzip, br, deflate"), "gzip, br");
478        assert_eq!(filter("deflate"), "");
479        assert_eq!(filter("br;q=1.0, deflate;q=0.5, gzip;q=0.8, *;q=0.1"), "br;q=1.0, gzip;q=0.8");
480    }
481
482    #[test]
483    fn modify_csp() {
484        #[track_caller]
485        fn assert_rewritten(original: &str, expected_rewritten: &str) {
486            let mut header = hyper::header::HeaderValue::from_str(original).unwrap();
487            super::rewrite_csp(&mut header);
488            if header.to_str().unwrap() != expected_rewritten {
489                panic!(
490                    "unexpected rewritten CSP header:\n\
491                        original: {}\n\
492                        expected: {}\n\
493                        actual:   {}\n",
494                    original,
495                    expected_rewritten,
496                    header.to_str().unwrap(),
497                );
498            }
499        }
500
501        #[track_caller]
502        fn assert_not_rewritten(original: &str) {
503            assert_rewritten(original, original);
504        }
505
506        assert_not_rewritten("default-src *");
507        assert_not_rewritten("default-src 'self'");
508        assert_not_rewritten("default-src 'self' https://google.com");
509        assert_not_rewritten("default-src 'none' https://google.com; \
510            script-src 'self'; connect-src *");
511
512        assert_rewritten(
513            "default-src 'none'",
514            "connect-src 'self'; default-src 'none'; script-src 'self'; ",
515        );
516        assert_rewritten(
517            "default-src 'none'; script-src http:",
518            "connect-src 'self'; default-src 'none'; script-src http: 'self'; ",
519        );
520        assert_rewritten(
521            "default-src 'self'; connect-src 'none'",
522            "connect-src 'self'; default-src 'self'; ",
523        );
524        assert_rewritten(
525            "default-src 'self'; script-src https:",
526            "default-src 'self'; script-src https: 'self'; ",
527        );
528    }
529}