Skip to main content

fakecloud_cloudfront/
dataplane.rs

1// crates/fakecloud-cloudfront/src/dataplane.rs
2//! In-process CloudFront data plane.
3//!
4//! Mirrors the ELBv2 data plane (`fakecloud-elbv2/src/dataplane.rs`): a
5//! supervisor loop binds one `TcpListener` on `127.0.0.1:0` per *enabled*
6//! distribution, records the OS-allocated port back into distribution state as
7//! `bound_port`, and serves viewer requests via `hyper`. The AWS-shaped
8//! `*.cloudfront.net` domain stays cosmetic; clients discover the real address
9//! via `/_fakecloud/cloudfront/distributions` and connect to
10//! `http://127.0.0.1:{bound_port}/...`.
11//!
12//! `handle_request` selects a cache behavior by path pattern, resolves its
13//! origin, reverse-proxies to it, and applies CustomErrorResponses (e.g. the
14//! SPA `404 -> /index.html` served as `200`). There is no global edge network —
15//! this is a single local origin-serving node, matching the ALB/API Gateway
16//! precedent. Deferred (not implemented): in-path CloudFront Functions /
17//! Lambda@Edge, TTL caching / invalidation, and OAC/SigV4 to private S3.
18
19use std::collections::{BTreeMap, HashSet};
20use std::convert::Infallible;
21use std::time::Duration;
22
23use bytes::Bytes;
24use http::{HeaderMap, Method};
25use http_body_util::{BodyExt, Full};
26use hyper::service::service_fn;
27use hyper::{Request, Response};
28use hyper_util::rt::TokioIo;
29use tokio::net::TcpListener;
30use tokio::task::JoinHandle;
31use tracing::{debug, trace, warn};
32
33use crate::model::DistributionConfig;
34use crate::state::SharedCloudFrontState;
35
36const ENV_DISABLE: &str = "FAKECLOUD_CLOUDFRONT_DISABLE_DATAPLANE";
37const SUPERVISOR_TICK_SECS: u64 = 1;
38
39/// Whether the data plane should run. Disabled by setting
40/// `FAKECLOUD_CLOUDFRONT_DISABLE_DATAPLANE` to a truthy value (mirrors the
41/// ELBv2 flag), for environments that only exercise the control plane.
42pub fn dataplane_enabled() -> bool {
43    !matches!(
44        std::env::var(ENV_DISABLE).as_deref(),
45        Ok("1") | Ok("true") | Ok("TRUE") | Ok("yes") | Ok("YES")
46    )
47}
48
49/// Per-distribution listener handle. Dropping it aborts the accept loop and
50/// frees the OS port, so a disabled/deleted distribution stops serving.
51struct BoundListener {
52    handle: JoinHandle<()>,
53}
54
55impl Drop for BoundListener {
56    fn drop(&mut self) {
57        self.handle.abort();
58    }
59}
60
61/// State shared across the supervisor and per-connection handlers.
62#[derive(Clone)]
63struct DataPlane {
64    state: SharedCloudFrontState,
65    /// HTTP client used to fetch from origins (reverse-proxy).
66    upstream: reqwest::Client,
67    /// `host:port` of fakecloud's own server. An S3-website origin is served by
68    /// this same process on the main port, so those origins are reached here
69    /// with the website domain preserved in the `Host` header (real CloudFront
70    /// likewise treats an S3-website endpoint as an HTTP custom origin).
71    s3_endpoint: String,
72}
73
74/// Spawn the CloudFront data-plane supervisor. No-op (returns without spawning)
75/// when disabled via the env flag. `server_port` is fakecloud's own listen port,
76/// used to reach S3-website origins served by this process.
77pub fn spawn_dataplane(state: SharedCloudFrontState, server_port: u16) {
78    if !dataplane_enabled() {
79        debug!("CloudFront data plane disabled via {ENV_DISABLE}");
80        return;
81    }
82    let upstream = match reqwest::Client::builder()
83        .danger_accept_invalid_certs(true)
84        .redirect(reqwest::redirect::Policy::none())
85        .timeout(Duration::from_secs(30))
86        .build()
87    {
88        Ok(c) => c,
89        Err(e) => {
90            warn!("CloudFront data plane: failed to build reqwest client: {e}");
91            return;
92        }
93    };
94    let dp = DataPlane {
95        state,
96        upstream,
97        s3_endpoint: format!("127.0.0.1:{server_port}"),
98    };
99    tokio::spawn(supervisor_loop(dp));
100}
101
102async fn supervisor_loop(dp: DataPlane) {
103    let mut bindings: BTreeMap<String, BoundListener> = BTreeMap::new();
104    let mut tick = tokio::time::interval(Duration::from_secs(SUPERVISOR_TICK_SECS));
105    tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
106    loop {
107        tick.tick().await;
108        reconcile(&dp, &mut bindings).await;
109    }
110}
111
112/// Reconcile bound listeners against the set of enabled distributions. Binds
113/// newly-enabled distributions, tears down listeners for ones that were
114/// disabled or deleted, and keeps each distribution's `bound_port` in sync with
115/// whether the supervisor currently holds its listener. Because the want-set is
116/// derived purely from persisted state, enabled distributions loaded from a
117/// snapshot on startup are re-bound on the first tick (startup rebind).
118async fn reconcile(dp: &DataPlane, bindings: &mut BTreeMap<String, BoundListener>) {
119    // 1. Snapshot the (distribution id, owning account) pairs that want a listener.
120    let want: Vec<(String, String)> = {
121        let accs = dp.state.read();
122        accs.all_distributions()
123            .filter(|(_acct, d)| d.config.enabled)
124            .map(|(acct, d)| (d.id.clone(), acct.clone()))
125            .collect()
126    };
127    let want_set: HashSet<&String> = want.iter().map(|(id, _)| id).collect();
128
129    // 2. Drop bindings for distributions no longer wanted (disabled/deleted).
130    bindings.retain(|id, _| want_set.contains(id));
131
132    // 3. Bind any newly-enabled distribution.
133    for (dist_id, account_id) in want.iter() {
134        if bindings.contains_key(dist_id) {
135            continue;
136        }
137        match TcpListener::bind(("127.0.0.1", 0)).await {
138            Ok(listener) => {
139                let port = listener.local_addr().map(|a| a.port()).unwrap_or(0);
140                if port == 0 {
141                    warn!("CloudFront data plane: bind returned port 0 for {dist_id}; skipping");
142                    continue;
143                }
144                {
145                    let mut accs = dp.state.write();
146                    if let Some(st) = accs.accounts.get_mut(account_id) {
147                        if let Some(d) = st.distributions.get_mut(dist_id) {
148                            d.bound_port = Some(port);
149                        }
150                    }
151                }
152                let dp2 = dp.clone();
153                let id2 = dist_id.clone();
154                let handle = tokio::spawn(async move {
155                    accept_loop(dp2, id2, listener).await;
156                });
157                bindings.insert(dist_id.clone(), BoundListener { handle });
158                trace!(dist = %dist_id, port, "CloudFront data plane: bound listener");
159            }
160            Err(e) => {
161                warn!("CloudFront data plane: failed to bind for {dist_id}: {e}");
162            }
163        }
164    }
165
166    // 4. Clear bound_port for any distribution the supervisor no longer holds.
167    let mut accs = dp.state.write();
168    let account_ids: Vec<String> = accs.accounts.keys().cloned().collect();
169    for acct in account_ids {
170        if let Some(st) = accs.accounts.get_mut(&acct) {
171            for d in st.distributions.values_mut() {
172                if !bindings.contains_key(&d.id) {
173                    d.bound_port = None;
174                }
175            }
176        }
177    }
178}
179
180async fn accept_loop(dp: DataPlane, dist_id: String, listener: TcpListener) {
181    loop {
182        let (sock, _peer) = match listener.accept().await {
183            Ok(p) => p,
184            Err(e) => {
185                debug!(dist = %dist_id, "accept error: {e}");
186                continue;
187            }
188        };
189        let dp2 = dp.clone();
190        let id2 = dist_id.clone();
191        tokio::spawn(async move {
192            let io = TokioIo::new(sock);
193            let svc = service_fn(move |req| {
194                let dp3 = dp2.clone();
195                let id3 = id2.clone();
196                async move { Ok::<_, Infallible>(handle_request(&dp3, &id3, req).await) }
197            });
198            if let Err(e) = hyper::server::conn::http1::Builder::new()
199                .serve_connection(io, svc)
200                .await
201            {
202                debug!("CloudFront data plane: connection error: {e}");
203            }
204        });
205    }
206}
207
208/// Serve one viewer request: select a cache behavior by path pattern, resolve
209/// its origin, and reverse-proxy to it. Task 4 interposes CustomErrorResponses
210/// on the origin status before returning.
211async fn handle_request(
212    dp: &DataPlane,
213    dist_id: &str,
214    req: Request<hyper::body::Incoming>,
215) -> Response<Full<Bytes>> {
216    let (parts, body) = req.into_parts();
217    let method = parts.method;
218    let path = parts.uri.path().to_string();
219    let path_and_query = parts
220        .uri
221        .path_and_query()
222        .map(|p| p.as_str())
223        .unwrap_or("/")
224        .to_string();
225    let req_headers = parts.headers;
226    let body_bytes = body
227        .collect()
228        .await
229        .map(|c| c.to_bytes())
230        .unwrap_or_default();
231
232    // Resolve the route under the read lock (owned snapshot so the guard drops
233    // at the end of the block).
234    let route: Option<RouteResolution> = {
235        let accs = dp.state.read();
236        let resolved = accs
237            .all_distributions()
238            .find(|(_, d)| d.id == dist_id)
239            .and_then(|(_, d)| resolve_route(&d.config, &path, &dp.s3_endpoint));
240        resolved
241    };
242    let Some(route) = route else {
243        return canned(502, "distribution or matching origin not found");
244    };
245
246    let url = format!("{}{path_and_query}", route.upstream.url_base);
247    trace!(dist = %dist_id, %path, origin = %route.upstream.host_header, "CloudFront data plane: proxying");
248    let resp = fetch_origin(
249        dp,
250        &method,
251        &url,
252        &route.upstream.host_header,
253        &req_headers,
254        &body_bytes,
255    )
256    .await;
257
258    // CustomErrorResponses: if the origin status matches a configured rule with
259    // a response page path, serve that page from the DEFAULT origin and return
260    // it with the rule's response code (the SPA deep-link fallback, e.g.
261    // 404 -> /index.html returned as 200).
262    if let Some(rule) = match_error_rule(&route.error_rules, resp.status().as_u16()) {
263        let origin_status = resp.status();
264        let url = format!("{}{}", route.default_upstream.url_base, rule.page_path);
265        let err_resp = fetch_origin(
266            dp,
267            &Method::GET,
268            &url,
269            &route.default_upstream.host_header,
270            &HeaderMap::new(),
271            &Bytes::new(),
272        )
273        .await;
274        // Only interpose the custom error page when the fallback fetch itself
275        // succeeded. If fetching the page failed (e.g. the default origin is
276        // down, or the page path 404s), returning it under the rule's success
277        // ResponseCode would mask an error body with a 200; keep the ORIGINAL
278        // origin response instead.
279        if err_resp.status().is_success() {
280            let mut err_resp = err_resp;
281            // Status = the rule's ResponseCode if set, else the ORIGINAL origin
282            // error status (AWS: an omitted ResponseCode keeps the origin's code).
283            let final_status = rule
284                .response_code
285                .and_then(|c| http::StatusCode::from_u16(c).ok())
286                .unwrap_or(origin_status);
287            *err_resp.status_mut() = final_status;
288            return err_resp;
289        }
290        return resp;
291    }
292    resp
293}
294
295/// Owned per-request routing snapshot (taken under the state read lock).
296struct RouteResolution {
297    /// Resolved upstream for the matched cache behavior.
298    upstream: UpstreamTarget,
299    /// Resolved upstream for the default cache behavior (where
300    /// CustomErrorResponse pages are fetched from).
301    default_upstream: UpstreamTarget,
302    /// CustomErrorResponses that have a response page path.
303    error_rules: Vec<ErrorRule>,
304}
305
306/// A resolved origin address: the scheme+authority to connect to and the `Host`
307/// header to send.
308#[derive(Clone)]
309struct UpstreamTarget {
310    /// `scheme://authority` (no trailing slash); the request path is appended.
311    url_base: String,
312    /// `Host` header sent upstream (the origin domain name).
313    host_header: String,
314}
315
316#[derive(Clone)]
317struct ErrorRule {
318    error_code: u16,
319    page_path: String,
320    response_code: Option<u16>,
321}
322
323/// Resolve the matched origin, the default origin, and the custom-error rules
324/// for a request path.
325fn resolve_route(
326    cfg: &DistributionConfig,
327    path: &str,
328    s3_endpoint: &str,
329) -> Option<RouteResolution> {
330    let items = cfg.origins.items.as_ref()?;
331    let target = select_target_origin(cfg, path);
332    let upstream = items
333        .origin
334        .iter()
335        .find(|o| o.id == target)
336        .map(|o| upstream_for(o, s3_endpoint))?;
337    let default_target = cfg.default_cache_behavior.target_origin_id.as_str();
338    let default_upstream = items
339        .origin
340        .iter()
341        .find(|o| o.id == default_target)
342        .map(|o| upstream_for(o, s3_endpoint))
343        .unwrap_or_else(|| upstream.clone());
344    let error_rules = cfg
345        .custom_error_responses
346        .as_ref()
347        .and_then(|c| c.items.as_ref())
348        .map(|it| {
349            it.custom_error_response
350                .iter()
351                .filter_map(|r| {
352                    r.response_page_path.as_ref().map(|p| ErrorRule {
353                        error_code: r.error_code as u16,
354                        page_path: p.clone(),
355                        response_code: r.response_code.as_ref().and_then(|s| s.parse().ok()),
356                    })
357                })
358                .collect()
359        })
360        .unwrap_or_default();
361    Some(RouteResolution {
362        upstream,
363        default_upstream,
364        error_rules,
365    })
366}
367
368/// First custom-error rule whose error code matches the origin status.
369fn match_error_rule(rules: &[ErrorRule], status: u16) -> Option<ErrorRule> {
370    rules.iter().find(|r| r.error_code == status).cloned()
371}
372
373fn select_target_origin<'a>(cfg: &'a DistributionConfig, path: &str) -> &'a str {
374    if let Some(cbs) = &cfg.cache_behaviors {
375        if let Some(items) = &cbs.items {
376            for cb in &items.cache_behavior {
377                if path_pattern_matches(&cb.path_pattern, path) {
378                    return &cb.target_origin_id;
379                }
380            }
381        }
382    }
383    &cfg.default_cache_behavior.target_origin_id
384}
385
386/// An S3 static-website endpoint (`bucket.s3-website-<region>.amazonaws.com` or
387/// `bucket.s3-website.<region>.amazonaws.com`). Matched precisely (`.s3-website`
388/// label plus the `.amazonaws.com` suffix) so a custom origin that merely
389/// contains the substring — e.g. `my.s3-website.example.com` — is NOT rerouted
390/// to the local fakecloud port.
391fn is_s3_website(domain: &str) -> bool {
392    domain.contains(".s3-website") && domain.ends_with(".amazonaws.com")
393}
394
395/// Resolve an [`crate::model::Origin`] to the upstream to connect to.
396///
397/// - S3-website origins are served by this same fakecloud process, so connect to
398///   its own port while preserving the website domain in `Host`.
399/// - Custom origins honor `CustomOriginConfig`: an `https-only` protocol policy
400///   is fetched over HTTPS (else HTTP), and the configured `HTTPPort`/`HTTPSPort`
401///   is appended UNLESS the `domain_name` already carries an explicit `:port`
402///   (as local test origins do) or the port is the scheme default.
403/// - Bare origins (no config) are reached over HTTP at their domain verbatim.
404fn upstream_for(origin: &crate::model::Origin, s3_endpoint: &str) -> UpstreamTarget {
405    let domain = &origin.domain_name;
406    if is_s3_website(domain) {
407        return UpstreamTarget {
408            url_base: format!("http://{s3_endpoint}"),
409            host_header: domain.clone(),
410        };
411    }
412    if let Some(cfg) = &origin.custom_origin_config {
413        let https = cfg
414            .origin_protocol_policy
415            .eq_ignore_ascii_case("https-only");
416        let (scheme, port) = if https {
417            ("https", cfg.https_port)
418        } else {
419            ("http", cfg.http_port)
420        };
421        // A domain that already encodes a port (host:port, as local origins do)
422        // wins over the config port; otherwise append a non-default port.
423        let has_explicit_port = domain.rsplit(':').next().is_some_and(|s| {
424            !s.is_empty() && s.bytes().all(|b| b.is_ascii_digit()) && domain.contains(':')
425        });
426        let default_port = (scheme == "http" && port == 80) || (scheme == "https" && port == 443);
427        let authority = if has_explicit_port || port <= 0 || default_port {
428            domain.clone()
429        } else {
430            format!("{domain}:{port}")
431        };
432        return UpstreamTarget {
433            url_base: format!("{scheme}://{authority}"),
434            host_header: domain.clone(),
435        };
436    }
437    UpstreamTarget {
438        url_base: format!("http://{domain}"),
439        host_header: domain.clone(),
440    }
441}
442
443/// Reverse-proxy the request to the resolved origin and copy the response back.
444async fn fetch_origin(
445    dp: &DataPlane,
446    method: &Method,
447    url: &str,
448    host_header: &str,
449    req_headers: &HeaderMap,
450    body: &Bytes,
451) -> Response<Full<Bytes>> {
452    let mut rb = dp.upstream.request(reqwest_method(method), url);
453    for (k, v) in req_headers.iter() {
454        let n = k.as_str();
455        if is_hop_by_hop(n) || n.eq_ignore_ascii_case("host") {
456            continue;
457        }
458        rb = rb.header(k.as_str(), v.as_bytes());
459    }
460    rb = rb.header("host", host_header);
461    if !body.is_empty() {
462        rb = rb.body(body.to_vec());
463    }
464    match rb.send().await {
465        Ok(up) => {
466            let status = up.status();
467            let headers = up.headers().clone();
468            let bytes = up.bytes().await.unwrap_or_default();
469            let mut resp = Response::new(Full::new(bytes));
470            *resp.status_mut() = status;
471            for (k, v) in headers.iter() {
472                if !is_hop_by_hop(k.as_str()) {
473                    resp.headers_mut().append(k.clone(), v.clone());
474                }
475            }
476            resp
477        }
478        Err(e) => canned(502, &format!("origin error: {e}")),
479    }
480}
481
482/// Match a CloudFront cache-behavior path pattern (`*` = any sequence, `?` = one
483/// character) against a request path. AWS path patterns are relative (no leading
484/// slash, e.g. `api/*`); normalize both sides so a canonical `api/*` and a
485/// slash-prefixed `/api/*` both match a request path like `/api/orders`.
486fn path_pattern_matches(pattern: &str, path: &str) -> bool {
487    let pat = pattern.trim_start_matches('/');
488    let p = path.trim_start_matches('/');
489    glob_match(pat.as_bytes(), p.as_bytes())
490}
491
492fn glob_match(pat: &[u8], text: &[u8]) -> bool {
493    // Iterative glob with backtracking on `*`.
494    let (mut p, mut t) = (0usize, 0usize);
495    let (mut star, mut mark) = (None, 0usize);
496    while t < text.len() {
497        if p < pat.len() && (pat[p] == b'?' || pat[p] == text[t]) {
498            p += 1;
499            t += 1;
500        } else if p < pat.len() && pat[p] == b'*' {
501            star = Some(p);
502            mark = t;
503            p += 1;
504        } else if let Some(sp) = star {
505            p = sp + 1;
506            mark += 1;
507            t = mark;
508        } else {
509            return false;
510        }
511    }
512    while p < pat.len() && pat[p] == b'*' {
513        p += 1;
514    }
515    p == pat.len()
516}
517
518fn canned(status: u16, msg: &str) -> Response<Full<Bytes>> {
519    Response::builder()
520        .status(status)
521        .body(Full::new(Bytes::from(msg.to_string())))
522        .expect("canned response builds")
523}
524
525fn reqwest_method(m: &Method) -> reqwest::Method {
526    reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
527}
528
529const HOP_BY_HOP: &[&str] = &[
530    "connection",
531    "keep-alive",
532    "proxy-authenticate",
533    "proxy-authorization",
534    "te",
535    "trailer",
536    "transfer-encoding",
537    "upgrade",
538];
539
540fn is_hop_by_hop(name: &str) -> bool {
541    HOP_BY_HOP.iter().any(|&h| h.eq_ignore_ascii_case(name))
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use crate::model::{CustomOriginConfig, Origin};
548
549    fn origin(domain: &str, custom: Option<CustomOriginConfig>) -> Origin {
550        Origin {
551            id: "o".into(),
552            domain_name: domain.into(),
553            custom_origin_config: custom,
554            ..Default::default()
555        }
556    }
557
558    fn custom(policy: &str, http_port: i32, https_port: i32) -> CustomOriginConfig {
559        CustomOriginConfig {
560            http_port,
561            https_port,
562            origin_protocol_policy: policy.into(),
563            ..Default::default()
564        }
565    }
566
567    #[test]
568    fn s3_website_detection_is_precise() {
569        assert!(is_s3_website("b.s3-website-us-east-1.amazonaws.com"));
570        assert!(is_s3_website("b.s3-website.us-east-1.amazonaws.com"));
571        // A custom origin that merely contains the substring must NOT match.
572        assert!(!is_s3_website("my.s3-website.example.com"));
573        assert!(!is_s3_website("api.example.com"));
574        assert!(!is_s3_website("127.0.0.1:8080"));
575    }
576
577    #[test]
578    fn s3_website_origin_routes_to_local_port() {
579        let up = upstream_for(
580            &origin("b.s3-website-us-east-1.amazonaws.com", None),
581            "127.0.0.1:4566",
582        );
583        assert_eq!(up.url_base, "http://127.0.0.1:4566");
584        assert_eq!(up.host_header, "b.s3-website-us-east-1.amazonaws.com");
585    }
586
587    #[test]
588    fn https_only_custom_origin_uses_https_and_port() {
589        let up = upstream_for(
590            &origin("api.example.com", Some(custom("https-only", 80, 8443))),
591            "127.0.0.1:4566",
592        );
593        assert_eq!(up.url_base, "https://api.example.com:8443");
594    }
595
596    #[test]
597    fn http_custom_origin_default_port_omits_port() {
598        let up = upstream_for(
599            &origin("api.example.com", Some(custom("http-only", 80, 443))),
600            "127.0.0.1:4566",
601        );
602        assert_eq!(up.url_base, "http://api.example.com");
603    }
604
605    #[test]
606    fn explicit_port_in_domain_wins_over_config_port() {
607        // Local origins encode the port in the domain; the config port (80) must
608        // not be appended on top of it.
609        let up = upstream_for(
610            &origin("127.0.0.1:52111", Some(custom("http-only", 80, 443))),
611            "127.0.0.1:4566",
612        );
613        assert_eq!(up.url_base, "http://127.0.0.1:52111");
614    }
615
616    #[test]
617    fn bare_origin_defaults_to_http() {
618        let up = upstream_for(&origin("origin.internal", None), "127.0.0.1:4566");
619        assert_eq!(up.url_base, "http://origin.internal");
620    }
621}