use std::collections::{BTreeMap, HashSet};
use std::convert::Infallible;
use std::time::Duration;
use bytes::Bytes;
use http::{HeaderMap, Method};
use http_body_util::{BodyExt, Full};
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tracing::{debug, trace, warn};
use crate::model::DistributionConfig;
use crate::state::SharedCloudFrontState;
const ENV_DISABLE: &str = "FAKECLOUD_CLOUDFRONT_DISABLE_DATAPLANE";
const SUPERVISOR_TICK_SECS: u64 = 1;
pub fn dataplane_enabled() -> bool {
!matches!(
std::env::var(ENV_DISABLE).as_deref(),
Ok("1") | Ok("true") | Ok("TRUE") | Ok("yes") | Ok("YES")
)
}
struct BoundListener {
handle: JoinHandle<()>,
}
impl Drop for BoundListener {
fn drop(&mut self) {
self.handle.abort();
}
}
#[derive(Clone)]
struct DataPlane {
state: SharedCloudFrontState,
upstream: reqwest::Client,
s3_endpoint: String,
}
pub fn spawn_dataplane(state: SharedCloudFrontState, server_port: u16) {
if !dataplane_enabled() {
debug!("CloudFront data plane disabled via {ENV_DISABLE}");
return;
}
let upstream = match reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.redirect(reqwest::redirect::Policy::none())
.timeout(Duration::from_secs(30))
.build()
{
Ok(c) => c,
Err(e) => {
warn!("CloudFront data plane: failed to build reqwest client: {e}");
return;
}
};
let dp = DataPlane {
state,
upstream,
s3_endpoint: format!("127.0.0.1:{server_port}"),
};
tokio::spawn(supervisor_loop(dp));
}
async fn supervisor_loop(dp: DataPlane) {
let mut bindings: BTreeMap<String, BoundListener> = BTreeMap::new();
let mut tick = tokio::time::interval(Duration::from_secs(SUPERVISOR_TICK_SECS));
tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tick.tick().await;
reconcile(&dp, &mut bindings).await;
}
}
async fn reconcile(dp: &DataPlane, bindings: &mut BTreeMap<String, BoundListener>) {
let want: Vec<(String, String)> = {
let accs = dp.state.read();
accs.all_distributions()
.filter(|(_acct, d)| d.config.enabled)
.map(|(acct, d)| (d.id.clone(), acct.clone()))
.collect()
};
let want_set: HashSet<&String> = want.iter().map(|(id, _)| id).collect();
bindings.retain(|id, _| want_set.contains(id));
for (dist_id, account_id) in want.iter() {
if bindings.contains_key(dist_id) {
continue;
}
match TcpListener::bind(("127.0.0.1", 0)).await {
Ok(listener) => {
let port = listener.local_addr().map(|a| a.port()).unwrap_or(0);
if port == 0 {
warn!("CloudFront data plane: bind returned port 0 for {dist_id}; skipping");
continue;
}
{
let mut accs = dp.state.write();
if let Some(st) = accs.accounts.get_mut(account_id) {
if let Some(d) = st.distributions.get_mut(dist_id) {
d.bound_port = Some(port);
}
}
}
let dp2 = dp.clone();
let id2 = dist_id.clone();
let handle = tokio::spawn(async move {
accept_loop(dp2, id2, listener).await;
});
bindings.insert(dist_id.clone(), BoundListener { handle });
trace!(dist = %dist_id, port, "CloudFront data plane: bound listener");
}
Err(e) => {
warn!("CloudFront data plane: failed to bind for {dist_id}: {e}");
}
}
}
let mut accs = dp.state.write();
let account_ids: Vec<String> = accs.accounts.keys().cloned().collect();
for acct in account_ids {
if let Some(st) = accs.accounts.get_mut(&acct) {
for d in st.distributions.values_mut() {
if !bindings.contains_key(&d.id) {
d.bound_port = None;
}
}
}
}
}
async fn accept_loop(dp: DataPlane, dist_id: String, listener: TcpListener) {
loop {
let (sock, _peer) = match listener.accept().await {
Ok(p) => p,
Err(e) => {
debug!(dist = %dist_id, "accept error: {e}");
continue;
}
};
let dp2 = dp.clone();
let id2 = dist_id.clone();
tokio::spawn(async move {
let io = TokioIo::new(sock);
let svc = service_fn(move |req| {
let dp3 = dp2.clone();
let id3 = id2.clone();
async move { Ok::<_, Infallible>(handle_request(&dp3, &id3, req).await) }
});
if let Err(e) = hyper::server::conn::http1::Builder::new()
.serve_connection(io, svc)
.await
{
debug!("CloudFront data plane: connection error: {e}");
}
});
}
}
async fn handle_request(
dp: &DataPlane,
dist_id: &str,
req: Request<hyper::body::Incoming>,
) -> Response<Full<Bytes>> {
let (parts, body) = req.into_parts();
let method = parts.method;
let path = parts.uri.path().to_string();
let path_and_query = parts
.uri
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/")
.to_string();
let req_headers = parts.headers;
let body_bytes = body
.collect()
.await
.map(|c| c.to_bytes())
.unwrap_or_default();
let route: Option<RouteResolution> = {
let accs = dp.state.read();
let resolved = accs
.all_distributions()
.find(|(_, d)| d.id == dist_id)
.and_then(|(_, d)| resolve_route(&d.config, &path, &dp.s3_endpoint));
resolved
};
let Some(route) = route else {
return canned(502, "distribution or matching origin not found");
};
let url = format!("{}{path_and_query}", route.upstream.url_base);
trace!(dist = %dist_id, %path, origin = %route.upstream.host_header, "CloudFront data plane: proxying");
let resp = fetch_origin(
dp,
&method,
&url,
&route.upstream.host_header,
&req_headers,
&body_bytes,
)
.await;
if let Some(rule) = match_error_rule(&route.error_rules, resp.status().as_u16()) {
let origin_status = resp.status();
let url = format!("{}{}", route.default_upstream.url_base, rule.page_path);
let err_resp = fetch_origin(
dp,
&Method::GET,
&url,
&route.default_upstream.host_header,
&HeaderMap::new(),
&Bytes::new(),
)
.await;
if err_resp.status().is_success() {
let mut err_resp = err_resp;
let final_status = rule
.response_code
.and_then(|c| http::StatusCode::from_u16(c).ok())
.unwrap_or(origin_status);
*err_resp.status_mut() = final_status;
return err_resp;
}
return resp;
}
resp
}
struct RouteResolution {
upstream: UpstreamTarget,
default_upstream: UpstreamTarget,
error_rules: Vec<ErrorRule>,
}
#[derive(Clone)]
struct UpstreamTarget {
url_base: String,
host_header: String,
}
#[derive(Clone)]
struct ErrorRule {
error_code: u16,
page_path: String,
response_code: Option<u16>,
}
fn resolve_route(
cfg: &DistributionConfig,
path: &str,
s3_endpoint: &str,
) -> Option<RouteResolution> {
let items = cfg.origins.items.as_ref()?;
let target = select_target_origin(cfg, path);
let upstream = items
.origin
.iter()
.find(|o| o.id == target)
.map(|o| upstream_for(o, s3_endpoint))?;
let default_target = cfg.default_cache_behavior.target_origin_id.as_str();
let default_upstream = items
.origin
.iter()
.find(|o| o.id == default_target)
.map(|o| upstream_for(o, s3_endpoint))
.unwrap_or_else(|| upstream.clone());
let error_rules = cfg
.custom_error_responses
.as_ref()
.and_then(|c| c.items.as_ref())
.map(|it| {
it.custom_error_response
.iter()
.filter_map(|r| {
r.response_page_path.as_ref().map(|p| ErrorRule {
error_code: r.error_code as u16,
page_path: p.clone(),
response_code: r.response_code.as_ref().and_then(|s| s.parse().ok()),
})
})
.collect()
})
.unwrap_or_default();
Some(RouteResolution {
upstream,
default_upstream,
error_rules,
})
}
fn match_error_rule(rules: &[ErrorRule], status: u16) -> Option<ErrorRule> {
rules.iter().find(|r| r.error_code == status).cloned()
}
fn select_target_origin<'a>(cfg: &'a DistributionConfig, path: &str) -> &'a str {
if let Some(cbs) = &cfg.cache_behaviors {
if let Some(items) = &cbs.items {
for cb in &items.cache_behavior {
if path_pattern_matches(&cb.path_pattern, path) {
return &cb.target_origin_id;
}
}
}
}
&cfg.default_cache_behavior.target_origin_id
}
fn is_s3_website(domain: &str) -> bool {
domain.contains(".s3-website") && domain.ends_with(".amazonaws.com")
}
fn upstream_for(origin: &crate::model::Origin, s3_endpoint: &str) -> UpstreamTarget {
let domain = &origin.domain_name;
if is_s3_website(domain) {
return UpstreamTarget {
url_base: format!("http://{s3_endpoint}"),
host_header: domain.clone(),
};
}
if let Some(cfg) = &origin.custom_origin_config {
let https = cfg
.origin_protocol_policy
.eq_ignore_ascii_case("https-only");
let (scheme, port) = if https {
("https", cfg.https_port)
} else {
("http", cfg.http_port)
};
let has_explicit_port = domain.rsplit(':').next().is_some_and(|s| {
!s.is_empty() && s.bytes().all(|b| b.is_ascii_digit()) && domain.contains(':')
});
let default_port = (scheme == "http" && port == 80) || (scheme == "https" && port == 443);
let authority = if has_explicit_port || port <= 0 || default_port {
domain.clone()
} else {
format!("{domain}:{port}")
};
return UpstreamTarget {
url_base: format!("{scheme}://{authority}"),
host_header: domain.clone(),
};
}
UpstreamTarget {
url_base: format!("http://{domain}"),
host_header: domain.clone(),
}
}
async fn fetch_origin(
dp: &DataPlane,
method: &Method,
url: &str,
host_header: &str,
req_headers: &HeaderMap,
body: &Bytes,
) -> Response<Full<Bytes>> {
let mut rb = dp.upstream.request(reqwest_method(method), url);
for (k, v) in req_headers.iter() {
let n = k.as_str();
if is_hop_by_hop(n) || n.eq_ignore_ascii_case("host") {
continue;
}
rb = rb.header(k.as_str(), v.as_bytes());
}
rb = rb.header("host", host_header);
if !body.is_empty() {
rb = rb.body(body.to_vec());
}
match rb.send().await {
Ok(up) => {
let status = up.status();
let headers = up.headers().clone();
let bytes = up.bytes().await.unwrap_or_default();
let mut resp = Response::new(Full::new(bytes));
*resp.status_mut() = status;
for (k, v) in headers.iter() {
if !is_hop_by_hop(k.as_str()) {
resp.headers_mut().append(k.clone(), v.clone());
}
}
resp
}
Err(e) => canned(502, &format!("origin error: {e}")),
}
}
fn path_pattern_matches(pattern: &str, path: &str) -> bool {
let pat = pattern.trim_start_matches('/');
let p = path.trim_start_matches('/');
glob_match(pat.as_bytes(), p.as_bytes())
}
fn glob_match(pat: &[u8], text: &[u8]) -> bool {
let (mut p, mut t) = (0usize, 0usize);
let (mut star, mut mark) = (None, 0usize);
while t < text.len() {
if p < pat.len() && (pat[p] == b'?' || pat[p] == text[t]) {
p += 1;
t += 1;
} else if p < pat.len() && pat[p] == b'*' {
star = Some(p);
mark = t;
p += 1;
} else if let Some(sp) = star {
p = sp + 1;
mark += 1;
t = mark;
} else {
return false;
}
}
while p < pat.len() && pat[p] == b'*' {
p += 1;
}
p == pat.len()
}
fn canned(status: u16, msg: &str) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.body(Full::new(Bytes::from(msg.to_string())))
.expect("canned response builds")
}
fn reqwest_method(m: &Method) -> reqwest::Method {
reqwest::Method::from_bytes(m.as_str().as_bytes()).unwrap_or(reqwest::Method::GET)
}
const HOP_BY_HOP: &[&str] = &[
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
];
fn is_hop_by_hop(name: &str) -> bool {
HOP_BY_HOP.iter().any(|&h| h.eq_ignore_ascii_case(name))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{CustomOriginConfig, Origin};
fn origin(domain: &str, custom: Option<CustomOriginConfig>) -> Origin {
Origin {
id: "o".into(),
domain_name: domain.into(),
custom_origin_config: custom,
..Default::default()
}
}
fn custom(policy: &str, http_port: i32, https_port: i32) -> CustomOriginConfig {
CustomOriginConfig {
http_port,
https_port,
origin_protocol_policy: policy.into(),
..Default::default()
}
}
#[test]
fn s3_website_detection_is_precise() {
assert!(is_s3_website("b.s3-website-us-east-1.amazonaws.com"));
assert!(is_s3_website("b.s3-website.us-east-1.amazonaws.com"));
assert!(!is_s3_website("my.s3-website.example.com"));
assert!(!is_s3_website("api.example.com"));
assert!(!is_s3_website("127.0.0.1:8080"));
}
#[test]
fn s3_website_origin_routes_to_local_port() {
let up = upstream_for(
&origin("b.s3-website-us-east-1.amazonaws.com", None),
"127.0.0.1:4566",
);
assert_eq!(up.url_base, "http://127.0.0.1:4566");
assert_eq!(up.host_header, "b.s3-website-us-east-1.amazonaws.com");
}
#[test]
fn https_only_custom_origin_uses_https_and_port() {
let up = upstream_for(
&origin("api.example.com", Some(custom("https-only", 80, 8443))),
"127.0.0.1:4566",
);
assert_eq!(up.url_base, "https://api.example.com:8443");
}
#[test]
fn http_custom_origin_default_port_omits_port() {
let up = upstream_for(
&origin("api.example.com", Some(custom("http-only", 80, 443))),
"127.0.0.1:4566",
);
assert_eq!(up.url_base, "http://api.example.com");
}
#[test]
fn explicit_port_in_domain_wins_over_config_port() {
let up = upstream_for(
&origin("127.0.0.1:52111", Some(custom("http-only", 80, 443))),
"127.0.0.1:4566",
);
assert_eq!(up.url_base, "http://127.0.0.1:52111");
}
#[test]
fn bare_origin_defaults_to_http() {
let up = upstream_for(&origin("origin.internal", None), "127.0.0.1:4566");
assert_eq!(up.url_base, "http://origin.internal");
}
}