#![allow(clippy::result_large_err)]
use crate::config::KvConfig;
use crate::error::KvError;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PushOutcome {
Created,
Updated,
Unchanged,
Deleted,
}
fn http_agent() -> ureq::Agent {
ureq::AgentBuilder::new()
.timeout_connect(std::time::Duration::from_secs(10))
.timeout(std::time::Duration::from_secs(30))
.build()
}
const API_VERSION: &str = "7.4";
const MAX_RETRIES_429: u32 = 3;
const MAX_RETRIES_TRANSIENT: u32 = 5;
const DEFAULT_RETRY_SECS: u64 = 2;
fn call_with_retry(
make_request: impl Fn() -> Result<ureq::Response, ureq::Error>,
) -> Result<ureq::Response, ureq::Error> {
let mut throttled_attempt = 0u32;
let mut transient_attempt = 0u32;
loop {
match make_request() {
Ok(resp) => return Ok(resp),
Err(ureq::Error::Status(429, resp)) if throttled_attempt < MAX_RETRIES_429 => {
let retry_after = resp
.header("Retry-After")
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(DEFAULT_RETRY_SECS * 2u64.pow(throttled_attempt));
let wait = std::cmp::min(jittered_delay_secs(retry_after), 30);
std::thread::sleep(std::time::Duration::from_secs(wait));
throttled_attempt += 1;
}
Err(ureq::Error::Transport(t))
if transient_attempt < MAX_RETRIES_TRANSIENT
&& is_retryable_transport_error(t.to_string().as_str()) =>
{
let backoff = DEFAULT_RETRY_SECS * 2u64.pow(transient_attempt);
let wait = std::cmp::min(jittered_delay_secs(backoff), 30);
std::thread::sleep(std::time::Duration::from_secs(wait));
transient_attempt += 1;
}
Err(e) => return Err(e),
}
}
}
fn is_retryable_transport_error(message: &str) -> bool {
let msg = message.to_ascii_lowercase();
msg.contains("timed out")
|| msg.contains("timeout")
|| msg.contains("connection reset")
|| msg.contains("connection refused")
|| msg.contains("econnreset")
|| msg.contains("econnrefused")
|| msg.contains("temporar")
}
fn jittered_delay_secs(base_secs: u64) -> u64 {
if base_secs == 0 {
return 0;
}
let jitter_cap = std::cmp::max(1, base_secs / 4);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
base_secs + (nanos % (jitter_cap + 1))
}
fn map_ureq_error(e: ureq::Error) -> KvError {
match e {
ureq::Error::Status(404, _) => KvError::NotFound(String::new()),
ureq::Error::Status(s, r) => KvError::Http {
status: s,
message: r
.into_string()
.unwrap_or_else(|_| "<unreadable response>".into()),
},
other => KvError::Transport(other.to_string()),
}
}
pub fn pull_secrets(
cfg: &KvConfig,
acquire_token: &impl Fn() -> Result<String, KvError>,
prefix: Option<&str>,
) -> Result<Vec<(String, String)>, KvError> {
let names = list_secret_names(cfg, acquire_token)?;
let token = acquire_token()?;
let mut secrets = Vec::new();
for name in &names {
if let Some(p) = prefix {
if !name.to_lowercase().starts_with(&p.to_lowercase()) {
continue;
}
}
let value = get_secret(cfg, &token, name)?;
let key = name.replace('-', "_").to_uppercase();
secrets.push((key, value));
}
Ok(secrets)
}
fn list_secret_names(
cfg: &KvConfig,
acquire_token: &impl Fn() -> Result<String, KvError>,
) -> Result<Vec<String>, KvError> {
let first_url = format!(
"{}/secrets?api-version={API_VERSION}&maxresults=25",
cfg.vault_url
);
let mut names = Vec::new();
let mut next: Option<String> = Some(first_url);
while let Some(url) = next.take() {
let token = acquire_token()?;
let agent = http_agent();
let auth = format!("Bearer {token}");
let url_clone = url.clone();
let resp: serde_json::Value =
call_with_retry(|| agent.get(&url_clone).set("Authorization", &auth).call())
.map_err(map_ureq_error)?
.into_json()
.map_err(|e| KvError::Transport(e.to_string()))?;
if let Some(items) = resp["value"].as_array() {
for item in items {
let enabled = item["attributes"]["enabled"].as_bool().unwrap_or(true);
if !enabled {
continue;
}
if let Some(id) = item["id"].as_str() {
let parts: Vec<&str> = id.trim_end_matches('/').split('/').collect();
if let Some(secrets_idx) = parts.iter().position(|&p| p == "secrets") {
if let Some(&name) = parts.get(secrets_idx + 1) {
if !name.is_empty() {
names.push(name.to_string());
}
}
}
}
}
}
next = resp["nextLink"]
.as_str()
.filter(|url| url.starts_with(cfg.vault_url.as_str()))
.map(|s| s.to_string());
}
Ok(names)
}
fn get_secret(cfg: &KvConfig, token: &str, name: &str) -> Result<String, KvError> {
let url = format!("{}/secrets/{name}?api-version={API_VERSION}", cfg.vault_url);
let agent = http_agent();
let auth = format!("Bearer {token}");
let url_clone = url.clone();
let resp: serde_json::Value =
call_with_retry(|| agent.get(&url_clone).set("Authorization", &auth).call())
.map_err(|e| {
let err = map_ureq_error(e);
match err {
KvError::NotFound(_) => KvError::NotFound(name.to_string()),
other => other,
}
})?
.into_json()
.map_err(|e| KvError::Transport(e.to_string()))?;
resp["value"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| KvError::NotFound(name.to_string()))
}
pub fn push_secret(
cfg: &KvConfig,
acquire_token: &impl Fn() -> Result<String, KvError>,
name: &str,
value: &str,
) -> Result<PushOutcome, KvError> {
let token = acquire_token()?;
let current_value = get_secret_opt(cfg, &token, name)?;
if let Some(ref remote_value) = current_value {
if remote_value == value {
return Ok(PushOutcome::Unchanged);
}
}
let url = format!("{}/secrets/{name}?api-version={API_VERSION}", cfg.vault_url);
let body = serde_json::json!({ "value": value });
let body_str = serde_json::to_string(&body).map_err(|e| KvError::Transport(e.to_string()))?;
let agent = http_agent();
let auth = format!("Bearer {token}");
let url_clone = url.clone();
let name_owned = name.to_string();
let result = call_with_retry(|| {
agent
.put(&url_clone)
.set("Authorization", &auth)
.set("Content-Type", "application/json")
.send_string(&body_str)
});
match result {
Ok(_) => {
if current_value.is_some() {
Ok(PushOutcome::Updated)
} else {
Ok(PushOutcome::Created)
}
}
Err(ureq::Error::Status(409, resp)) => {
let body = resp.into_string().unwrap_or_else(|_| "<unreadable>".into());
let body_lower = body.to_ascii_lowercase();
if body_lower.contains("deleted")
&& (body_lower.contains("recoverable") || body_lower.contains("soft"))
{
Err(KvError::SoftDeleted(name_owned))
} else {
Err(KvError::Http {
status: 409,
message: body,
})
}
}
Err(e) => Err(map_ureq_error(e)),
}
}
pub fn delete_secret(
cfg: &KvConfig,
acquire_token: &impl Fn() -> Result<String, KvError>,
name: &str,
) -> Result<(), KvError> {
let token = acquire_token()?;
let url = format!("{}/secrets/{name}?api-version={API_VERSION}", cfg.vault_url);
let agent = http_agent();
let auth = format!("Bearer {token}");
let url_clone = url.clone();
let result = call_with_retry(|| agent.delete(&url_clone).set("Authorization", &auth).call());
match result {
Ok(_) => Ok(()),
Err(ureq::Error::Status(404, _)) => Ok(()),
Err(e) => Err(map_ureq_error(e)),
}
}
fn get_secret_opt(cfg: &KvConfig, token: &str, name: &str) -> Result<Option<String>, KvError> {
match get_secret(cfg, token, name) {
Ok(v) => Ok(Some(v)),
Err(KvError::NotFound(_)) => Ok(None),
Err(e) => Err(e),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn key_normalisation() {
let name = "my-super-secret";
let key = name.replace('-', "_").to_uppercase();
assert_eq!(key, "MY_SUPER_SECRET");
}
#[test]
fn id_parsing_versioned_url() {
let id = "https://myvault.vault.azure.net/secrets/my-secret/abc123def456";
let parts: Vec<&str> = id.trim_end_matches('/').split('/').collect();
let secrets_idx = parts.iter().position(|&p| p == "secrets").unwrap();
assert_eq!(parts[secrets_idx + 1], "my-secret");
}
#[test]
fn id_parsing_unversioned_url() {
let id = "https://myvault.vault.azure.net/secrets/my-secret";
let parts: Vec<&str> = id.trim_end_matches('/').split('/').collect();
let secrets_idx = parts.iter().position(|&p| p == "secrets").unwrap();
assert_eq!(parts[secrets_idx + 1], "my-secret");
}
#[test]
fn nextlink_ssrf_guard_rejects_foreign_origin() {
let vault_url = "https://myvault.vault.azure.net";
let malicious_next = "https://evil.example.com/secrets?api-version=7.4";
assert!(
!malicious_next.starts_with(vault_url),
"SSRF guard should reject nextLink from a different origin"
);
}
#[test]
fn nextlink_ssrf_guard_accepts_same_origin() {
let vault_url = "https://myvault.vault.azure.net";
let valid_next = "https://myvault.vault.azure.net/secrets?api-version=7.4&$skiptoken=abc";
assert!(valid_next.starts_with(vault_url));
}
#[test]
fn retryable_transport_classifier_detects_timeout() {
assert!(is_retryable_transport_error("operation timed out"));
assert!(is_retryable_transport_error("Connection refused"));
assert!(!is_retryable_transport_error("certificate verify failed"));
}
#[test]
fn jittered_delay_stays_within_25_percent_bound() {
let base = 20;
let jittered = jittered_delay_secs(base);
assert!(jittered >= base);
assert!(jittered <= base + (base / 4));
}
fn list_response(names: &[&str], next_link: Option<&str>) -> String {
let items: Vec<String> = names
.iter()
.map(|n| {
format!(r#"{{"id":"https://vault/secrets/{n}","attributes":{{"enabled":true}}}}"#)
})
.collect();
let next = next_link
.map(|u| format!(r#","nextLink":"{u}""#))
.unwrap_or_default();
format!(r#"{{"value":[{}]{}}}"#, items.join(","), next)
}
fn secret_response(value: &str) -> String {
format!(r#"{{"value":"{value}"}}"#)
}
fn cfg(url: &str) -> KvConfig {
KvConfig {
vault_url: url.to_string(),
}
}
#[test]
fn pull_secrets_empty_vault() {
let mut server = mockito::Server::new();
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(r#"{"value":[]}"#)
.create();
let result =
pull_secrets(&cfg(&server.url()), &|| Ok("test-token".to_string()), None).unwrap();
assert!(result.is_empty());
}
#[test]
fn pull_secrets_fetches_and_normalises_key() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["my-db-password"], None))
.create();
let _get = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/my-db-password\?".to_string()),
)
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(secret_response("s3cr3t"))
.create();
let secrets = pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), None).unwrap();
assert_eq!(secrets.len(), 1);
assert_eq!(secrets[0].0, "MY_DB_PASSWORD");
assert_eq!(secrets[0].1, "s3cr3t");
}
#[test]
fn pull_secrets_skips_disabled_secrets() {
let mut server = mockito::Server::new();
let disabled_item = r#"{"value":[{"id":"https://vault/secrets/disabled-key","attributes":{"enabled":false}}]}"#;
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(disabled_item)
.create();
let secrets = pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), None).unwrap();
assert!(secrets.is_empty(), "disabled secrets must be filtered out");
}
#[test]
fn pull_secrets_prefix_filter() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["app-token", "db-password"], None))
.create();
let _get = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/app-token\?".to_string()),
)
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(secret_response("tok-xyz"))
.create();
let secrets =
pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), Some("app-")).unwrap();
assert_eq!(secrets.len(), 1);
assert_eq!(secrets[0].0, "APP_TOKEN");
}
#[test]
fn pull_secrets_pagination() {
let mut server = mockito::Server::new();
let page2_url = format!("{}/secrets?api-version=7.4&skiptoken=p2", server.url());
let _page1 = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["secret-a"], Some(&page2_url)))
.create();
let _page2 = server
.mock("GET", mockito::Matcher::Regex(r"skiptoken=p2".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["secret-b"], None))
.create();
let _get_a = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/secret-a\?".to_string()),
)
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(secret_response("val-a"))
.create();
let _get_b = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/secret-b\?".to_string()),
)
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(secret_response("val-b"))
.create();
let secrets = pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), None).unwrap();
assert_eq!(secrets.len(), 2);
let keys: Vec<&str> = secrets.iter().map(|(k, _)| k.as_str()).collect();
assert!(keys.contains(&"SECRET_A"));
assert!(keys.contains(&"SECRET_B"));
}
#[test]
fn get_secret_404_returns_not_found() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["ghost"], None))
.create();
let _get = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/ghost\?".to_string()),
)
.with_status(404)
.with_header("Content-Type", "application/json")
.with_body(r#"{"error":{"code":"SecretNotFound"}}"#)
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), None).unwrap_err();
assert!(matches!(err, KvError::NotFound(_)));
}
#[test]
fn list_returns_http_error_on_server_fault() {
let mut server = mockito::Server::new();
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets".to_string()))
.with_status(503)
.with_body("Service Unavailable")
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".to_string()), None).unwrap_err();
assert!(matches!(err, KvError::Http { status: 503, .. }));
}
#[test]
fn token_refresh_failure_before_fetch_phase_propagates_error() {
use std::sync::atomic::{AtomicUsize, Ordering};
let call_count = AtomicUsize::new(0);
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["db-password"], None))
.create();
let err = pull_secrets(
&cfg(&server.url()),
&|| {
let n = call_count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Ok("valid-token".to_string())
} else {
Err(KvError::Auth("token refresh: 401 Unauthorized".into()))
}
},
None,
)
.unwrap_err();
assert!(
matches!(err, KvError::Auth(_)),
"expected Auth error, got {err:?}"
);
}
#[test]
fn token_acquisition_failure_on_list_propagates_error() {
let server = mockito::Server::new();
let err = pull_secrets(
&cfg(&server.url()),
&|| Err(KvError::Auth("no credentials".into())),
None,
)
.unwrap_err();
assert!(matches!(err, KvError::Auth(_)));
}
#[test]
fn list_endpoint_401_returns_http_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(401)
.with_header("Content-Type", "application/json")
.with_body(r#"{"error":{"code":"Unauthorized","message":"Invalid token"}}"#)
.create();
let err =
pull_secrets(&cfg(&server.url()), &|| Ok("expired-tok".into()), None).unwrap_err();
assert!(
matches!(err, KvError::Http { status: 401, .. }),
"expected Http 401, got {err:?}"
);
}
#[test]
fn get_secret_401_returns_http_error() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["my-secret"], None))
.create();
let _get = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/my-secret\?".to_string()),
)
.with_status(401)
.with_header("Content-Type", "application/json")
.with_body(r#"{"error":{"code":"Unauthorized"}}"#)
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".into()), None).unwrap_err();
assert!(
matches!(err, KvError::Http { status: 401, .. }),
"expected Http 401 on get, got {err:?}"
);
}
#[test]
fn list_endpoint_malformed_json_returns_transport_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body("this is not json {{{{")
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".into()), None).unwrap_err();
assert!(
matches!(err, KvError::Transport(_)),
"expected Transport error on malformed JSON, got {err:?}"
);
}
#[test]
fn get_secret_malformed_json_returns_transport_error() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["broken-secret"], None))
.create();
let _get = server
.mock(
"GET",
mockito::Matcher::Regex(r"^/secrets/broken-secret\?".to_string()),
)
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body("not valid json")
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".into()), None).unwrap_err();
assert!(
matches!(err, KvError::Transport(_)),
"expected Transport on malformed get body, got {err:?}"
);
}
#[test]
fn list_endpoint_429_exhausts_retries_returns_http_error() {
let mut server = mockito::Server::new();
let _m = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(429)
.with_header("Retry-After", "0")
.with_body("Too Many Requests")
.expect(MAX_RETRIES_429 as usize + 1)
.create();
let err = pull_secrets(&cfg(&server.url()), &|| Ok("tok".into()), None).unwrap_err();
assert!(
matches!(err, KvError::Http { status: 429, .. }),
"expected Http 429 after retries exhausted, got {err:?}"
);
}
#[test]
fn authorization_header_contains_bearer_token() {
let mut server = mockito::Server::new();
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.match_header("Authorization", "Bearer my-test-token")
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(r#"{"value":[]}"#)
.create();
let result =
pull_secrets(&cfg(&server.url()), &|| Ok("my-test-token".into()), None).unwrap();
assert!(result.is_empty());
}
#[test]
fn nextlink_ssrf_guard_drops_foreign_origin_silently() {
let mut server = mockito::Server::new();
let malicious_next = "https://evil.example.com/secrets?api-version=7.4";
let _list = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(format!(r#"{{"value":[],"nextLink":"{malicious_next}"}}"#))
.create();
let result = pull_secrets(&cfg(&server.url()), &|| Ok("tok".into()), None).unwrap();
assert!(result.is_empty());
}
#[test]
fn token_refresh_failure_on_pagination_page2_propagates_error() {
use std::sync::atomic::{AtomicUsize, Ordering};
let call_count = AtomicUsize::new(0);
let mut server = mockito::Server::new();
let page2_url = format!("{}/secrets?api-version=7.4&skiptoken=p2", server.url());
let _page1 = server
.mock("GET", mockito::Matcher::Regex(r"^/secrets\?".to_string()))
.with_status(200)
.with_header("Content-Type", "application/json")
.with_body(list_response(&["secret-a"], Some(&page2_url)))
.create();
let err = pull_secrets(
&cfg(&server.url()),
&|| {
let n = call_count.fetch_add(1, Ordering::SeqCst);
if n == 0 {
Ok("good-token".into())
} else {
Err(KvError::Auth("401 on second page token".into()))
}
},
None,
)
.unwrap_err();
assert!(
matches!(err, KvError::Auth(_)),
"expected Auth error from page-2 token failure, got {err:?}"
);
}
}