1use std::collections::{BTreeMap, HashSet};
20use std::convert::Infallible;
21use std::time::Duration;
22
23use bytes::Bytes;
24use http::{HeaderMap, Method};
25use http_body_util::{BodyExt, Full};
26use hyper::service::service_fn;
27use hyper::{Request, Response};
28use hyper_util::rt::TokioIo;
29use tokio::net::TcpListener;
30use tokio::task::JoinHandle;
31use tracing::{debug, trace, warn};
32
33use crate::model::DistributionConfig;
34use crate::state::SharedCloudFrontState;
35
36const ENV_DISABLE: &str = "FAKECLOUD_CLOUDFRONT_DISABLE_DATAPLANE";
37const SUPERVISOR_TICK_SECS: u64 = 1;
38
39pub fn dataplane_enabled() -> bool {
43 !matches!(
44 std::env::var(ENV_DISABLE).as_deref(),
45 Ok("1") | Ok("true") | Ok("TRUE") | Ok("yes") | Ok("YES")
46 )
47}
48
49struct BoundListener {
52 handle: JoinHandle<()>,
53}
54
55impl Drop for BoundListener {
56 fn drop(&mut self) {
57 self.handle.abort();
58 }
59}
60
61#[derive(Clone)]
63struct DataPlane {
64 state: SharedCloudFrontState,
65 upstream: reqwest::Client,
67 s3_endpoint: String,
72}
73
74pub fn spawn_dataplane(state: SharedCloudFrontState, server_port: u16) {
78 if !dataplane_enabled() {
79 debug!("CloudFront data plane disabled via {ENV_DISABLE}");
80 return;
81 }
82 let upstream = match reqwest::Client::builder()
83 .danger_accept_invalid_certs(true)
84 .redirect(reqwest::redirect::Policy::none())
85 .timeout(Duration::from_secs(30))
86 .build()
87 {
88 Ok(c) => c,
89 Err(e) => {
90 warn!("CloudFront data plane: failed to build reqwest client: {e}");
91 return;
92 }
93 };
94 let dp = DataPlane {
95 state,
96 upstream,
97 s3_endpoint: format!("127.0.0.1:{server_port}"),
98 };
99 tokio::spawn(supervisor_loop(dp));
100}
101
102async fn supervisor_loop(dp: DataPlane) {
103 let mut bindings: BTreeMap<String, BoundListener> = BTreeMap::new();
104 let mut tick = tokio::time::interval(Duration::from_secs(SUPERVISOR_TICK_SECS));
105 tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
106 loop {
107 tick.tick().await;
108 reconcile(&dp, &mut bindings).await;
109 }
110}
111
112async fn reconcile(dp: &DataPlane, bindings: &mut BTreeMap<String, BoundListener>) {
119 let want: Vec<(String, String)> = {
121 let accs = dp.state.read();
122 accs.all_distributions()
123 .filter(|(_acct, d)| d.config.enabled)
124 .map(|(acct, d)| (d.id.clone(), acct.clone()))
125 .collect()
126 };
127 let want_set: HashSet<&String> = want.iter().map(|(id, _)| id).collect();
128
129 bindings.retain(|id, _| want_set.contains(id));
131
132 for (dist_id, account_id) in want.iter() {
134 if bindings.contains_key(dist_id) {
135 continue;
136 }
137 match TcpListener::bind(("127.0.0.1", 0)).await {
138 Ok(listener) => {
139 let port = listener.local_addr().map(|a| a.port()).unwrap_or(0);
140 if port == 0 {
141 warn!("CloudFront data plane: bind returned port 0 for {dist_id}; skipping");
142 continue;
143 }
144 {
145 let mut accs = dp.state.write();
146 if let Some(st) = accs.accounts.get_mut(account_id) {
147 if let Some(d) = st.distributions.get_mut(dist_id) {
148 d.bound_port = Some(port);
149 }
150 }
151 }
152 let dp2 = dp.clone();
153 let id2 = dist_id.clone();
154 let handle = tokio::spawn(async move {
155 accept_loop(dp2, id2, listener).await;
156 });
157 bindings.insert(dist_id.clone(), BoundListener { handle });
158 trace!(dist = %dist_id, port, "CloudFront data plane: bound listener");
159 }
160 Err(e) => {
161 warn!("CloudFront data plane: failed to bind for {dist_id}: {e}");
162 }
163 }
164 }
165
166 let mut accs = dp.state.write();
168 let account_ids: Vec<String> = accs.accounts.keys().cloned().collect();
169 for acct in account_ids {
170 if let Some(st) = accs.accounts.get_mut(&acct) {
171 for d in st.distributions.values_mut() {
172 if !bindings.contains_key(&d.id) {
173 d.bound_port = None;
174 }
175 }
176 }
177 }
178}
179
180async fn accept_loop(dp: DataPlane, dist_id: String, listener: TcpListener) {
181 loop {
182 let (sock, _peer) = match listener.accept().await {
183 Ok(p) => p,
184 Err(e) => {
185 debug!(dist = %dist_id, "accept error: {e}");
186 continue;
187 }
188 };
189 let dp2 = dp.clone();
190 let id2 = dist_id.clone();
191 tokio::spawn(async move {
192 let io = TokioIo::new(sock);
193 let svc = service_fn(move |req| {
194 let dp3 = dp2.clone();
195 let id3 = id2.clone();
196 async move { Ok::<_, Infallible>(handle_request(&dp3, &id3, req).await) }
197 });
198 if let Err(e) = hyper::server::conn::http1::Builder::new()
199 .serve_connection(io, svc)
200 .await
201 {
202 debug!("CloudFront data plane: connection error: {e}");
203 }
204 });
205 }
206}
207
208async fn handle_request(
212 dp: &DataPlane,
213 dist_id: &str,
214 req: Request<hyper::body::Incoming>,
215) -> Response<Full<Bytes>> {
216 let (parts, body) = req.into_parts();
217 let method = parts.method;
218 let path = parts.uri.path().to_string();
219 let path_and_query = parts
220 .uri
221 .path_and_query()
222 .map(|p| p.as_str())
223 .unwrap_or("/")
224 .to_string();
225 let req_headers = parts.headers;
226 let body_bytes = body
227 .collect()
228 .await
229 .map(|c| c.to_bytes())
230 .unwrap_or_default();
231
232 let route: Option<RouteResolution> = {
235 let accs = dp.state.read();
236 let resolved = accs
237 .all_distributions()
238 .find(|(_, d)| d.id == dist_id)
239 .and_then(|(_, d)| resolve_route(&d.config, &path, &dp.s3_endpoint));
240 resolved
241 };
242 let Some(route) = route else {
243 return canned(502, "distribution or matching origin not found");
244 };
245
246 let url = format!("{}{path_and_query}", route.upstream.url_base);
247 trace!(dist = %dist_id, %path, origin = %route.upstream.host_header, "CloudFront data plane: proxying");
248 let resp = fetch_origin(
249 dp,
250 &method,
251 &url,
252 &route.upstream.host_header,
253 &req_headers,
254 &body_bytes,
255 )
256 .await;
257
258 if let Some(rule) = match_error_rule(&route.error_rules, resp.status().as_u16()) {
263 let origin_status = resp.status();
264 let url = format!("{}{}", route.default_upstream.url_base, rule.page_path);
265 let err_resp = fetch_origin(
266 dp,
267 &Method::GET,
268 &url,
269 &route.default_upstream.host_header,
270 &HeaderMap::new(),
271 &Bytes::new(),
272 )
273 .await;
274 if err_resp.status().is_success() {
280 let mut err_resp = err_resp;
281 let final_status = rule
284 .response_code
285 .and_then(|c| http::StatusCode::from_u16(c).ok())
286 .unwrap_or(origin_status);
287 *err_resp.status_mut() = final_status;
288 return err_resp;
289 }
290 return resp;
291 }
292 resp
293}
294
295struct RouteResolution {
297 upstream: UpstreamTarget,
299 default_upstream: UpstreamTarget,
302 error_rules: Vec<ErrorRule>,
304}
305
306#[derive(Clone)]
309struct UpstreamTarget {
310 url_base: String,
312 host_header: String,
314}
315
316#[derive(Clone)]
317struct ErrorRule {
318 error_code: u16,
319 page_path: String,
320 response_code: Option<u16>,
321}
322
323fn resolve_route(
326 cfg: &DistributionConfig,
327 path: &str,
328 s3_endpoint: &str,
329) -> Option<RouteResolution> {
330 let items = cfg.origins.items.as_ref()?;
331 let target = select_target_origin(cfg, path);
332 let upstream = items
333 .origin
334 .iter()
335 .find(|o| o.id == target)
336 .map(|o| upstream_for(o, s3_endpoint))?;
337 let default_target = cfg.default_cache_behavior.target_origin_id.as_str();
338 let default_upstream = items
339 .origin
340 .iter()
341 .find(|o| o.id == default_target)
342 .map(|o| upstream_for(o, s3_endpoint))
343 .unwrap_or_else(|| upstream.clone());
344 let error_rules = cfg
345 .custom_error_responses
346 .as_ref()
347 .and_then(|c| c.items.as_ref())
348 .map(|it| {
349 it.custom_error_response
350 .iter()
351 .filter_map(|r| {
352 r.response_page_path.as_ref().map(|p| ErrorRule {
353 error_code: r.error_code as u16,
354 page_path: p.clone(),
355 response_code: r.response_code.as_ref().and_then(|s| s.parse().ok()),
356 })
357 })
358 .collect()
359 })
360 .unwrap_or_default();
361 Some(RouteResolution {
362 upstream,
363 default_upstream,
364 error_rules,
365 })
366}
367
368fn match_error_rule(rules: &[ErrorRule], status: u16) -> Option<ErrorRule> {
370 rules.iter().find(|r| r.error_code == status).cloned()
371}
372
373fn select_target_origin<'a>(cfg: &'a DistributionConfig, path: &str) -> &'a str {
374 if let Some(cbs) = &cfg.cache_behaviors {
375 if let Some(items) = &cbs.items {
376 for cb in &items.cache_behavior {
377 if path_pattern_matches(&cb.path_pattern, path) {
378 return &cb.target_origin_id;
379 }
380 }
381 }
382 }
383 &cfg.default_cache_behavior.target_origin_id
384}
385
386fn is_s3_website(domain: &str) -> bool {
392 domain.contains(".s3-website") && domain.ends_with(".amazonaws.com")
393}
394
395fn upstream_for(origin: &crate::model::Origin, s3_endpoint: &str) -> UpstreamTarget {
405 let domain = &origin.domain_name;
406 if is_s3_website(domain) {
407 return UpstreamTarget {
408 url_base: format!("http://{s3_endpoint}"),
409 host_header: domain.clone(),
410 };
411 }
412 if let Some(cfg) = &origin.custom_origin_config {
413 let https = cfg
414 .origin_protocol_policy
415 .eq_ignore_ascii_case("https-only");
416 let (scheme, port) = if https {
417 ("https", cfg.https_port)
418 } else {
419 ("http", cfg.http_port)
420 };
421 let has_explicit_port = domain.rsplit(':').next().is_some_and(|s| {
424 !s.is_empty() && s.bytes().all(|b| b.is_ascii_digit()) && domain.contains(':')
425 });
426 let default_port = (scheme == "http" && port == 80) || (scheme == "https" && port == 443);
427 let authority = if has_explicit_port || port <= 0 || default_port {
428 domain.clone()
429 } else {
430 format!("{domain}:{port}")
431 };
432 return UpstreamTarget {
433 url_base: format!("{scheme}://{authority}"),
434 host_header: domain.clone(),
435 };
436 }
437 UpstreamTarget {
438 url_base: format!("http://{domain}"),
439 host_header: domain.clone(),
440 }
441}
442
443async fn fetch_origin(
445 dp: &DataPlane,
446 method: &Method,
447 url: &str,
448 host_header: &str,
449 req_headers: &HeaderMap,
450 body: &Bytes,
451) -> Response<Full<Bytes>> {
452 let mut rb = dp.upstream.request(reqwest_method(method), url);
453 for (k, v) in req_headers.iter() {
454 let n = k.as_str();
455 if is_hop_by_hop(n) || n.eq_ignore_ascii_case("host") {
456 continue;
457 }
458 rb = rb.header(k.as_str(), v.as_bytes());
459 }
460 rb = rb.header("host", host_header);
461 if !body.is_empty() {
462 rb = rb.body(body.to_vec());
463 }
464 match rb.send().await {
465 Ok(up) => {
466 let status = up.status();
467 let headers = up.headers().clone();
468 let bytes = up.bytes().await.unwrap_or_default();
469 let mut resp = Response::new(Full::new(bytes));
470 *resp.status_mut() = status;
471 for (k, v) in headers.iter() {
472 if !is_hop_by_hop(k.as_str()) {
473 resp.headers_mut().append(k.clone(), v.clone());
474 }
475 }
476 resp
477 }
478 Err(e) => canned(502, &format!("origin error: {e}")),
479 }
480}
481
482fn path_pattern_matches(pattern: &str, path: &str) -> bool {
487 let pat = pattern.trim_start_matches('/');
488 let p = path.trim_start_matches('/');
489 glob_match(pat.as_bytes(), p.as_bytes())
490}
491
492fn glob_match(pat: &[u8], text: &[u8]) -> bool {
493 let (mut p, mut t) = (0usize, 0usize);
495 let (mut star, mut mark) = (None, 0usize);
496 while t < text.len() {
497 if p < pat.len() && (pat[p] == b'?' || pat[p] == text[t]) {
498 p += 1;
499 t += 1;
500 } else if p < pat.len() && pat[p] == b'*' {
501 star = Some(p);
502 mark = t;
503 p += 1;
504 } else if let Some(sp) = star {
505 p = sp + 1;
506 mark += 1;
507 t = mark;
508 } else {
509 return false;
510 }
511 }
512 while p < pat.len() && pat[p] == b'*' {
513 p += 1;
514 }
515 p == pat.len()
516}
517
518fn canned(status: u16, msg: &str) -> Response<Full<Bytes>> {
519 Response::builder()
520 .status(status)
521 .body(Full::new(Bytes::from(msg.to_string())))
522 .expect("canned response builds")
523}
524
525fn reqwest_method(m: &Method) -> reqwest::Method {
526 reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
527}
528
529const HOP_BY_HOP: &[&str] = &[
530 "connection",
531 "keep-alive",
532 "proxy-authenticate",
533 "proxy-authorization",
534 "te",
535 "trailer",
536 "transfer-encoding",
537 "upgrade",
538];
539
540fn is_hop_by_hop(name: &str) -> bool {
541 HOP_BY_HOP.iter().any(|&h| h.eq_ignore_ascii_case(name))
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use crate::model::{CustomOriginConfig, Origin};
548
549 fn origin(domain: &str, custom: Option<CustomOriginConfig>) -> Origin {
550 Origin {
551 id: "o".into(),
552 domain_name: domain.into(),
553 custom_origin_config: custom,
554 ..Default::default()
555 }
556 }
557
558 fn custom(policy: &str, http_port: i32, https_port: i32) -> CustomOriginConfig {
559 CustomOriginConfig {
560 http_port,
561 https_port,
562 origin_protocol_policy: policy.into(),
563 ..Default::default()
564 }
565 }
566
567 #[test]
568 fn s3_website_detection_is_precise() {
569 assert!(is_s3_website("b.s3-website-us-east-1.amazonaws.com"));
570 assert!(is_s3_website("b.s3-website.us-east-1.amazonaws.com"));
571 assert!(!is_s3_website("my.s3-website.example.com"));
573 assert!(!is_s3_website("api.example.com"));
574 assert!(!is_s3_website("127.0.0.1:8080"));
575 }
576
577 #[test]
578 fn s3_website_origin_routes_to_local_port() {
579 let up = upstream_for(
580 &origin("b.s3-website-us-east-1.amazonaws.com", None),
581 "127.0.0.1:4566",
582 );
583 assert_eq!(up.url_base, "http://127.0.0.1:4566");
584 assert_eq!(up.host_header, "b.s3-website-us-east-1.amazonaws.com");
585 }
586
587 #[test]
588 fn https_only_custom_origin_uses_https_and_port() {
589 let up = upstream_for(
590 &origin("api.example.com", Some(custom("https-only", 80, 8443))),
591 "127.0.0.1:4566",
592 );
593 assert_eq!(up.url_base, "https://api.example.com:8443");
594 }
595
596 #[test]
597 fn http_custom_origin_default_port_omits_port() {
598 let up = upstream_for(
599 &origin("api.example.com", Some(custom("http-only", 80, 443))),
600 "127.0.0.1:4566",
601 );
602 assert_eq!(up.url_base, "http://api.example.com");
603 }
604
605 #[test]
606 fn explicit_port_in_domain_wins_over_config_port() {
607 let up = upstream_for(
610 &origin("127.0.0.1:52111", Some(custom("http-only", 80, 443))),
611 "127.0.0.1:4566",
612 );
613 assert_eq!(up.url_base, "http://127.0.0.1:52111");
614 }
615
616 #[test]
617 fn bare_origin_defaults_to_http() {
618 let up = upstream_for(&origin("origin.internal", None), "127.0.0.1:4566");
619 assert_eq!(up.url_base, "http://origin.internal");
620 }
621}