use crate::error::{AwsError, Result};
use crate::utils::{with_retry, RetryConfig, RetryableError};
use aws_config::BehaviorVersion;
use aws_sdk_cloudfront::Client as CloudFrontClient;
use aws_sdk_ecr::error::SdkError;
use aws_sdk_ecr::Client as EcrClient;
use aws_sdk_ssm::Client as SsmClient;
use aws_sdk_sts::Client as StsClient;
use base64::Engine;
use bollard::auth::DockerCredentials;
#[derive(Debug)]
struct AwsRetryableError(AwsError);
impl std::fmt::Display for AwsRetryableError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl RetryableError for AwsRetryableError {
fn is_retryable(&self) -> bool {
match &self.0 {
AwsError::SdkError(msg) => {
msg.contains("ThrottlingException")
|| msg.contains("RequestTimeout")
|| msg.contains("ServiceUnavailable")
|| msg.contains("InternalServiceError")
|| msg.contains("connection")
|| msg.contains("timeout")
}
AwsError::CredentialsInvalid => false,
AwsError::EcrAuthFailed(_) => false,
AwsError::PermissionDenied(_) => false,
AwsError::EcrRepoNotFound(_) => false,
AwsError::RegionNotConfigured => false,
}
}
}
impl From<AwsRetryableError> for crate::error::ApiForgeError {
fn from(e: AwsRetryableError) -> Self {
crate::error::ApiForgeError::Aws(e.0)
}
}
pub struct AwsClient {
ecr: EcrClient,
sts: StsClient,
ssm: SsmClient,
cloudfront: CloudFrontClient,
region: String,
retry_config: RetryConfig,
}
impl AwsClient {
pub async fn new(region: &str) -> Result<Self> {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::Region::new(region.to_string()))
.load()
.await;
Ok(Self::from_sdk_config(&config, region))
}
pub async fn with_profile(region: &str, profile: &str) -> Result<Self> {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::Region::new(region.to_string()))
.profile_name(profile)
.load()
.await;
Ok(Self::from_sdk_config(&config, region))
}
fn from_sdk_config(config: &aws_config::SdkConfig, region: &str) -> Self {
Self {
ecr: EcrClient::new(config),
sts: StsClient::new(config),
ssm: SsmClient::new(config),
cloudfront: CloudFrontClient::new(config),
region: region.to_string(),
retry_config: RetryConfig::default(),
}
}
pub async fn get_ssm_parameter(&self, name: &str) -> Result<String> {
let ssm = self.ssm.clone();
let name = name.to_string();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS get_ssm_parameter", || {
let ssm = ssm.clone();
let name = name.clone();
async move {
let response = ssm
.get_parameter()
.name(&name)
.with_decryption(true)
.send()
.await
.map_err(|e| {
let msg = e.to_string();
if msg.contains("ParameterNotFound") {
AwsRetryableError(AwsError::SdkError(format!(
"SSM parameter '{}' not found",
name
)))
} else {
AwsRetryableError(AwsError::SdkError(msg))
}
})?;
response
.parameter()
.and_then(|p| p.value())
.map(|v| v.to_string())
.ok_or_else(|| {
AwsRetryableError(AwsError::SdkError(format!(
"SSM parameter '{}' has no value",
name
)))
})
}
})
.await?;
Ok(result)
}
pub async fn create_cloudfront_invalidation(
&self,
distribution_id: &str,
paths: &[String],
) -> Result<String> {
let cloudfront = self.cloudfront.clone();
let distribution_id = distribution_id.to_string();
let paths = paths.to_vec();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS create_cloudfront_invalidation", || {
let cloudfront = cloudfront.clone();
let distribution_id = distribution_id.clone();
let paths = paths.clone();
async move {
let cf_paths = aws_sdk_cloudfront::types::Paths::builder()
.quantity(paths.len() as i32)
.set_items(Some(paths.clone()))
.build()
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
let batch = aws_sdk_cloudfront::types::InvalidationBatch::builder()
.paths(cf_paths)
.caller_reference(uuid::Uuid::new_v4().to_string())
.build()
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
let response = cloudfront
.create_invalidation()
.distribution_id(&distribution_id)
.invalidation_batch(batch)
.send()
.await
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
response
.invalidation()
.map(|inv| inv.id().to_string())
.ok_or_else(|| {
AwsRetryableError(AwsError::SdkError(
"No invalidation in response".to_string(),
))
})
}
})
.await?;
Ok(result)
}
pub async fn get_caller_identity(&self) -> Result<(String, String)> {
let _ecr = self.ecr.clone();
let sts = self.sts.clone();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS get_caller_identity", || {
let sts = sts.clone();
async move {
let response = sts
.get_caller_identity()
.send()
.await
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
let account = response
.account()
.ok_or(AwsRetryableError(AwsError::CredentialsInvalid))?
.to_string();
let arn = response
.arn()
.ok_or(AwsRetryableError(AwsError::CredentialsInvalid))?
.to_string();
Ok::<(String, String), AwsRetryableError>((account, arn))
}
})
.await?;
Ok(result)
}
pub async fn get_ecr_authorization(&self) -> Result<DockerCredentials> {
let ecr = self.ecr.clone();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS get_ecr_authorization", || {
let ecr = ecr.clone();
async move {
let response = ecr
.get_authorization_token()
.send()
.await
.map_err(|e| AwsRetryableError(AwsError::EcrAuthFailed(e.to_string())))?;
let auth_data = response.authorization_data().first().ok_or_else(|| {
AwsRetryableError(AwsError::EcrAuthFailed(
"No authorization data returned".to_string(),
))
})?;
let token = auth_data.authorization_token().ok_or_else(|| {
AwsRetryableError(AwsError::EcrAuthFailed("No token in response".to_string()))
})?;
let decoded = base64::engine::general_purpose::STANDARD
.decode(token)
.map_err(|e| {
AwsRetryableError(AwsError::EcrAuthFailed(format!(
"Failed to decode token: {}",
e
)))
})?;
let decoded_str = String::from_utf8(decoded).map_err(|e| {
AwsRetryableError(AwsError::EcrAuthFailed(format!(
"Invalid token encoding: {}",
e
)))
})?;
let (username, password) = decoded_str.split_once(':').ok_or_else(|| {
AwsRetryableError(AwsError::EcrAuthFailed("Invalid token format".to_string()))
})?;
let server_address = auth_data.proxy_endpoint().map(|s| s.to_string());
Ok::<DockerCredentials, AwsRetryableError>(DockerCredentials {
username: Some(username.to_string()),
password: Some(password.to_string()),
serveraddress: server_address,
..Default::default()
})
}
})
.await?;
Ok(result)
}
pub fn get_ecr_registry_url(&self, account_id: &str) -> String {
format!("{}.dkr.ecr.{}.amazonaws.com", account_id, self.region)
}
fn is_repository_not_found_error<E>(error: &SdkError<E>) -> bool
where
E: std::fmt::Debug,
{
match error {
SdkError::ServiceError(service_err) => {
let debug_str = format!("{:?}", service_err);
debug_str.contains("RepositoryNotFoundException")
|| debug_str.contains("RepositoryNotFound")
}
_ => {
let err_str = error.to_string();
err_str.contains("RepositoryNotFoundException")
|| err_str.contains("RepositoryNotFound")
}
}
}
pub async fn ensure_repository_exists(&self, repo_name: &str) -> Result<String> {
let ecr = self.ecr.clone();
let repo_name = repo_name.to_string();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS ensure_repository_exists", || {
let ecr = ecr.clone();
let repo_name = repo_name.clone();
async move {
match ecr
.describe_repositories()
.repository_names(&repo_name)
.send()
.await
{
Ok(response) => {
let repo = response.repositories().first().ok_or_else(|| {
AwsRetryableError(AwsError::EcrRepoNotFound(repo_name.clone()))
})?;
repo.repository_uri().map(|s| s.to_string()).ok_or_else(|| {
AwsRetryableError(AwsError::SdkError(format!(
"Repository '{}' exists but has no URI",
repo_name
)))
})
}
Err(e) => {
if Self::is_repository_not_found_error(&e) {
Err(AwsRetryableError(AwsError::EcrRepoNotFound(
repo_name.clone(),
)))
} else {
Err(AwsRetryableError(AwsError::SdkError(e.to_string())))
}
}
}
}
})
.await?;
Ok(result)
}
pub async fn create_repository(&self, repo_name: &str) -> Result<String> {
let ecr = self.ecr.clone();
let repo_name = repo_name.to_string();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS create_repository", || {
let ecr = ecr.clone();
let repo_name = repo_name.clone();
async move {
let response = ecr
.create_repository()
.repository_name(&repo_name)
.image_scanning_configuration(
aws_sdk_ecr::types::ImageScanningConfiguration::builder()
.scan_on_push(true)
.build(),
)
.send()
.await
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
let repo = response.repository().ok_or_else(|| {
AwsRetryableError(AwsError::SdkError("No repository in response".to_string()))
})?;
repo.repository_uri().map(|s| s.to_string()).ok_or_else(|| {
AwsRetryableError(AwsError::SdkError(format!(
"Created repository '{}' but it has no URI",
repo_name
)))
})
}
})
.await?;
Ok(result)
}
pub async fn list_image_tags(&self, repo_name: &str) -> Result<Vec<String>> {
let ecr = self.ecr.clone();
let repo_name = repo_name.to_string();
let retry_config = self.retry_config.clone();
let result = with_retry(&retry_config, "AWS list_image_tags", || {
let ecr = ecr.clone();
let repo_name = repo_name.clone();
async move {
let mut tags = Vec::new();
let mut next_token: Option<String> = None;
loop {
let mut request = ecr.list_images().repository_name(&repo_name);
if let Some(token) = &next_token {
request = request.next_token(token);
}
let response = request
.send()
.await
.map_err(|e| AwsRetryableError(AwsError::SdkError(e.to_string())))?;
for image_id in response.image_ids() {
if let Some(tag) = image_id.image_tag() {
tags.push(tag.to_string());
}
}
match response.next_token() {
Some(token) => next_token = Some(token.to_string()),
None => break,
}
}
Ok::<Vec<String>, AwsRetryableError>(tags)
}
})
.await?;
Ok(result)
}
}