use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use chrono::Utc;
use http::header::{HeaderName, HeaderValue, ETAG, IF_MATCH, LOCATION};
use http::{HeaderMap, StatusCode};
use parking_lot::RwLock;
use uuid::Uuid;
use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError, ResponseBody};
use crate::model::{
DistributionConfig, DistributionConfigWithTags, InvalidationBatch, TagKeys, Tags as ModelTags,
};
use crate::router::{route, Route};
use crate::state::{
CloudFrontAccounts, SharedCloudFrontState, StoredDistribution, StoredInvalidation, Tag,
};
use crate::xml_io;
pub(crate) const DEFAULT_ACCOUNT: &str = "000000000000";
const SUPPORTED_ACTIONS: &[&str] = &[
"CreateDistribution",
"CreateDistributionWithTags",
"GetDistribution",
"GetDistributionConfig",
"UpdateDistribution",
"DeleteDistribution",
"ListDistributions",
"CopyDistribution",
"CreateInvalidation",
"GetInvalidation",
"ListInvalidations",
"TagResource",
"UntagResource",
"ListTagsForResource",
"AssociateAlias",
"ListConflictingAliases",
"ListDistributionsByCachePolicyId",
"ListDistributionsByOriginRequestPolicyId",
"ListDistributionsByResponseHeadersPolicyId",
"ListDistributionsByKeyGroup",
"ListDistributionsByWebACLId",
"ListDistributionsByVpcOriginId",
"ListDistributionsByAnycastIpListId",
"ListDistributionsByConnectionMode",
"ListDistributionsByConnectionFunction",
"ListDistributionsByOwnedResource",
"ListDistributionsByTrustStore",
"ListDistributionsByRealtimeLogConfig",
"AssociateDistributionWebACL",
"DisassociateDistributionWebACL",
"CreateOriginAccessControl",
"GetOriginAccessControl",
"GetOriginAccessControlConfig",
"UpdateOriginAccessControl",
"DeleteOriginAccessControl",
"ListOriginAccessControls",
"CreateCachePolicy",
"GetCachePolicy",
"GetCachePolicyConfig",
"UpdateCachePolicy",
"DeleteCachePolicy",
"ListCachePolicies",
"CreateOriginRequestPolicy",
"GetOriginRequestPolicy",
"GetOriginRequestPolicyConfig",
"UpdateOriginRequestPolicy",
"DeleteOriginRequestPolicy",
"ListOriginRequestPolicies",
"CreateResponseHeadersPolicy",
"GetResponseHeadersPolicy",
"GetResponseHeadersPolicyConfig",
"UpdateResponseHeadersPolicy",
"DeleteResponseHeadersPolicy",
"ListResponseHeadersPolicies",
"CreateContinuousDeploymentPolicy",
"GetContinuousDeploymentPolicy",
"GetContinuousDeploymentPolicyConfig",
"UpdateContinuousDeploymentPolicy",
"DeleteContinuousDeploymentPolicy",
"ListContinuousDeploymentPolicies",
"CreateFunction",
"DescribeFunction",
"GetFunction",
"UpdateFunction",
"DeleteFunction",
"ListFunctions",
"PublishFunction",
"TestFunction",
"CreatePublicKey",
"GetPublicKey",
"GetPublicKeyConfig",
"UpdatePublicKey",
"DeletePublicKey",
"ListPublicKeys",
"CreateKeyGroup",
"GetKeyGroup",
"GetKeyGroupConfig",
"UpdateKeyGroup",
"DeleteKeyGroup",
"ListKeyGroups",
"CreateKeyValueStore",
"DescribeKeyValueStore",
"UpdateKeyValueStore",
"DeleteKeyValueStore",
"ListKeyValueStores",
"CreateCloudFrontOriginAccessIdentity",
"GetCloudFrontOriginAccessIdentity",
"GetCloudFrontOriginAccessIdentityConfig",
"UpdateCloudFrontOriginAccessIdentity",
"DeleteCloudFrontOriginAccessIdentity",
"ListCloudFrontOriginAccessIdentities",
"CreateMonitoringSubscription",
"GetMonitoringSubscription",
"DeleteMonitoringSubscription",
"CreateStreamingDistribution",
"CreateStreamingDistributionWithTags",
"GetStreamingDistribution",
"GetStreamingDistributionConfig",
"UpdateStreamingDistribution",
"DeleteStreamingDistribution",
"ListStreamingDistributions",
"CreateFieldLevelEncryptionConfig",
"GetFieldLevelEncryption",
"GetFieldLevelEncryptionConfig",
"UpdateFieldLevelEncryptionConfig",
"DeleteFieldLevelEncryptionConfig",
"ListFieldLevelEncryptionConfigs",
"CreateFieldLevelEncryptionProfile",
"GetFieldLevelEncryptionProfile",
"GetFieldLevelEncryptionProfileConfig",
"UpdateFieldLevelEncryptionProfile",
"DeleteFieldLevelEncryptionProfile",
"ListFieldLevelEncryptionProfiles",
"CreateRealtimeLogConfig",
"GetRealtimeLogConfig",
"UpdateRealtimeLogConfig",
"DeleteRealtimeLogConfig",
"ListRealtimeLogConfigs",
"CreateVpcOrigin",
"GetVpcOrigin",
"UpdateVpcOrigin",
"DeleteVpcOrigin",
"ListVpcOrigins",
"CreateAnycastIpList",
"GetAnycastIpList",
"UpdateAnycastIpList",
"DeleteAnycastIpList",
"ListAnycastIpLists",
"CreateTrustStore",
"GetTrustStore",
"UpdateTrustStore",
"DeleteTrustStore",
"ListTrustStores",
"GetResourcePolicy",
"PutResourcePolicy",
"DeleteResourcePolicy",
"CreateConnectionGroup",
"GetConnectionGroup",
"GetConnectionGroupByRoutingEndpoint",
"UpdateConnectionGroup",
"DeleteConnectionGroup",
"ListConnectionGroups",
"ListDomainConflicts",
"UpdateDomainAssociation",
"VerifyDnsConfiguration",
"GetManagedCertificateDetails",
"UpdateDistributionWithStagingConfig",
"CreateDistributionTenant",
"GetDistributionTenant",
"GetDistributionTenantByDomain",
"UpdateDistributionTenant",
"DeleteDistributionTenant",
"ListDistributionTenants",
"ListDistributionTenantsByCustomization",
"AssociateDistributionTenantWebACL",
"DisassociateDistributionTenantWebACL",
"CreateInvalidationForDistributionTenant",
"GetInvalidationForDistributionTenant",
"ListInvalidationsForDistributionTenant",
"CreateConnectionFunction",
"GetConnectionFunction",
"DescribeConnectionFunction",
"UpdateConnectionFunction",
"DeleteConnectionFunction",
"ListConnectionFunctions",
"PublishConnectionFunction",
"TestConnectionFunction",
];
pub struct CloudFrontService {
pub(crate) state: SharedCloudFrontState,
pub(crate) propagation_delay: std::time::Duration,
}
fn default_propagation_delay() -> std::time::Duration {
let Ok(raw) = std::env::var("FAKECLOUD_CLOUDFRONT_STATUS_DELAY_SEC") else {
return std::time::Duration::from_secs(1);
};
let trimmed = raw.trim();
if let Ok(secs) = trimmed.parse::<u64>() {
return std::time::Duration::from_secs(secs);
}
if let Ok(secs) = trimmed.parse::<f64>() {
if secs.is_finite() && secs >= 0.0 {
return std::time::Duration::from_secs_f64(secs);
}
}
std::time::Duration::from_secs(1)
}
impl CloudFrontService {
pub fn new(state: SharedCloudFrontState) -> Self {
Self {
state,
propagation_delay: default_propagation_delay(),
}
}
pub fn shared_state(&self) -> SharedCloudFrontState {
Arc::clone(&self.state)
}
pub fn with_propagation_delay(mut self, delay: std::time::Duration) -> Self {
self.propagation_delay = delay;
self
}
pub fn set_distribution_status(&self, id: &str, status: &str) -> bool {
let mut state = self.state.write();
for account in state.accounts.values_mut() {
if let Some(dist) = account.distributions.get_mut(id) {
dist.status = status.to_string();
return true;
}
}
false
}
}
impl Default for CloudFrontService {
fn default() -> Self {
Self::new(Arc::new(RwLock::new(CloudFrontAccounts::new())))
}
}
#[async_trait]
impl AwsService for CloudFrontService {
fn service_name(&self) -> &str {
"cloudfront"
}
fn supported_actions(&self) -> &[&str] {
SUPPORTED_ACTIONS
}
async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let resolved = match route(&req.method, &req.raw_path, &req.raw_query) {
Some(r) => r,
None => {
return Err(aws_error(
StatusCode::NOT_FOUND,
"InvalidArgument",
format!("Unknown CloudFront route: {} {}", req.method, req.raw_path),
));
}
};
match resolved.action {
"CreateDistribution" => self.create_distribution(&req, false),
"CreateDistributionWithTags" => self.create_distribution(&req, true),
"GetDistribution" => self.get_distribution(&resolved),
"GetDistributionConfig" => self.get_distribution_config(&resolved),
"UpdateDistribution" => self.update_distribution(&req, &resolved),
"DeleteDistribution" => self.delete_distribution(&req, &resolved),
"ListDistributions" => self.list_distributions(&req),
"CopyDistribution" => self.copy_distribution(&req, &resolved),
"CreateInvalidation" => self.create_invalidation(&req, &resolved),
"GetInvalidation" => self.get_invalidation(&resolved),
"ListInvalidations" => self.list_invalidations(&resolved),
"TagResource" => self.tag_resource(&req),
"UntagResource" => self.untag_resource(&req),
"ListTagsForResource" => self.list_tags_for_resource(&req),
"AssociateAlias" => self.associate_alias(&req, &resolved),
"ListConflictingAliases" => self.list_conflicting_aliases(&req),
"AssociateDistributionWebACL" => self.associate_web_acl(&req, &resolved),
"DisassociateDistributionWebACL" => self.disassociate_web_acl(&req, &resolved),
"ListDistributionsByCachePolicyId"
| "ListDistributionsByOriginRequestPolicyId"
| "ListDistributionsByResponseHeadersPolicyId"
| "ListDistributionsByKeyGroup"
| "ListDistributionsByWebACLId"
| "ListDistributionsByVpcOriginId"
| "ListDistributionsByAnycastIpListId"
| "ListDistributionsByConnectionMode"
| "ListDistributionsByConnectionFunction"
| "ListDistributionsByOwnedResource"
| "ListDistributionsByTrustStore"
| "ListDistributionsByRealtimeLogConfig" => self.list_distributions_by(resolved.action),
"CreateOriginAccessControl" => self.create_origin_access_control(&req),
"GetOriginAccessControl" => self.get_origin_access_control(&resolved),
"GetOriginAccessControlConfig" => self.get_origin_access_control_config(&resolved),
"UpdateOriginAccessControl" => self.update_origin_access_control(&req, &resolved),
"DeleteOriginAccessControl" => self.delete_origin_access_control(&req, &resolved),
"ListOriginAccessControls" => self.list_origin_access_controls(&req),
"CreateCachePolicy" => self.create_cache_policy(&req),
"GetCachePolicy" => self.get_cache_policy(&resolved),
"GetCachePolicyConfig" => self.get_cache_policy_config(&resolved),
"UpdateCachePolicy" => self.update_cache_policy(&req, &resolved),
"DeleteCachePolicy" => self.delete_cache_policy(&req, &resolved),
"ListCachePolicies" => self.list_cache_policies(&req),
"CreateOriginRequestPolicy" => self.create_origin_request_policy(&req),
"GetOriginRequestPolicy" => self.get_origin_request_policy(&resolved),
"GetOriginRequestPolicyConfig" => self.get_origin_request_policy_config(&resolved),
"UpdateOriginRequestPolicy" => self.update_origin_request_policy(&req, &resolved),
"DeleteOriginRequestPolicy" => self.delete_origin_request_policy(&req, &resolved),
"ListOriginRequestPolicies" => self.list_origin_request_policies(&req),
"CreateResponseHeadersPolicy" => self.create_response_headers_policy(&req),
"GetResponseHeadersPolicy" => self.get_response_headers_policy(&resolved),
"GetResponseHeadersPolicyConfig" => self.get_response_headers_policy_config(&resolved),
"UpdateResponseHeadersPolicy" => self.update_response_headers_policy(&req, &resolved),
"DeleteResponseHeadersPolicy" => self.delete_response_headers_policy(&req, &resolved),
"ListResponseHeadersPolicies" => self.list_response_headers_policies(&req),
"CreateContinuousDeploymentPolicy" => self.create_continuous_deployment_policy(&req),
"GetContinuousDeploymentPolicy" => self.get_continuous_deployment_policy(&resolved),
"GetContinuousDeploymentPolicyConfig" => {
self.get_continuous_deployment_policy_config(&resolved)
}
"UpdateContinuousDeploymentPolicy" => {
self.update_continuous_deployment_policy(&req, &resolved)
}
"DeleteContinuousDeploymentPolicy" => {
self.delete_continuous_deployment_policy(&req, &resolved)
}
"ListContinuousDeploymentPolicies" => self.list_continuous_deployment_policies(&req),
"CreateFunction" => self.create_function(&req),
"DescribeFunction" => self.describe_function(&req, &resolved),
"GetFunction" => self.get_function(&req, &resolved),
"UpdateFunction" => self.update_function(&req, &resolved),
"DeleteFunction" => self.delete_function(&req, &resolved),
"ListFunctions" => self.list_functions(&req),
"PublishFunction" => self.publish_function(&req, &resolved),
"TestFunction" => self.test_function(&req, &resolved),
"CreatePublicKey" => self.create_public_key(&req),
"GetPublicKey" => self.get_public_key(&resolved),
"GetPublicKeyConfig" => self.get_public_key_config(&resolved),
"UpdatePublicKey" => self.update_public_key(&req, &resolved),
"DeletePublicKey" => self.delete_public_key(&req, &resolved),
"ListPublicKeys" => self.list_public_keys(&req),
"CreateKeyGroup" => self.create_key_group(&req),
"GetKeyGroup" => self.get_key_group(&resolved),
"GetKeyGroupConfig" => self.get_key_group_config(&resolved),
"UpdateKeyGroup" => self.update_key_group(&req, &resolved),
"DeleteKeyGroup" => self.delete_key_group(&req, &resolved),
"ListKeyGroups" => self.list_key_groups(&req),
"CreateKeyValueStore" => self.create_key_value_store(&req),
"DescribeKeyValueStore" => self.describe_key_value_store(&resolved),
"UpdateKeyValueStore" => self.update_key_value_store(&req, &resolved),
"DeleteKeyValueStore" => self.delete_key_value_store(&req, &resolved),
"ListKeyValueStores" => self.list_key_value_stores(&req),
"CreateCloudFrontOriginAccessIdentity" => self.create_oai(&req),
"GetCloudFrontOriginAccessIdentity" => self.get_oai(&resolved),
"GetCloudFrontOriginAccessIdentityConfig" => self.get_oai_config(&resolved),
"UpdateCloudFrontOriginAccessIdentity" => self.update_oai(&req, &resolved),
"DeleteCloudFrontOriginAccessIdentity" => self.delete_oai(&req, &resolved),
"ListCloudFrontOriginAccessIdentities" => self.list_oai(&req),
"CreateMonitoringSubscription" => self.create_monitoring_subscription(&req, &resolved),
"GetMonitoringSubscription" => self.get_monitoring_subscription(&resolved),
"DeleteMonitoringSubscription" => self.delete_monitoring_subscription(&resolved),
"CreateStreamingDistribution" => self.create_streaming_distribution(&req, false),
"CreateStreamingDistributionWithTags" => self.create_streaming_distribution(&req, true),
"GetStreamingDistribution" => self.get_streaming_distribution(&resolved),
"GetStreamingDistributionConfig" => self.get_streaming_distribution_config(&resolved),
"UpdateStreamingDistribution" => self.update_streaming_distribution(&req, &resolved),
"DeleteStreamingDistribution" => self.delete_streaming_distribution(&req, &resolved),
"ListStreamingDistributions" => self.list_streaming_distributions(&req),
"CreateFieldLevelEncryptionConfig" => self.create_field_level_encryption_config(&req),
"GetFieldLevelEncryption" => self.get_field_level_encryption(&resolved),
"GetFieldLevelEncryptionConfig" => self.get_field_level_encryption_config(&resolved),
"UpdateFieldLevelEncryptionConfig" => {
self.update_field_level_encryption_config(&req, &resolved)
}
"DeleteFieldLevelEncryptionConfig" => {
self.delete_field_level_encryption_config(&req, &resolved)
}
"ListFieldLevelEncryptionConfigs" => self.list_field_level_encryption_configs(&req),
"CreateFieldLevelEncryptionProfile" => self.create_field_level_encryption_profile(&req),
"GetFieldLevelEncryptionProfile" => self.get_field_level_encryption_profile(&resolved),
"GetFieldLevelEncryptionProfileConfig" => {
self.get_field_level_encryption_profile_config(&resolved)
}
"UpdateFieldLevelEncryptionProfile" => {
self.update_field_level_encryption_profile(&req, &resolved)
}
"DeleteFieldLevelEncryptionProfile" => {
self.delete_field_level_encryption_profile(&req, &resolved)
}
"ListFieldLevelEncryptionProfiles" => self.list_field_level_encryption_profiles(&req),
"CreateRealtimeLogConfig" => self.create_realtime_log_config(&req),
"GetRealtimeLogConfig" => self.get_realtime_log_config(&req),
"UpdateRealtimeLogConfig" => self.update_realtime_log_config(&req),
"DeleteRealtimeLogConfig" => self.delete_realtime_log_config(&req),
"ListRealtimeLogConfigs" => self.list_realtime_log_configs(&req),
"CreateVpcOrigin" => self.create_vpc_origin(&req),
"GetVpcOrigin" => self.get_vpc_origin(&resolved),
"UpdateVpcOrigin" => self.update_vpc_origin(&req, &resolved),
"DeleteVpcOrigin" => self.delete_vpc_origin(&req, &resolved),
"ListVpcOrigins" => self.list_vpc_origins(&req),
"CreateAnycastIpList" => self.create_anycast_ip_list(&req),
"GetAnycastIpList" => self.get_anycast_ip_list(&resolved),
"UpdateAnycastIpList" => self.update_anycast_ip_list(&req, &resolved),
"DeleteAnycastIpList" => self.delete_anycast_ip_list(&req, &resolved),
"ListAnycastIpLists" => self.list_anycast_ip_lists(&req),
"CreateTrustStore" => self.create_trust_store(&req),
"GetTrustStore" => self.get_trust_store(&resolved),
"UpdateTrustStore" => self.update_trust_store(&req, &resolved),
"DeleteTrustStore" => self.delete_trust_store(&req, &resolved),
"ListTrustStores" => self.list_trust_stores(&req),
"GetResourcePolicy" => self.get_resource_policy(&req),
"PutResourcePolicy" => self.put_resource_policy(&req),
"DeleteResourcePolicy" => self.delete_resource_policy(&req),
"CreateConnectionGroup" => self.create_connection_group(&req),
"GetConnectionGroup" => self.get_connection_group(&resolved),
"GetConnectionGroupByRoutingEndpoint" => {
self.get_connection_group_by_routing_endpoint(&req)
}
"UpdateConnectionGroup" => self.update_connection_group(&req, &resolved),
"DeleteConnectionGroup" => self.delete_connection_group(&req, &resolved),
"ListConnectionGroups" => self.list_connection_groups(&req),
"ListDomainConflicts" => self.list_domain_conflicts(&req),
"UpdateDomainAssociation" => self.update_domain_association(&req),
"VerifyDnsConfiguration" => self.verify_dns_configuration(&req),
"GetManagedCertificateDetails" => self.get_managed_certificate_details(&resolved),
"UpdateDistributionWithStagingConfig" => {
self.update_distribution_with_staging_config(&req, &resolved)
}
"CreateDistributionTenant" => self.create_distribution_tenant(&req),
"GetDistributionTenant" => self.get_distribution_tenant(&resolved),
"GetDistributionTenantByDomain" => self.get_distribution_tenant_by_domain(&req),
"UpdateDistributionTenant" => self.update_distribution_tenant(&req, &resolved),
"DeleteDistributionTenant" => self.delete_distribution_tenant(&req, &resolved),
"ListDistributionTenants" => self.list_distribution_tenants(&req),
"ListDistributionTenantsByCustomization" => {
self.list_distribution_tenants_by_customization(&req)
}
"AssociateDistributionTenantWebACL" => {
self.associate_distribution_tenant_web_acl(&req, &resolved)
}
"DisassociateDistributionTenantWebACL" => {
self.disassociate_distribution_tenant_web_acl(&req, &resolved)
}
"CreateInvalidationForDistributionTenant" => {
self.create_invalidation_for_distribution_tenant(&req, &resolved)
}
"GetInvalidationForDistributionTenant" => {
self.get_invalidation_for_distribution_tenant(&resolved)
}
"ListInvalidationsForDistributionTenant" => {
self.list_invalidations_for_distribution_tenant(&resolved)
}
"CreateConnectionFunction" => self.create_connection_function(&req),
"GetConnectionFunction" => self.get_connection_function(&resolved),
"DescribeConnectionFunction" => self.describe_connection_function(&resolved),
"UpdateConnectionFunction" => self.update_connection_function(&req, &resolved),
"DeleteConnectionFunction" => self.delete_connection_function(&req, &resolved),
"ListConnectionFunctions" => self.list_connection_functions(&req),
"PublishConnectionFunction" => self.publish_connection_function(&req, &resolved),
"TestConnectionFunction" => self.test_connection_function(&req, &resolved),
other => Err(aws_error(
StatusCode::NOT_IMPLEMENTED,
"InvalidAction",
format!("CloudFront action {other} is not implemented yet"),
)),
}
}
}
impl CloudFrontService {
fn create_distribution(
&self,
req: &AwsRequest,
with_tags: bool,
) -> Result<AwsResponse, AwsServiceError> {
let (config, tags) = if with_tags {
let parsed: DistributionConfigWithTags = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid request XML: {e}")))?;
let tags = parsed
.tags
.items
.map(|i| {
i.tag
.into_iter()
.map(|t| Tag {
key: t.key,
value: t.value,
})
.collect()
})
.unwrap_or_default();
(parsed.distribution_config, tags)
} else {
let parsed: DistributionConfig = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid request XML: {e}")))?;
(parsed, Vec::new())
};
validate_caller_reference(&config.caller_reference)?;
validate_origins(&config)?;
let mut state = self.state.write();
let account = state.entry(account_id(req));
if let Some(existing) = account
.distributions
.values()
.find(|d| d.config.caller_reference == config.caller_reference)
{
return Err(aws_error(
StatusCode::CONFLICT,
"DistributionAlreadyExists",
format!(
"Distribution with the same CallerReference exists: {}",
existing.id
),
));
}
let id = generate_distribution_id();
let now = Utc::now();
let etag = generate_etag();
let domain = format!("{}.cloudfront.net", id.to_lowercase());
let arn = format!(
"arn:aws:cloudfront::{}:distribution/{}",
account_id(req),
id
);
let stored = StoredDistribution {
id: id.clone(),
arn: arn.clone(),
status: "InProgress".to_string(),
last_modified_time: now,
domain_name: domain,
in_progress_invalidation_batches: 0,
etag: etag.clone(),
config,
};
account.distributions.insert(id.clone(), stored.clone());
if !tags.is_empty() {
account.tags.insert(arn.clone(), tags);
}
drop(state);
self.schedule_distribution_deploy(id.clone());
let body = build_distribution_xml(&stored);
let mut headers = HeaderMap::new();
set_header(&mut headers, ETAG, &etag);
set_header(&mut headers, LOCATION, &stored.arn);
Ok(xml_response(StatusCode::CREATED, body, headers))
}
fn schedule_distribution_deploy(&self, id: String) {
let state = Arc::clone(&self.state);
let delay = self.propagation_delay;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let mut s = state.write();
for account in s.accounts.values_mut() {
if let Some(d) = account.distributions.get_mut(&id) {
if d.status == "InProgress" {
d.status = "Deployed".to_string();
}
return;
}
}
});
}
pub(crate) fn schedule_distribution_tenant_deploy(&self, id: String) {
let state = Arc::clone(&self.state);
let delay = self.propagation_delay;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let mut s = state.write();
for account in s.accounts.values_mut() {
if let Some(t) = account.distribution_tenants.get_mut(&id) {
if t.status == "InProgress" {
t.status = "Deployed".to_string();
}
return;
}
}
});
}
pub(crate) fn schedule_connection_group_deploy(&self, id: String) {
let state = Arc::clone(&self.state);
let delay = self.propagation_delay;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let mut s = state.write();
for account in s.accounts.values_mut() {
if let Some(g) = account.connection_groups.get_mut(&id) {
if g.status == "InProgress" {
g.status = "Deployed".to_string();
}
return;
}
}
});
}
pub(crate) fn schedule_streaming_distribution_deploy(&self, id: String) {
let state = Arc::clone(&self.state);
let delay = self.propagation_delay;
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let mut s = state.write();
for account in s.accounts.values_mut() {
if let Some(d) = account.streaming_distributions.get_mut(&id) {
if d.status == "InProgress" {
d.status = "Deployed".to_string();
}
return;
}
}
});
}
fn get_distribution(&self, route: &Route) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let state = self.state.read();
let account = state
.accounts
.get(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
let dist = account
.distributions
.get(id)
.ok_or_else(|| no_such_distribution(id))?
.clone();
drop(state);
let body = build_distribution_xml(&dist);
let mut headers = HeaderMap::new();
set_header(&mut headers, ETAG, &dist.etag);
Ok(xml_response(StatusCode::OK, body, headers))
}
fn get_distribution_config(&self, route: &Route) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let state = self.state.read();
let account = state
.accounts
.get(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
let dist = account
.distributions
.get(id)
.ok_or_else(|| no_such_distribution(id))?
.clone();
drop(state);
let body = xml_io::to_xml_root("DistributionConfig", &dist.config)
.map_err(|e| internal_error(format!("xml encode failed: {e}")))?;
let mut headers = HeaderMap::new();
set_header(&mut headers, ETAG, &dist.etag);
Ok(xml_response(StatusCode::OK, body, headers))
}
fn update_distribution(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let if_match = req
.headers
.get(IF_MATCH)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
aws_error(
StatusCode::BAD_REQUEST,
"InvalidIfMatchVersion",
"Missing If-Match header for UpdateDistribution",
)
})?
.to_string();
let new_config: DistributionConfig = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid DistributionConfig XML: {e}")))?;
validate_caller_reference(&new_config.caller_reference)?;
validate_origins(&new_config)?;
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
let dist = account
.distributions
.get_mut(id)
.ok_or_else(|| no_such_distribution(id))?;
if dist.etag != if_match {
return Err(aws_error(
StatusCode::PRECONDITION_FAILED,
"PreconditionFailed",
"If-Match header does not match the current ETag",
));
}
let config_changed = !configs_equal(&dist.config, &new_config);
if config_changed {
dist.config = new_config;
dist.etag = generate_etag();
dist.last_modified_time = Utc::now();
dist.status = "InProgress".to_string();
}
let snapshot = dist.clone();
drop(state);
if config_changed {
self.schedule_distribution_deploy(id.to_string());
}
let body = build_distribution_xml(&snapshot);
let mut headers = HeaderMap::new();
set_header(&mut headers, ETAG, &snapshot.etag);
Ok(xml_response(StatusCode::OK, body, headers))
}
fn delete_distribution(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let if_match = req
.headers
.get(IF_MATCH)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
aws_error(
StatusCode::BAD_REQUEST,
"InvalidIfMatchVersion",
"Missing If-Match header for DeleteDistribution",
)
})?
.to_string();
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
{
let dist = account
.distributions
.get(id)
.ok_or_else(|| no_such_distribution(id))?;
if dist.etag != if_match {
return Err(aws_error(
StatusCode::PRECONDITION_FAILED,
"PreconditionFailed",
"If-Match header does not match the current ETag",
));
}
if dist.config.enabled {
return Err(aws_error(
StatusCode::PRECONDITION_FAILED,
"DistributionNotDisabled",
"Distribution must be disabled before delete",
));
}
}
let removed = account.distributions.remove(id).unwrap();
account.tags.remove(&removed.arn);
Ok(empty_response(StatusCode::NO_CONTENT))
}
fn list_distributions(&self, _req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let state = self.state.read();
let mut dists: Vec<StoredDistribution> = state
.accounts
.values()
.flat_map(|a| a.distributions.values().cloned())
.collect();
dists.sort_by_key(|a| a.last_modified_time);
drop(state);
let body = build_distribution_list_xml(&dists, "DistributionList");
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
fn list_distributions_by(&self, action: &str) -> Result<AwsResponse, AwsServiceError> {
let root = match action {
"ListDistributionsByCachePolicyId"
| "ListDistributionsByOriginRequestPolicyId"
| "ListDistributionsByResponseHeadersPolicyId"
| "ListDistributionsByKeyGroup"
| "ListDistributionsByVpcOriginId" => "DistributionIdList",
"ListDistributionsByOwnedResource" => "DistributionIdOwnerList",
_ => "DistributionList",
};
let body = build_empty_distribution_id_list(root);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
fn copy_distribution(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let primary_id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing primary distribution id"))?;
let if_match = req
.headers
.get(IF_MATCH)
.and_then(|v| v.to_str().ok())
.ok_or_else(|| {
aws_error(
StatusCode::BAD_REQUEST,
"InvalidIfMatchVersion",
"Missing If-Match header for CopyDistribution",
)
})?
.to_string();
let parsed: CopyDistributionRequest = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid request XML: {e}")))?;
validate_caller_reference(&parsed.caller_reference)?;
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(primary_id))?;
let primary = account
.distributions
.get(primary_id)
.ok_or_else(|| no_such_distribution(primary_id))?
.clone();
if primary.etag != if_match {
return Err(aws_error(
StatusCode::PRECONDITION_FAILED,
"PreconditionFailed",
"If-Match header does not match the current ETag",
));
}
if account
.distributions
.values()
.any(|d| d.config.caller_reference == parsed.caller_reference)
{
return Err(aws_error(
StatusCode::CONFLICT,
"DistributionAlreadyExists",
"Distribution with the same CallerReference exists",
));
}
let new_id = generate_distribution_id();
let mut config = primary.config.clone();
config.caller_reference = parsed.caller_reference;
config.enabled = parsed.enabled.unwrap_or(false);
config.staging = parsed.staging;
let now = Utc::now();
let etag = generate_etag();
let arn = format!(
"arn:aws:cloudfront::{}:distribution/{}",
account_id(req),
new_id
);
let stored = StoredDistribution {
id: new_id.clone(),
arn: arn.clone(),
status: "InProgress".to_string(),
last_modified_time: now,
domain_name: format!("{}.cloudfront.net", new_id.to_lowercase()),
in_progress_invalidation_batches: 0,
etag: etag.clone(),
config,
};
account.distributions.insert(new_id.clone(), stored.clone());
drop(state);
self.schedule_distribution_deploy(new_id);
let body = build_distribution_xml(&stored);
let mut headers = HeaderMap::new();
set_header(&mut headers, ETAG, &etag);
set_header(&mut headers, LOCATION, &stored.arn);
Ok(xml_response(StatusCode::CREATED, body, headers))
}
}
#[derive(Debug, serde::Deserialize, Default)]
#[serde(rename_all = "PascalCase")]
struct CopyDistributionRequest {
caller_reference: String,
#[serde(default)]
enabled: Option<bool>,
#[serde(default)]
staging: Option<bool>,
}
impl CloudFrontService {
fn create_invalidation(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let dist_id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let batch: InvalidationBatch = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid InvalidationBatch XML: {e}")))?;
if batch.caller_reference.is_empty() {
return Err(invalid_argument("CallerReference is required"));
}
if batch.paths.quantity < 1 {
return Err(invalid_argument(
"InvalidationBatch.Paths must be non-empty",
));
}
let mut state = self.state.write();
let account = state.entry(DEFAULT_ACCOUNT);
if !account.distributions.contains_key(dist_id) {
return Err(no_such_distribution(dist_id));
}
let id = generate_invalidation_id();
let stored = StoredInvalidation {
id: id.clone(),
distribution_id: dist_id.to_string(),
status: "Completed".to_string(),
create_time: Utc::now(),
batch: batch.clone(),
};
account.invalidations.insert(id.clone(), stored.clone());
drop(state);
let body = build_invalidation_xml(&stored);
let mut headers = HeaderMap::new();
set_header(
&mut headers,
LOCATION,
&format!(
"/2020-05-31/distribution/{dist_id}/invalidation/{}",
stored.id
),
);
Ok(xml_response(StatusCode::CREATED, body, headers))
}
fn get_invalidation(&self, route: &Route) -> Result<AwsResponse, AwsServiceError> {
let dist_id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let inv_id = route
.second_id
.as_deref()
.ok_or_else(|| invalid_argument("missing invalidation id"))?;
let state = self.state.read();
let account = state
.accounts
.get(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_invalidation(inv_id))?;
if !account.distributions.contains_key(dist_id) {
return Err(no_such_distribution(dist_id));
}
let inv = account
.invalidations
.get(inv_id)
.filter(|i| i.distribution_id == dist_id)
.ok_or_else(|| no_such_invalidation(inv_id))?
.clone();
drop(state);
let body = build_invalidation_xml(&inv);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
fn list_invalidations(&self, route: &Route) -> Result<AwsResponse, AwsServiceError> {
let dist_id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let state = self.state.read();
let account = state
.accounts
.get(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(dist_id))?;
if !account.distributions.contains_key(dist_id) {
return Err(no_such_distribution(dist_id));
}
let mut items: Vec<&StoredInvalidation> = account
.invalidations
.values()
.filter(|i| i.distribution_id == dist_id)
.collect();
items.sort_by_key(|a| a.create_time);
let body = build_invalidation_list_xml(&items);
drop(state);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
}
impl CloudFrontService {
fn parse_arn_query(query: &str) -> Option<String> {
for pair in query.split('&').filter(|p| !p.is_empty()) {
if let Some(rest) = pair.strip_prefix("Resource=") {
return Some(percent_decode(rest));
}
}
None
}
fn tag_resource(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let arn = Self::parse_arn_query(&req.raw_query)
.ok_or_else(|| invalid_argument("Resource query parameter is required"))?;
let parsed: ModelTags = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid Tags XML: {e}")))?;
let new_tags: Vec<Tag> = parsed
.items
.map(|i| {
i.tag
.into_iter()
.map(|t| Tag {
key: t.key,
value: t.value,
})
.collect()
})
.unwrap_or_default();
let mut state = self.state.write();
let account = state.entry(DEFAULT_ACCOUNT);
let entry = account.tags.entry(arn).or_default();
for tag in new_tags {
if let Some(existing) = entry.iter_mut().find(|t| t.key == tag.key) {
existing.value = tag.value;
} else {
entry.push(tag);
}
}
Ok(empty_response(StatusCode::NO_CONTENT))
}
fn untag_resource(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let arn = Self::parse_arn_query(&req.raw_query)
.ok_or_else(|| invalid_argument("Resource query parameter is required"))?;
let parsed: TagKeys = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid TagKeys XML: {e}")))?;
let keys: Vec<String> = parsed.items.map(|k| k.key).unwrap_or_default();
let mut state = self.state.write();
let account = state.entry(DEFAULT_ACCOUNT);
if let Some(existing) = account.tags.get_mut(&arn) {
existing.retain(|t| !keys.contains(&t.key));
}
Ok(empty_response(StatusCode::NO_CONTENT))
}
fn list_tags_for_resource(&self, req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let arn = Self::parse_arn_query(&req.raw_query)
.ok_or_else(|| invalid_argument("Resource query parameter is required"))?;
let state = self.state.read();
let tags = state
.accounts
.get(DEFAULT_ACCOUNT)
.and_then(|a| a.tags.get(&arn))
.cloned()
.unwrap_or_default();
drop(state);
let body = build_tags_xml(&tags);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
}
impl CloudFrontService {
fn associate_alias(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let alias = parse_query_value(&req.raw_query, "Alias")
.ok_or_else(|| invalid_argument("Alias query parameter is required"))?;
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
if let Some(other) = account.distributions.values().find(|d| {
d.id != id
&& d.config
.aliases
.as_ref()
.and_then(|a| a.items.as_ref())
.is_some_and(|i| i.cname.iter().any(|c| c == &alias))
}) {
return Err(aws_error(
StatusCode::CONFLICT,
"CNAMEAlreadyExists",
format!(
"Alias {alias} is already associated with distribution {}",
other.id
),
));
}
let dist = account
.distributions
.get_mut(id)
.ok_or_else(|| no_such_distribution(id))?;
let aliases = dist.config.aliases.get_or_insert_with(Default::default);
let items = aliases
.items
.get_or_insert_with(crate::model::AliasItems::default);
if !items.cname.iter().any(|c| c == &alias) {
items.cname.push(alias.clone());
aliases.quantity = items.cname.len() as i32;
}
dist.etag = generate_etag();
dist.last_modified_time = Utc::now();
Ok(empty_response(StatusCode::OK))
}
fn list_conflicting_aliases(&self, _req: &AwsRequest) -> Result<AwsResponse, AwsServiceError> {
let body = format!(
"{XML_DECL}<ConflictingAliasesList xmlns=\"{NS}\"><Quantity>0</Quantity></ConflictingAliasesList>",
NS = crate::NAMESPACE,
XML_DECL = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
fn associate_web_acl(
&self,
req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let parsed: AssociateAliasRequest = xml_io::from_xml_root(&req.body)
.map_err(|e| invalid_argument(format!("invalid request XML: {e}")))?;
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
let dist = account
.distributions
.get_mut(id)
.ok_or_else(|| no_such_distribution(id))?;
dist.config.web_acl_id = Some(parsed.web_acl_arn.clone());
dist.etag = generate_etag();
dist.last_modified_time = Utc::now();
let body = format!(
"{XML_DECL}<AssociateDistributionWebACLResult xmlns=\"{NS}\"><Id>{}</Id><WebACLArn>{}</WebACLArn></AssociateDistributionWebACLResult>",
esc(id), esc(&parsed.web_acl_arn),
NS = crate::NAMESPACE,
XML_DECL = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
fn disassociate_web_acl(
&self,
_req: &AwsRequest,
route: &Route,
) -> Result<AwsResponse, AwsServiceError> {
let id = route
.id
.as_deref()
.ok_or_else(|| invalid_argument("missing distribution id"))?;
let mut state = self.state.write();
let account = state
.accounts
.get_mut(DEFAULT_ACCOUNT)
.ok_or_else(|| no_such_distribution(id))?;
let dist = account
.distributions
.get_mut(id)
.ok_or_else(|| no_such_distribution(id))?;
dist.config.web_acl_id = None;
dist.etag = generate_etag();
dist.last_modified_time = Utc::now();
let body = format!(
"{XML_DECL}<DisassociateDistributionWebACLResult xmlns=\"{NS}\"><Id>{}</Id></DisassociateDistributionWebACLResult>",
esc(id),
NS = crate::NAMESPACE,
XML_DECL = "<?xml version=\"1.0\" encoding=\"UTF-8\"?>"
);
Ok(xml_response(StatusCode::OK, body, HeaderMap::new()))
}
}
#[derive(serde::Deserialize, Default, Debug)]
#[serde(rename_all = "PascalCase")]
struct AssociateAliasRequest {
#[serde(rename = "WebACLArn", default)]
web_acl_arn: String,
}
pub(crate) fn esc(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'"' => out.push_str("""),
'\'' => out.push_str("'"),
_ => out.push(c),
}
}
out
}
pub(crate) fn build_distribution_xml(dist: &StoredDistribution) -> String {
let mut out = String::with_capacity(2048);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!(
"<Distribution xmlns=\"{ns}\">",
ns = crate::NAMESPACE
));
out.push_str(&format!("<Id>{}</Id>", esc(&dist.id)));
out.push_str(&format!("<ARN>{}</ARN>", esc(&dist.arn)));
out.push_str(&format!("<Status>{}</Status>", esc(&dist.status)));
out.push_str(&format!(
"<LastModifiedTime>{}</LastModifiedTime>",
rfc3339(&dist.last_modified_time)
));
out.push_str(&format!(
"<InProgressInvalidationBatches>{}</InProgressInvalidationBatches>",
dist.in_progress_invalidation_batches
));
out.push_str(&format!(
"<DomainName>{}</DomainName>",
esc(&dist.domain_name)
));
out.push_str("<ActiveTrustedSigners><Enabled>false</Enabled><Quantity>0</Quantity></ActiveTrustedSigners>");
out.push_str("<ActiveTrustedKeyGroups><Enabled>false</Enabled><Quantity>0</Quantity></ActiveTrustedKeyGroups>");
let inner = quick_xml::se::to_string_with_root("DistributionConfig", &dist.config)
.unwrap_or_else(|_| String::new());
out.push_str(&inner);
out.push_str("</Distribution>");
out
}
fn build_distribution_list_xml(dists: &[StoredDistribution], root: &str) -> String {
let mut out = String::with_capacity(2048);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!("<{root} xmlns=\"{ns}\">", ns = crate::NAMESPACE));
out.push_str("<Marker></Marker>");
out.push_str(&format!("<MaxItems>{}</MaxItems>", dists.len().max(100)));
out.push_str("<IsTruncated>false</IsTruncated>");
out.push_str(&format!("<Quantity>{}</Quantity>", dists.len()));
if dists.is_empty() {
out.push_str(&format!("</{root}>"));
return out;
}
out.push_str("<Items>");
for d in dists {
out.push_str("<DistributionSummary>");
out.push_str(&format!("<Id>{}</Id>", esc(&d.id)));
out.push_str(&format!("<ARN>{}</ARN>", esc(&d.arn)));
out.push_str(&format!("<Status>{}</Status>", esc(&d.status)));
out.push_str(&format!(
"<LastModifiedTime>{}</LastModifiedTime>",
rfc3339(&d.last_modified_time)
));
out.push_str(&format!("<DomainName>{}</DomainName>", esc(&d.domain_name)));
let aliases = d.config.aliases.clone().unwrap_or_default();
out.push_str(&render_inline("Aliases", &aliases));
let origins = d.config.origins.clone();
out.push_str(&render_inline("Origins", &origins));
let dcb = d.config.default_cache_behavior.clone();
out.push_str(&render_inline("DefaultCacheBehavior", &dcb));
let cb = d.config.cache_behaviors.clone().unwrap_or_default();
out.push_str(&render_inline("CacheBehaviors", &cb));
let cer = d.config.custom_error_responses.clone().unwrap_or_default();
out.push_str(&render_inline("CustomErrorResponses", &cer));
out.push_str(&format!("<Comment>{}</Comment>", esc(&d.config.comment)));
out.push_str(&format!(
"<PriceClass>{}</PriceClass>",
esc(&d
.config
.price_class
.clone()
.unwrap_or_else(|| "PriceClass_All".to_string()))
));
out.push_str(&format!("<Enabled>{}</Enabled>", d.config.enabled));
out.push_str(&render_inline(
"ViewerCertificate",
&d.config.viewer_certificate.clone().unwrap_or_default(),
));
out.push_str(&render_inline(
"Restrictions",
&d.config.restrictions.clone().unwrap_or_default(),
));
out.push_str(&format!(
"<WebACLId>{}</WebACLId>",
esc(&d.config.web_acl_id.clone().unwrap_or_default())
));
out.push_str(&format!(
"<HttpVersion>{}</HttpVersion>",
esc(&d
.config
.http_version
.clone()
.unwrap_or_else(|| "http2".to_string()))
));
out.push_str(&format!(
"<IsIPV6Enabled>{}</IsIPV6Enabled>",
d.config.is_ipv6_enabled.unwrap_or(true)
));
out.push_str("<Staging>false</Staging>");
out.push_str("</DistributionSummary>");
}
out.push_str("</Items>");
out.push_str(&format!("</{root}>"));
out
}
fn build_empty_distribution_id_list(root: &str) -> String {
let mut out = String::with_capacity(256);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!("<{root} xmlns=\"{ns}\">", ns = crate::NAMESPACE));
out.push_str("<Marker></Marker>");
out.push_str("<MaxItems>100</MaxItems>");
out.push_str("<IsTruncated>false</IsTruncated>");
out.push_str("<Quantity>0</Quantity>");
out.push_str(&format!("</{root}>"));
out
}
fn build_invalidation_xml(inv: &StoredInvalidation) -> String {
let mut out = String::with_capacity(512);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!(
"<Invalidation xmlns=\"{ns}\">",
ns = crate::NAMESPACE
));
out.push_str(&format!("<Id>{}</Id>", esc(&inv.id)));
out.push_str(&format!("<Status>{}</Status>", esc(&inv.status)));
out.push_str(&format!(
"<CreateTime>{}</CreateTime>",
rfc3339(&inv.create_time)
));
out.push_str(&render_inline("InvalidationBatch", &inv.batch));
out.push_str("</Invalidation>");
out
}
fn build_invalidation_list_xml(items: &[&StoredInvalidation]) -> String {
let mut out = String::with_capacity(1024);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!(
"<InvalidationList xmlns=\"{ns}\">",
ns = crate::NAMESPACE
));
out.push_str("<Marker></Marker>");
out.push_str("<MaxItems>100</MaxItems>");
out.push_str("<IsTruncated>false</IsTruncated>");
out.push_str(&format!("<Quantity>{}</Quantity>", items.len()));
if !items.is_empty() {
out.push_str("<Items>");
for inv in items {
out.push_str("<InvalidationSummary>");
out.push_str(&format!("<Id>{}</Id>", esc(&inv.id)));
out.push_str(&format!(
"<CreateTime>{}</CreateTime>",
rfc3339(&inv.create_time)
));
out.push_str(&format!("<Status>{}</Status>", esc(&inv.status)));
out.push_str("</InvalidationSummary>");
}
out.push_str("</Items>");
}
out.push_str("</InvalidationList>");
out
}
fn build_tags_xml(tags: &[Tag]) -> String {
let mut out = String::with_capacity(256);
out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
out.push_str(&format!("<Tags xmlns=\"{ns}\">", ns = crate::NAMESPACE));
out.push_str("<Items>");
for t in tags {
out.push_str("<Tag>");
out.push_str(&format!("<Key>{}</Key>", esc(&t.key)));
if let Some(v) = &t.value {
out.push_str(&format!("<Value>{}</Value>", esc(v)));
}
out.push_str("</Tag>");
}
out.push_str("</Items>");
out.push_str("</Tags>");
out
}
fn render_inline<T: serde::Serialize>(root: &str, value: &T) -> String {
quick_xml::se::to_string_with_root(root, value).unwrap_or_default()
}
fn validate_caller_reference(s: &str) -> Result<(), AwsServiceError> {
if s.is_empty() {
return Err(invalid_argument("CallerReference is required"));
}
Ok(())
}
fn validate_origins(config: &DistributionConfig) -> Result<(), AwsServiceError> {
if config.origins.quantity < 1 {
return Err(invalid_argument(
"DistributionConfig.Origins must contain at least one origin",
));
}
Ok(())
}
fn configs_equal(lhs: &DistributionConfig, rhs: &DistributionConfig) -> bool {
let Ok(a) = xml_io::to_xml_root("DistributionConfig", lhs) else {
return false;
};
let Ok(b) = xml_io::to_xml_root("DistributionConfig", rhs) else {
return false;
};
a == b
}
fn account_id(_req: &AwsRequest) -> &'static str {
DEFAULT_ACCOUNT
}
fn generate_distribution_id() -> String {
let raw = Uuid::new_v4().simple().to_string().to_uppercase();
format!("E{}", &raw[..13])
}
fn generate_invalidation_id() -> String {
let raw = Uuid::new_v4().simple().to_string().to_uppercase();
format!("I{}", &raw[..13])
}
pub(crate) fn generate_etag() -> String {
let raw = Uuid::new_v4().simple().to_string().to_uppercase();
format!("E{}", &raw[..13])
}
pub(crate) fn generate_id_with_prefix(prefix: &str) -> String {
let raw = Uuid::new_v4().simple().to_string().to_uppercase();
let suffix_len = 14usize.saturating_sub(prefix.len()).min(raw.len());
format!("{prefix}{}", &raw[..suffix_len])
}
fn rfc3339(t: &chrono::DateTime<chrono::Utc>) -> String {
t.to_rfc3339_opts(chrono::SecondsFormat::Millis, true)
}
pub(crate) fn invalid_argument(msg: impl Into<String>) -> AwsServiceError {
aws_error(StatusCode::BAD_REQUEST, "InvalidArgument", msg)
}
fn no_such_distribution(id: &str) -> AwsServiceError {
aws_error(
StatusCode::NOT_FOUND,
"NoSuchDistribution",
format!("The specified distribution does not exist: {id}"),
)
}
fn no_such_invalidation(id: &str) -> AwsServiceError {
aws_error(
StatusCode::NOT_FOUND,
"NoSuchInvalidation",
format!("The specified invalidation does not exist: {id}"),
)
}
fn internal_error(msg: impl Into<String>) -> AwsServiceError {
aws_error(StatusCode::INTERNAL_SERVER_ERROR, "InternalError", msg)
}
pub(crate) fn aws_error(
status: StatusCode,
code: impl Into<String>,
msg: impl Into<String>,
) -> AwsServiceError {
AwsServiceError::aws_error(status, code.into(), msg)
}
fn set_header(headers: &mut HeaderMap, name: HeaderName, value: &str) {
if let Ok(v) = HeaderValue::from_str(value) {
headers.insert(name, v);
}
}
pub(crate) fn xml_response(status: StatusCode, body: String, headers: HeaderMap) -> AwsResponse {
AwsResponse {
status,
content_type: "text/xml".to_string(),
body: ResponseBody::Bytes(Bytes::from(body)),
headers,
}
}
fn empty_response(status: StatusCode) -> AwsResponse {
AwsResponse {
status,
content_type: "text/xml".to_string(),
body: ResponseBody::Bytes(Bytes::new()),
headers: HeaderMap::new(),
}
}
fn percent_decode(input: &str) -> String {
let mut out = String::with_capacity(input.len());
let bytes = input.as_bytes();
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'%' && i + 2 < bytes.len() {
if let (Some(a), Some(c)) = (hex_digit(bytes[i + 1]), hex_digit(bytes[i + 2])) {
out.push(((a << 4) | c) as char);
i += 3;
continue;
}
}
if b == b'+' {
out.push(' ');
} else {
out.push(b as char);
}
i += 1;
}
out
}
fn hex_digit(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
fn parse_query_value(query: &str, key: &str) -> Option<String> {
let prefix = format!("{key}=");
for pair in query.split('&').filter(|p| !p.is_empty()) {
if let Some(rest) = pair.strip_prefix(&prefix) {
return Some(percent_decode(rest));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
fn make_state() -> SharedCloudFrontState {
Arc::new(RwLock::new(CloudFrontAccounts::new()))
}
fn make_request(method: http::Method, path: &str, query: &str, body: &str) -> AwsRequest {
AwsRequest {
service: "cloudfront".into(),
action: String::new(),
region: "us-east-1".into(),
account_id: DEFAULT_ACCOUNT.into(),
request_id: Uuid::new_v4().to_string(),
headers: HeaderMap::new(),
query_params: std::collections::HashMap::new(),
body_stream: parking_lot::Mutex::new(None),
body: Bytes::from(body.to_string()),
path_segments: path
.split('/')
.filter(|s| !s.is_empty())
.map(String::from)
.collect(),
raw_path: path.into(),
raw_query: query.into(),
method,
is_query_protocol: false,
access_key_id: None,
principal: None,
}
}
fn minimal_dist_config_xml(caller_ref: &str) -> String {
format!(
r#"<?xml version="1.0" encoding="UTF-8"?>
<DistributionConfig xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/">
<CallerReference>{caller_ref}</CallerReference>
<Origins>
<Quantity>1</Quantity>
<Items>
<Origin>
<Id>primary</Id>
<DomainName>example.com</DomainName>
</Origin>
</Items>
</Origins>
<DefaultCacheBehavior>
<TargetOriginId>primary</TargetOriginId>
<ViewerProtocolPolicy>allow-all</ViewerProtocolPolicy>
</DefaultCacheBehavior>
<Comment></Comment>
<Enabled>true</Enabled>
</DistributionConfig>"#
)
}
#[tokio::test]
async fn create_then_get_then_delete_distribution() {
let svc = CloudFrontService::new(make_state());
let body = minimal_dist_config_xml("ref-1");
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
assert_eq!(create.status, StatusCode::CREATED);
let etag = create
.headers
.get(ETAG)
.unwrap()
.to_str()
.unwrap()
.to_string();
let xml = std::str::from_utf8(create.body.expect_bytes()).unwrap();
let id = xml
.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap()
.to_string();
let get = svc
.handle(make_request(
http::Method::GET,
&format!("/2020-05-31/distribution/{id}"),
"",
"",
))
.await
.unwrap();
assert_eq!(get.status, StatusCode::OK);
let disable_body = body.replace("<Enabled>true</Enabled>", "<Enabled>false</Enabled>");
let mut update_req = make_request(
http::Method::PUT,
&format!("/2020-05-31/distribution/{id}/config"),
"",
&disable_body,
);
update_req.headers.insert(IF_MATCH, etag.parse().unwrap());
let updated = svc.handle(update_req).await.unwrap();
assert_eq!(updated.status, StatusCode::OK);
let new_etag = updated
.headers
.get(ETAG)
.unwrap()
.to_str()
.unwrap()
.to_string();
let mut del_req = make_request(
http::Method::DELETE,
&format!("/2020-05-31/distribution/{id}"),
"",
"",
);
del_req.headers.insert(IF_MATCH, new_etag.parse().unwrap());
let del = svc.handle(del_req).await.unwrap();
assert_eq!(del.status, StatusCode::NO_CONTENT);
}
async fn create_distribution_returning_id(svc: &CloudFrontService, caller_ref: &str) -> String {
let body = minimal_dist_config_xml(caller_ref);
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let xml = std::str::from_utf8(create.body.expect_bytes())
.unwrap()
.to_string();
xml.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap()
.to_string()
}
fn distribution_status(svc: &CloudFrontService, id: &str) -> String {
let state = svc.state.read();
state
.accounts
.get(DEFAULT_ACCOUNT)
.and_then(|a| a.distributions.get(id))
.map(|d| d.status.clone())
.unwrap_or_default()
}
#[tokio::test]
async fn create_distribution_starts_in_progress() {
let svc = CloudFrontService::new(make_state())
.with_propagation_delay(std::time::Duration::from_secs(60));
let body = minimal_dist_config_xml("status-ref");
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let xml = std::str::from_utf8(create.body.expect_bytes())
.unwrap()
.to_string();
assert!(
xml.contains("<Status>InProgress</Status>"),
"expected initial status InProgress, got: {xml}"
);
}
#[tokio::test]
async fn auto_transition_after_tick_marks_deployed() {
let svc = CloudFrontService::new(make_state())
.with_propagation_delay(std::time::Duration::from_millis(50));
let id = create_distribution_returning_id(&svc, "auto-tick-ref").await;
assert_eq!(distribution_status(&svc, &id), "InProgress");
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
assert_eq!(distribution_status(&svc, &id), "Deployed");
}
#[tokio::test]
async fn set_distribution_status_via_admin_flips_synchronously() {
let svc = CloudFrontService::new(make_state())
.with_propagation_delay(std::time::Duration::from_secs(60));
let id = create_distribution_returning_id(&svc, "admin-ref").await;
assert_eq!(distribution_status(&svc, &id), "InProgress");
assert!(svc.set_distribution_status(&id, "Deployed"));
assert_eq!(distribution_status(&svc, &id), "Deployed");
assert!(svc.set_distribution_status(&id, "InProgress"));
assert_eq!(distribution_status(&svc, &id), "InProgress");
}
#[tokio::test]
async fn set_distribution_status_unknown_id_returns_false() {
let svc = CloudFrontService::new(make_state());
assert!(!svc.set_distribution_status("E-DOES-NOT-EXIST", "Deployed"));
}
#[tokio::test]
async fn update_distribution_resets_to_in_progress() {
let svc = CloudFrontService::new(make_state())
.with_propagation_delay(std::time::Duration::from_secs(60));
let body = minimal_dist_config_xml("update-reset-ref");
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let etag = create
.headers
.get(ETAG)
.unwrap()
.to_str()
.unwrap()
.to_string();
let xml = std::str::from_utf8(create.body.expect_bytes())
.unwrap()
.to_string();
let id = xml
.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap()
.to_string();
assert!(svc.set_distribution_status(&id, "Deployed"));
assert_eq!(distribution_status(&svc, &id), "Deployed");
let updated_body = body.replace(
"<ViewerProtocolPolicy>allow-all</ViewerProtocolPolicy>",
"<ViewerProtocolPolicy>https-only</ViewerProtocolPolicy>",
);
let mut update_req = make_request(
http::Method::PUT,
&format!("/2020-05-31/distribution/{id}/config"),
"",
&updated_body,
);
update_req.headers.insert(IF_MATCH, etag.parse().unwrap());
let updated = svc.handle(update_req).await.unwrap();
assert_eq!(updated.status, StatusCode::OK);
assert_eq!(distribution_status(&svc, &id), "InProgress");
}
#[tokio::test]
async fn duplicate_caller_reference_is_rejected() {
let svc = CloudFrontService::new(make_state());
let body = minimal_dist_config_xml("dup-ref");
svc.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let result = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await;
let err = match result {
Ok(_) => panic!("expected duplicate caller-reference to fail"),
Err(e) => e,
};
assert_eq!(err.code(), "DistributionAlreadyExists");
assert_eq!(err.status(), StatusCode::CONFLICT);
}
#[tokio::test]
async fn invalidation_lifecycle() {
let svc = CloudFrontService::new(make_state());
let body = minimal_dist_config_xml("inv-ref");
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let xml = std::str::from_utf8(create.body.expect_bytes()).unwrap();
let dist_id = xml
.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap();
let inv_body = r#"<?xml version="1.0" encoding="UTF-8"?>
<InvalidationBatch xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/">
<Paths><Quantity>1</Quantity><Items><Path>/*</Path></Items></Paths>
<CallerReference>inv-1</CallerReference>
</InvalidationBatch>"#;
let inv_resp = svc
.handle(make_request(
http::Method::POST,
&format!("/2020-05-31/distribution/{dist_id}/invalidation"),
"",
inv_body,
))
.await
.unwrap();
assert_eq!(inv_resp.status, StatusCode::CREATED);
let inv_xml = std::str::from_utf8(inv_resp.body.expect_bytes()).unwrap();
let inv_id = inv_xml
.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap();
let get = svc
.handle(make_request(
http::Method::GET,
&format!("/2020-05-31/distribution/{dist_id}/invalidation/{inv_id}"),
"",
"",
))
.await
.unwrap();
assert_eq!(get.status, StatusCode::OK);
let list = svc
.handle(make_request(
http::Method::GET,
&format!("/2020-05-31/distribution/{dist_id}/invalidation"),
"",
"",
))
.await
.unwrap();
let xml = std::str::from_utf8(list.body.expect_bytes()).unwrap();
assert!(xml.contains("<Quantity>1</Quantity>"));
}
#[tokio::test]
async fn tags_roundtrip() {
let svc = CloudFrontService::new(make_state());
let body = minimal_dist_config_xml("tag-ref");
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let xml = std::str::from_utf8(create.body.expect_bytes()).unwrap();
let arn = xml
.split("<ARN>")
.nth(1)
.unwrap()
.split("</ARN>")
.next()
.unwrap();
let tag_body = r#"<?xml version="1.0" encoding="UTF-8"?>
<Tags xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/">
<Items><Tag><Key>env</Key><Value>prod</Value></Tag></Items>
</Tags>"#;
let arn_q = format!("Operation=Tag&Resource={}", arn);
let resp = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/tagging",
&arn_q,
tag_body,
))
.await
.unwrap();
assert_eq!(resp.status, StatusCode::NO_CONTENT);
let list = svc
.handle(make_request(
http::Method::GET,
"/2020-05-31/tagging",
&format!("Resource={}", arn),
"",
))
.await
.unwrap();
let xml = std::str::from_utf8(list.body.expect_bytes()).unwrap();
assert!(xml.contains("<Key>env</Key>"));
assert!(xml.contains("<Value>prod</Value>"));
}
#[tokio::test]
async fn xml_metacharacters_in_user_input_are_escaped() {
let svc = CloudFrontService::new(make_state());
let body = minimal_dist_config_xml("escape-ref").replace(
"<Comment></Comment>",
"<Comment><![CDATA[a&b<c>d]]></Comment>",
);
let create = svc
.handle(make_request(
http::Method::POST,
"/2020-05-31/distribution",
"",
&body,
))
.await
.unwrap();
let xml = std::str::from_utf8(create.body.expect_bytes()).unwrap();
let dist_id = xml
.split("<Id>")
.nth(1)
.unwrap()
.split("</Id>")
.next()
.unwrap();
let arn = xml
.split("<ARN>")
.nth(1)
.unwrap()
.split("</ARN>")
.next()
.unwrap();
let tag_body = r#"<?xml version="1.0" encoding="UTF-8"?>
<Tags xmlns="http://cloudfront.amazonaws.com/doc/2020-05-31/">
<Items><Tag><Key>env</Key><Value>a&b<c>d</Value></Tag></Items>
</Tags>"#;
let arn_q = format!("Operation=Tag&Resource={}", arn);
svc.handle(make_request(
http::Method::POST,
"/2020-05-31/tagging",
&arn_q,
tag_body,
))
.await
.unwrap();
let list = svc
.handle(make_request(
http::Method::GET,
"/2020-05-31/tagging",
&format!("Resource={}", arn),
"",
))
.await
.unwrap();
let xml = std::str::from_utf8(list.body.expect_bytes()).unwrap();
assert!(xml.contains("<Value>a&b<c>d</Value>"));
assert!(!xml.contains("<Value>a&b<c>d</Value>"));
let list_resp = svc
.handle(make_request(
http::Method::GET,
"/2020-05-31/distribution",
"",
"",
))
.await
.unwrap();
let xml = std::str::from_utf8(list_resp.body.expect_bytes()).unwrap();
assert!(!xml.contains("<Comment>a&b<c>d"));
assert!(xml.contains(dist_id));
}
}