#![allow(clippy::result_large_err)]
use super::config::{AwsConfig, AwsCredentials};
use super::error::AwsError;
use super::sigv4::sign_for_service;
fn endpoint_host(endpoint: &str) -> &str {
endpoint
.split_once("://")
.map(|(_, rest)| rest)
.unwrap_or(endpoint)
.split('/')
.next()
.unwrap_or(endpoint)
}
const MAX_RETRIES_429: u32 = 3;
const MAX_RETRIES_TRANSIENT: u32 = 5;
const DEFAULT_RETRY_SECS: u64 = 2;
fn http_agent() -> ureq::Agent {
ureq::AgentBuilder::new()
.timeout_connect(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build()
}
fn call_with_retry(
make_request: impl Fn() -> Result<ureq::Response, ureq::Error>,
) -> Result<ureq::Response, ureq::Error> {
let mut throttled_attempt = 0u32;
let mut transient_attempt = 0u32;
loop {
match make_request() {
Ok(resp) => return Ok(resp),
Err(ureq::Error::Status(429, resp)) if throttled_attempt < MAX_RETRIES_429 => {
let retry_after = resp
.header("Retry-After")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_RETRY_SECS * 2u64.pow(throttled_attempt));
let wait = std::cmp::min(jittered_delay_secs(retry_after), 30);
std::thread::sleep(std::time::Duration::from_secs(wait));
throttled_attempt += 1;
}
Err(ureq::Error::Transport(t))
if transient_attempt < MAX_RETRIES_TRANSIENT
&& is_retryable_transport_error(t.to_string().as_str()) =>
{
let backoff = DEFAULT_RETRY_SECS * 2u64.pow(transient_attempt);
let wait = std::cmp::min(jittered_delay_secs(backoff), 30);
std::thread::sleep(std::time::Duration::from_secs(wait));
transient_attempt += 1;
}
Err(e) => return Err(e),
}
}
}
fn is_retryable_transport_error(message: &str) -> bool {
let msg = message.to_ascii_lowercase();
msg.contains("timed out")
|| msg.contains("timeout")
|| msg.contains("connection reset")
|| msg.contains("connection refused")
|| msg.contains("econnreset")
|| msg.contains("econnrefused")
|| msg.contains("temporar")
}
fn jittered_delay_secs(base_secs: u64) -> u64 {
if base_secs == 0 {
return 0;
}
let jitter_cap = std::cmp::max(1, base_secs / 4);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
base_secs + (nanos % (jitter_cap + 1))
}
fn map_ureq_error(e: ureq::Error) -> AwsError {
match e {
ureq::Error::Status(400, resp) => {
let body = resp.into_string().unwrap_or_default();
if body.contains("ParameterNotFound") {
AwsError::NotFound(body)
} else {
AwsError::Http {
status: 400,
message: body,
}
}
}
ureq::Error::Status(s, resp) => AwsError::Http {
status: s,
message: resp
.into_string()
.unwrap_or_else(|_| "<unreadable response>".into()),
},
other => AwsError::Transport(other.to_string()),
}
}
pub fn normalize_ssm_name(name: &str) -> String {
name.trim_start_matches('/')
.replace(['/', '-'], "_")
.to_uppercase()
}
fn normalize_ssm_path(path: Option<&str>) -> String {
match path {
Some("") | None => "/".to_string(),
Some(p) if p.starts_with('/') => p.to_string(),
Some(p) => format!("/{p}"),
}
}
pub fn pull_ssm_parameters(
cfg: &AwsConfig,
get_creds: &impl Fn() -> Result<AwsCredentials, AwsError>,
path: Option<&str>,
) -> Result<Vec<(String, String)>, AwsError> {
const TARGET: &str = "AmazonSSM.GetParametersByPath";
let root = normalize_ssm_path(path);
let agent = http_agent();
let mut results = Vec::new();
let mut next_token: Option<String> = None;
loop {
let creds = get_creds()?;
let body = match &next_token {
Some(tok) => serde_json::json!({
"Path": root,
"Recursive": true,
"WithDecryption": true,
"MaxResults": 10,
"NextToken": tok,
}),
None => serde_json::json!({
"Path": root,
"Recursive": true,
"WithDecryption": true,
"MaxResults": 10,
}),
};
let body_str = body.to_string();
let sig = sign_for_service(
"ssm",
&cfg.region,
endpoint_host(&cfg.endpoint),
TARGET,
&body_str,
&creds.access_key_id,
&creds.secret_access_key,
creds.session_token.as_deref(),
);
let body_clone = body_str.clone();
let endpoint = cfg.endpoint.clone();
let resp: serde_json::Value = call_with_retry(|| {
let mut req = agent
.post(&endpoint)
.set("Content-Type", "application/x-amz-json-1.1")
.set("X-Amz-Target", TARGET)
.set("X-Amz-Date", &sig.x_amz_date)
.set("Authorization", &sig.authorization);
if let Some(ref tok) = sig.x_amz_security_token {
req = req.set("X-Amz-Security-Token", tok);
}
req.send_string(&body_clone)
})
.map_err(map_ureq_error)?
.into_json()
.map_err(|e| AwsError::Transport(e.to_string()))?;
let params = resp["Parameters"]
.as_array()
.ok_or_else(|| AwsError::Transport("SSM response missing 'Parameters' array".into()))?;
for param in params {
if let (Some(name), Some(value)) = (param["Name"].as_str(), param["Value"].as_str()) {
let key = normalize_ssm_name(name);
results.push((key, value.to_string()));
}
}
next_token = resp["NextToken"].as_str().map(|s| s.to_string());
if next_token.is_none() {
break;
}
}
Ok(results)
}
#[derive(Debug, PartialEq, Eq)]
pub enum PushOutcome {
Created,
Updated,
Unchanged,
Deleted,
}
#[tracing::instrument(skip(cfg, get_creds, value), fields(name = %name))]
pub fn push_ssm_parameter(
cfg: &AwsConfig,
get_creds: &impl Fn() -> Result<AwsCredentials, AwsError>,
name: &str,
value: &str,
overwrite: bool,
) -> Result<PushOutcome, AwsError> {
let creds = get_creds()?;
match get_ssm_parameter(cfg, &creds, name) {
Ok(existing_value) => {
if existing_value == value {
return Ok(PushOutcome::Unchanged);
}
if !overwrite {
return Ok(PushOutcome::Unchanged);
}
put_ssm_parameter(cfg, &creds, name, value, true)?;
Ok(PushOutcome::Updated)
}
Err(AwsError::NotFound(_)) => {
put_ssm_parameter(cfg, &creds, name, value, false)?;
Ok(PushOutcome::Created)
}
Err(e) => Err(e),
}
}
fn get_ssm_parameter(
cfg: &AwsConfig,
creds: &AwsCredentials,
name: &str,
) -> Result<String, AwsError> {
const TARGET: &str = "AmazonSSM.GetParameter";
let agent = http_agent();
let body = serde_json::json!({
"Name": name,
"WithDecryption": true,
})
.to_string();
let sig = sign_for_service(
"ssm",
&cfg.region,
endpoint_host(&cfg.endpoint),
TARGET,
&body,
&creds.access_key_id,
&creds.secret_access_key,
creds.session_token.as_deref(),
);
let body_clone = body.clone();
let endpoint = cfg.endpoint.clone();
let resp: serde_json::Value = call_with_retry(|| {
let mut req = agent
.post(&endpoint)
.set("Content-Type", "application/x-amz-json-1.1")
.set("X-Amz-Target", TARGET)
.set("X-Amz-Date", &sig.x_amz_date)
.set("Authorization", &sig.authorization);
if let Some(ref tok) = sig.x_amz_security_token {
req = req.set("X-Amz-Security-Token", tok);
}
req.send_string(&body_clone)
})
.map_err(map_ureq_error)?
.into_json()
.map_err(|e| AwsError::Transport(e.to_string()))?;
resp["Parameter"]["Value"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| AwsError::NotFound(name.to_string()))
}
fn put_ssm_parameter(
cfg: &AwsConfig,
creds: &AwsCredentials,
name: &str,
value: &str,
overwrite: bool,
) -> Result<(), AwsError> {
const TARGET: &str = "AmazonSSM.PutParameter";
let agent = http_agent();
let body = serde_json::json!({
"Name": name,
"Value": value,
"Type": "SecureString",
"Overwrite": overwrite,
})
.to_string();
let sig = sign_for_service(
"ssm",
&cfg.region,
endpoint_host(&cfg.endpoint),
TARGET,
&body,
&creds.access_key_id,
&creds.secret_access_key,
creds.session_token.as_deref(),
);
let body_clone = body.clone();
let endpoint = cfg.endpoint.clone();
call_with_retry(|| {
let mut req = agent
.post(&endpoint)
.set("Content-Type", "application/x-amz-json-1.1")
.set("X-Amz-Target", TARGET)
.set("X-Amz-Date", &sig.x_amz_date)
.set("Authorization", &sig.authorization);
if let Some(ref tok) = sig.x_amz_security_token {
req = req.set("X-Amz-Security-Token", tok);
}
req.send_string(&body_clone)
})
.map_err(map_ureq_error)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::super::config::{AwsConfig, AwsCredentials};
use super::*;
fn test_creds() -> AwsCredentials {
AwsCredentials {
access_key_id: "AKID-TEST".into(),
secret_access_key: "secret-test".into(),
session_token: None,
}
}
fn cfg(url: &str) -> AwsConfig {
AwsConfig::with_endpoint("us-east-1", url)
}
fn params_response(params: &[(&str, &str)], next_token: Option<&str>) -> String {
let items: Vec<String> = params
.iter()
.map(|(name, value)| {
format!(r#"{{"Name":"{name}","Value":"{value}","Type":"SecureString"}}"#)
})
.collect();
match next_token {
Some(tok) => format!(
r#"{{"Parameters":[{}],"NextToken":"{tok}"}}"#,
items.join(",")
),
None => format!(r#"{{"Parameters":[{}]}}"#, items.join(",")),
}
}
#[test]
fn normalize_ssm_name_absolute_path() {
assert_eq!(
normalize_ssm_name("/myapp/db-password"),
"MYAPP_DB_PASSWORD"
);
}
#[test]
fn normalize_ssm_name_relative_path() {
assert_eq!(normalize_ssm_name("myapp/api-key"), "MYAPP_API_KEY");
}
#[test]
fn normalize_ssm_name_deep_path() {
assert_eq!(
normalize_ssm_name("/prod/myapp/DB_URL"),
"PROD_MYAPP_DB_URL"
);
}
#[test]
fn retryable_transport_classifier_detects_timeout() {
assert!(is_retryable_transport_error("operation timed out"));
assert!(is_retryable_transport_error("Connection reset by peer"));
assert!(!is_retryable_transport_error("invalid request payload"));
}
#[test]
fn jittered_delay_stays_within_25_percent_bound() {
let base = 20;
let jittered = jittered_delay_secs(base);
assert!(jittered >= base);
assert!(jittered <= base + (base / 4));
}
#[test]
fn normalize_ssm_name_simple() {
assert_eq!(normalize_ssm_name("/DB_URL"), "DB_URL");
}
#[test]
fn normalize_ssm_path_adds_missing_leading_slash() {
assert_eq!(normalize_ssm_path(Some("myapp/prod/")), "/myapp/prod/");
assert_eq!(normalize_ssm_path(Some("/myapp/prod/")), "/myapp/prod/");
assert_eq!(normalize_ssm_path(None), "/");
}
#[test]
fn pull_ssm_parameters_empty() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Parameters":[]}"#)
.create();
let result =
pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/myapp/"))
.unwrap();
assert!(result.is_empty());
}
#[test]
fn pull_ssm_parameters_fetches_and_normalises() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(params_response(&[("/myapp/db-password", "s3cr3t")], None))
.create();
let result =
pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/myapp/"))
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "MYAPP_DB_PASSWORD");
assert_eq!(result[0].1, "s3cr3t");
}
#[test]
fn pull_ssm_parameters_pagination() {
let mut server = mockito::Server::new();
let _page1 = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(params_response(&[("/app/key-a", "val-a")], Some("tok2")))
.expect(1)
.create();
let _page2 = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(params_response(&[("/app/key-b", "val-b")], None))
.expect(1)
.create();
let result =
pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/app/")).unwrap();
assert_eq!(result.len(), 2);
let keys: Vec<&str> = result.iter().map(|(k, _)| k.as_str()).collect();
assert!(keys.contains(&"APP_KEY_A"));
assert!(keys.contains(&"APP_KEY_B"));
}
#[test]
fn pull_ssm_parameters_uses_absolute_path_and_ssm_sigv4_scope() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.match_header(
"Authorization",
mockito::Matcher::Regex(
r"Credential=AKID-TEST/\d{8}/us-east-1/ssm/aws4_request".to_string(),
),
)
.match_body(mockito::Matcher::Regex(r#""Path":"/app/""#.to_string()))
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Parameters":[]}"#)
.create();
let result =
pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("app/")).unwrap();
assert!(result.is_empty());
}
#[test]
fn pull_ssm_parameters_403_returns_http_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.with_status(403)
.with_body(r#"{"__type":"AccessDeniedException","message":"Access denied"}"#)
.create();
let err = pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/app/"))
.unwrap_err();
assert!(matches!(err, AwsError::Http { status: 403, .. }));
}
#[test]
fn pull_ssm_parameters_429_exhausts_retries() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.with_status(429)
.with_header("Retry-After", "0")
.with_body("Too Many Requests")
.expect(MAX_RETRIES_429 as usize + 1)
.create();
let err = pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/app/"))
.unwrap_err();
assert!(matches!(err, AwsError::Http { status: 429, .. }));
}
#[test]
fn pull_ssm_malformed_json_returns_transport_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body("not json {{{{")
.create();
let err = pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/app/"))
.unwrap_err();
assert!(matches!(err, AwsError::Transport(_)));
}
#[test]
fn pull_ssm_missing_parameters_array_returns_transport_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Unexpected":[]}"#)
.create();
let err = pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/app/"))
.unwrap_err();
assert!(
matches!(err, AwsError::Transport(_)),
"expected Transport for malformed SSM schema, got {err:?}"
);
}
#[test]
fn creds_refresh_failure_on_page2_propagates_error() {
use std::sync::atomic::{AtomicUsize, Ordering};
let call_count = AtomicUsize::new(0);
let mut server = mockito::Server::new();
let _page1 = server
.mock("POST", "/")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(params_response(&[("/app/key", "val")], Some("tok2")))
.create();
let err = pull_ssm_parameters(
&cfg(&server.url()),
&|| {
let n = call_count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Ok(test_creds())
} else {
Err(AwsError::Auth("creds expired".into()))
}
},
Some("/app/"),
)
.unwrap_err();
assert!(
matches!(err, AwsError::Auth(_)),
"expected Auth on page-2 creds failure, got {err:?}"
);
}
#[test]
fn x_amz_target_header_is_ssm_get_parameters_by_path() {
let mut server = mockito::Server::new();
let _m = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParametersByPath")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Parameters":[]}"#)
.create();
let result =
pull_ssm_parameters(&cfg(&server.url()), &|| Ok(test_creds()), Some("/")).unwrap();
assert!(result.is_empty());
}
#[test]
fn push_ssm_creates_when_not_found() {
let mut server = mockito::Server::new();
let _get = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParameter")
.with_status(400)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"__type":"ParameterNotFound","message":"Parameter not found"}"#)
.create();
let _put = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.PutParameter")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Version":1,"Tier":"Standard"}"#)
.create();
let outcome = push_ssm_parameter(
&cfg(&server.url()),
&|| Ok(test_creds()),
"/myapp/db-password",
"val",
true,
)
.unwrap();
assert_eq!(outcome, PushOutcome::Created);
}
#[test]
fn push_ssm_updates_when_value_differs() {
let mut server = mockito::Server::new();
let _get = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParameter")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Parameter":{"Name":"/myapp/db","Value":"old","Type":"SecureString"}}"#)
.create();
let _put = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.PutParameter")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(r#"{"Version":2,"Tier":"Standard"}"#)
.create();
let outcome = push_ssm_parameter(
&cfg(&server.url()),
&|| Ok(test_creds()),
"/myapp/db",
"new",
true,
)
.unwrap();
assert_eq!(outcome, PushOutcome::Updated);
}
#[test]
fn push_ssm_unchanged_when_value_identical() {
let mut server = mockito::Server::new();
let _get = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.GetParameter")
.with_status(200)
.with_header("Content-Type", "application/x-amz-json-1.1")
.with_body(
r#"{"Parameter":{"Name":"/myapp/key","Value":"same","Type":"SecureString"}}"#,
)
.create();
let no_put = server
.mock("POST", "/")
.match_header("X-Amz-Target", "AmazonSSM.PutParameter")
.expect(0)
.create();
let outcome = push_ssm_parameter(
&cfg(&server.url()),
&|| Ok(test_creds()),
"/myapp/key",
"same",
true,
)
.unwrap();
assert_eq!(outcome, PushOutcome::Unchanged);
no_put.assert();
}
#[test]
fn push_ssm_auth_error_propagates() {
let server = mockito::Server::new();
let err = push_ssm_parameter(
&cfg(&server.url()),
&|| Err(AwsError::Auth("no creds".into())),
"/myapp/key",
"val",
true,
)
.unwrap_err();
assert!(matches!(err, AwsError::Auth(_)));
}
}