use aws_config::BehaviorVersion;
use aws_credential_types::{Credentials, provider::ProvideCredentials};
use aws_sigv4::http_request::{
SignableBody, SignableRequest, SignatureLocation, SigningSettings, sign,
};
use aws_sigv4::sign::v4;
use rand::Rng;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use std::time::SystemTime;
use strum_macros::IntoStaticStr;
use tokio::sync::{Notify, RwLock};
use tokio::task::JoinHandle;
use tokio::time::{MissedTickBehavior, interval};
use crate::value::{ErrorKind, Error};
const MAX_REFRESH_INTERVAL_SECONDS: u32 = 12 * 60 * 60; const DEFAULT_REFRESH_INTERVAL_SECONDS: u32 = 300; const WARNING_REFRESH_INTERVAL_SECONDS: u32 = 15 * 60; pub const TOKEN_TTL_SECONDS: u64 = 15 * 60;
const TOKEN_GEN_MAX_ATTEMPTS: u32 = 8;
const TOKEN_GEN_INITIAL_BACKOFF_MS: u64 = 100;
const TOKEN_GEN_MAX_BACKOFF_MS: u64 = 3_000;
#[derive(Clone, Copy, Debug, PartialEq, Eq, IntoStaticStr)]
pub enum ServiceType {
#[strum(serialize = "elasticache")]
ElastiCache,
#[strum(serialize = "memorydb")]
MemoryDB,
}
fn validate_refresh_interval(
refresh_interval_seconds: Option<u32>,
) -> std::result::Result<Option<u32>, Error> {
match refresh_interval_seconds {
Some(0) => {
Err(Error::from((
ErrorKind::ClientError,
"IAM refresh interval validation failed",
"interval must be at least 1 second, got 0".to_string(),
)))
}
Some(interval) => {
if interval > MAX_REFRESH_INTERVAL_SECONDS {
return Err(Error::from((
ErrorKind::ClientError,
"IAM refresh interval validation failed",
format!("actual={interval} exceeds max={MAX_REFRESH_INTERVAL_SECONDS}"),
)));
}
if interval >= WARNING_REFRESH_INTERVAL_SECONDS {
let interval_min = interval / 60;
let warning_min = WARNING_REFRESH_INTERVAL_SECONDS / 60;
tracing::warn!(
"IAM token refresh interval warning - Refresh interval of {interval} seconds ({interval_min}min) exceeds recommended maximum of {WARNING_REFRESH_INTERVAL_SECONDS} seconds ({warning_min}min). \
This may increase the risk of token expiration. \
Consider using a shorter interval for better reliability."
);
}
Ok(Some(interval))
}
None => Ok(Some(DEFAULT_REFRESH_INTERVAL_SECONDS)),
}
}
async fn get_signing_identity(
region: &str,
service_type: ServiceType,
) -> std::result::Result<aws_credential_types::Credentials, Error> {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::Region::new(region.to_string()))
.load()
.await;
let provider = config.credentials_provider().ok_or_else(|| {
Error::from((
ErrorKind::ClientError,
"IAM credentials error",
"No AWS credentials provider found".to_string(),
))
})?;
let creds = provider.provide_credentials().await.map_err(|e| {
Error::from((
ErrorKind::ClientError,
"IAM credentials error",
e.to_string(),
))
})?;
let service_name: &'static str = service_type.into();
Ok(Credentials::new(
creds.access_key_id(),
creds.secret_access_key(),
creds.session_token().map(|s| s.to_string()),
creds.expiry(),
service_name,
))
}
#[derive(Clone, Debug)]
pub(crate) struct IamTokenState {
region: String,
cluster_name: String,
username: String,
service_type: ServiceType,
refresh_interval_seconds: u32,
}
pub struct IAMTokenManager {
cached_token: Arc<RwLock<String>>,
token_created_at: Arc<RwLock<tokio::time::Instant>>,
iam_token_state: IamTokenState,
refresh_task: Option<JoinHandle<()>>,
shutdown_notify: Arc<Notify>,
token_changed: Arc<AtomicBool>,
}
impl std::fmt::Debug for IAMTokenManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IAMTokenManager")
.field("cached_token", &"<RwLock<String>>")
.field("iam_token_state", &self.iam_token_state)
.field("refresh_task", &self.refresh_task.is_some())
.field("shutdown_notify", &"<Notify>")
.field("token_changed", &self.token_changed.load(Ordering::Relaxed))
.finish()
}
}
impl IAMTokenManager {
pub async fn new(
cluster_name: String,
username: String,
region: String,
service_type: ServiceType,
refresh_interval_seconds: Option<u32>,
) -> std::result::Result<Self, Error> {
let validated_refresh_interval = validate_refresh_interval(refresh_interval_seconds)?;
let state = IamTokenState {
region,
cluster_name,
username,
service_type,
refresh_interval_seconds: validated_refresh_interval
.unwrap_or(DEFAULT_REFRESH_INTERVAL_SECONDS),
};
let initial_token = Self::generate_token_with_backoff(&state).await?;
Ok(Self {
cached_token: Arc::new(RwLock::new(initial_token)),
token_created_at: Arc::new(RwLock::new(tokio::time::Instant::now())),
iam_token_state: state,
refresh_task: None,
shutdown_notify: Arc::new(Notify::new()),
token_changed: Arc::new(AtomicBool::new(true)), })
}
pub fn start_refresh_task(&mut self) {
if self.refresh_task.is_some() {
return; }
let iam_token_state = self.iam_token_state.clone();
let cached_token = Arc::clone(&self.cached_token);
let token_created_at = Arc::clone(&self.token_created_at);
let shutdown_notify = Arc::clone(&self.shutdown_notify);
let token_changed = Arc::clone(&self.token_changed);
let task = tokio::spawn(Self::token_refresh_task(
iam_token_state,
cached_token,
token_created_at,
shutdown_notify,
token_changed,
));
self.refresh_task = Some(task);
}
async fn token_refresh_task(
iam_token_state: IamTokenState,
cached_token: Arc<RwLock<String>>,
token_created_at: Arc<RwLock<tokio::time::Instant>>,
shutdown_notify: Arc<Notify>,
token_changed: Arc<AtomicBool>,
) {
let refresh_interval = Duration::from_secs(iam_token_state.refresh_interval_seconds as u64);
let mut interval_timer = interval(refresh_interval);
interval_timer.set_missed_tick_behavior(MissedTickBehavior::Skip);
interval_timer.tick().await;
loop {
tokio::select! {
_ = interval_timer.tick() => {
let _ = Self::handle_token_refresh(&iam_token_state, &cached_token, &token_created_at, &token_changed).await;
}
_ = shutdown_notify.notified() => {
tracing::info!("IAM token refresh task shutting down");
break;
}
}
}
}
async fn handle_token_refresh(
iam_token_state: &IamTokenState,
cached_token: &Arc<RwLock<String>>,
token_created_at: &Arc<RwLock<tokio::time::Instant>>,
token_changed: &Arc<AtomicBool>,
) -> std::result::Result<(), Error> {
match Self::generate_token_with_backoff(iam_token_state).await {
Ok(new_token) => {
Self::set_cached_token_static(cached_token, new_token.clone()).await;
{
let mut ts = token_created_at.write().await;
*ts = tokio::time::Instant::now();
}
token_changed.store(true, Ordering::Release);
Ok(())
}
Err(err) => {
tracing::error!("IAM token refresh failed - Could not refresh token after backoff: {err}");
Err(err)
}
}
}
pub(crate) async fn generate_token_with_backoff(
state: &IamTokenState,
) -> std::result::Result<String, Error> {
let mut attempt: u32 = 0;
let mut backoff_ms = TOKEN_GEN_INITIAL_BACKOFF_MS;
loop {
match Self::generate_token_static(state).await {
Ok(token) => {
return Ok(token);
}
Err(e) => {
attempt += 1;
if attempt >= TOKEN_GEN_MAX_ATTEMPTS {
tracing::error!("IAM token generation failed - Exhausted {TOKEN_GEN_MAX_ATTEMPTS} attempts with exponential backoff. error: {e}");
return Err(e);
}
let sleep_ms = {
let jitter = (backoff_ms as f64 * 0.2) as u64;
let min = backoff_ms.saturating_sub(jitter);
let max = backoff_ms.saturating_add(jitter);
let mut rng = rand::rng();
rng.random_range(min..=max)
};
tracing::warn!("IAM token generation failed - {e}. Retrying in {sleep_ms}ms");
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(TOKEN_GEN_MAX_BACKOFF_MS);
}
}
}
}
pub async fn refresh_token(&self) -> std::result::Result<(), Error> {
Self::handle_token_refresh(
&self.iam_token_state,
&self.cached_token,
&self.token_created_at,
&self.token_changed,
)
.await
}
pub async fn stop_refresh_task(&mut self) {
if let Some(task) = self.refresh_task.take() {
self.shutdown_notify.notify_one();
let _ = tokio::time::timeout(Duration::from_secs(5), task).await;
}
}
async fn set_cached_token_static(cached_token: &Arc<RwLock<String>>, new_token: String) {
let mut token_guard = cached_token.write().await;
*token_guard = new_token;
}
pub async fn get_token(&self) -> String {
let token_guard = self.cached_token.read().await;
token_guard.clone()
}
pub fn token_changed(&self) -> bool {
self.token_changed.load(Ordering::Acquire)
}
pub fn clear_token_changed(&self) {
self.token_changed.store(false, Ordering::Release)
}
pub fn get_token_handle(&self) -> crate::client::IAMTokenHandle {
crate::client::IAMTokenHandle {
cached_token: Arc::clone(&self.cached_token),
token_created_at: Arc::clone(&self.token_created_at),
iam_token_state: self.iam_token_state.clone(),
}
}
async fn generate_token_static(state: &IamTokenState) -> std::result::Result<String, Error> {
let service_name: &'static str = state.service_type.into();
let signing_time = SystemTime::now();
let hostname = state.cluster_name.clone();
let base_url = build_base_url(&hostname, &state.username);
let creds = get_signing_identity(&state.region, state.service_type).await?;
let identity_value = creds.into();
let mut signing_settings = SigningSettings::default();
signing_settings.signature_location = SignatureLocation::QueryParams;
signing_settings.expires_in = Some(Duration::from_secs(TOKEN_TTL_SECONDS));
let signing_params = v4::SigningParams::builder()
.identity(&identity_value)
.region(&state.region)
.name(service_name)
.time(signing_time)
.settings(signing_settings)
.build()
.map_err(|e| {
Error::from((
ErrorKind::ClientError,
"IAM token generation failed",
format!("Failed to build signing params: {e}"),
))
})?
.into();
let signable_request = SignableRequest::new(
"GET",
&base_url,
std::iter::empty(),
SignableBody::Bytes(b""),
)
.map_err(|e| {
Error::from((
ErrorKind::ClientError,
"IAM token generation failed",
format!("Failed to create signable request: {e}"),
))
})?;
let (instructions, _sig) = sign(signable_request, &signing_params)
.map_err(|e| {
Error::from((
ErrorKind::ClientError,
"IAM token generation failed",
format!("Failed to sign: {e}"),
))
})?
.into_parts();
let mut req = http::Request::builder()
.method("GET")
.uri(&base_url)
.header("host", &hostname)
.body(())
.map_err(|e| {
Error::from((
ErrorKind::ClientError,
"IAM token generation failed",
format!("Build HTTP request failed: {e}"),
))
})?;
instructions.apply_to_request_http1x(&mut req);
let token = strip_scheme(req.uri().to_string());
tracing::debug!("Generated new IAM token");
Ok(token)
}
}
impl Drop for IAMTokenManager {
fn drop(&mut self) {
self.shutdown_notify.notify_one();
}
}
fn build_base_url(hostname: &str, username: &str) -> String {
format!(
"https://{}/?Action=connect&User={}",
hostname,
urlencoding::encode(username)
)
}
fn strip_scheme(full: String) -> String {
full.strip_prefix("https://")
.or_else(|| full.strip_prefix("http://"))
.unwrap_or(&full)
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::env;
use std::fs;
use std::sync::Once;
use tokio::time::{Duration, sleep};
const IAM_TOKENS_JSON: &str = "/tmp/iam_tokens.json";
static INIT: Once = Once::new();
fn initialize_test_environment() {
INIT.call_once(|| {
let _ = std::fs::remove_file(IAM_TOKENS_JSON);
tracing::info!("Test setup - Cleaned up old IAM token log file");
});
}
fn setup_test_credentials() {
unsafe {
env::set_var("AWS_ACCESS_KEY_ID", "test_access_key");
env::set_var("AWS_SECRET_ACCESS_KEY", "test_secret_key");
env::set_var("AWS_SESSION_TOKEN", "test_session_token");
}
}
fn save_token_to_file(test_name: &str, token: &str, state: &IamTokenState) {
let token_data = serde_json::json!({
"test_name": test_name,
"token": token,
"region": state.region,
"cluster_name": state.cluster_name,
"username": state.username,
"service_type": format!("{:?}", state.service_type),
"refresh_interval_seconds": state.refresh_interval_seconds,
"timestamp": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
});
let mut tokens = if let Ok(content) = fs::read_to_string(IAM_TOKENS_JSON) {
serde_json::from_str::<Vec<serde_json::Value>>(&content).unwrap_or_else(|_| Vec::new())
} else {
Vec::new()
};
tokens.push(token_data);
if let Ok(json_string) = serde_json::to_string_pretty(&tokens) {
let _ = fs::write(IAM_TOKENS_JSON, json_string);
}
}
fn create_test_state(
region: &str,
cluster_name: &str,
username: &str,
service_type: ServiceType,
) -> IamTokenState {
IamTokenState {
region: region.to_string(),
cluster_name: cluster_name.to_string(),
username: username.to_string(),
service_type,
refresh_interval_seconds: DEFAULT_REFRESH_INTERVAL_SECONDS,
}
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_with_atomic_flag() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let mut manager = IAMTokenManager::new(
cluster_name,
username,
region,
ServiceType::ElastiCache,
Some(2), )
.await
.unwrap();
assert!(
manager.token_changed(),
"Initial token_changed should be true"
);
manager.clear_token_changed();
assert!(
!manager.token_changed(),
"After clear, token_changed should be false"
);
manager.start_refresh_task();
sleep(Duration::from_secs(3)).await;
assert!(
manager.token_changed(),
"After refresh, token_changed should be true"
);
manager.stop_refresh_task().await;
tracing::info!("Test completed successfully! - Atomic flag working as expected");
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_manual_refresh_sets_flag() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let manager = IAMTokenManager::new(
cluster_name,
username,
region,
ServiceType::ElastiCache,
None,
)
.await
.unwrap();
manager.clear_token_changed();
assert!(!manager.token_changed(), "Flag should be false after clear");
manager
.refresh_token()
.await
.expect("refresh_token should succeed in test");
assert!(
manager.token_changed(),
"Flag should be true after manual refresh"
);
tracing::info!("Manual refresh test completed successfully!");
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_new_creates_initial_token() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let result = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
None,
)
.await;
assert!(result.is_ok(), "IAMTokenManager creation should succeed");
let manager = result.unwrap();
let token = manager.get_token().await;
let state = create_test_state(®ion, &cluster_name, &username, ServiceType::ElastiCache);
save_token_to_file(
"test_iam_token_manager_new_creates_initial_token",
&token,
&state,
);
assert!(!token.is_empty(), "Initial token should not be empty");
assert!(
token.starts_with(&format!("{}/", cluster_name)),
"Token should start with cluster name"
);
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_get_token_returns_cached_token() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let manager = IAMTokenManager::new(
cluster_name,
username,
region,
ServiceType::ElastiCache,
None,
)
.await
.unwrap();
let token1 = manager.get_token().await;
let token2 = manager.get_token().await;
assert_eq!(
token1, token2,
"get_token should return the same cached token"
);
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_refresh_token_updates_cached_token() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let manager = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
None,
)
.await
.unwrap();
let initial_token = manager.get_token().await;
let state = create_test_state(®ion, &cluster_name, &username, ServiceType::ElastiCache);
save_token_to_file(
"test_iam_token_manager_refresh_token_updates_cached_token_initial",
&initial_token,
&state,
);
sleep(Duration::from_secs(1)).await;
manager
.refresh_token()
.await
.expect("refresh_token should succeed in test");
let new_token = manager.get_token().await;
let state = create_test_state(®ion, &cluster_name, &username, ServiceType::ElastiCache);
save_token_to_file(
"test_iam_token_manager_refresh_token_updates_cached_token_refreshed",
&new_token,
&state,
);
assert_ne!(
initial_token, new_token,
"Refreshed token should be different from initial token"
);
assert!(
new_token.starts_with(&format!("{}/", cluster_name)),
"New token should still start with cluster name"
);
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_start_and_stop_refresh_task() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let mut manager = IAMTokenManager::new(
cluster_name,
username,
region,
ServiceType::ElastiCache,
Some(1), )
.await
.unwrap();
manager.start_refresh_task();
assert!(
manager.refresh_task.is_some(),
"Refresh task should be started"
);
manager.start_refresh_task();
assert!(
manager.refresh_task.is_some(),
"Refresh task should still exist"
);
manager.stop_refresh_task().await;
assert!(
manager.refresh_task.is_none(),
"Refresh task should be stopped"
);
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_refresh_interval_validation() {
initialize_test_environment();
setup_test_credentials();
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let valid_intervals = [60, 900, 21600, 43200]; for interval in valid_intervals {
let result = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
Some(interval),
)
.await;
assert!(
result.is_ok(),
"IAMTokenManager creation should succeed with valid interval: {interval} seconds"
);
}
{
let result = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
Some(0),
)
.await;
assert!(
result.is_err(),
"IAMTokenManager creation should fail with interval 0"
);
let error = result.unwrap_err();
assert_eq!(error.kind(), ErrorKind::ClientError);
let detail = error.detail().unwrap_or_default();
assert!(
detail.contains("0"),
"Expected '0' in error detail, got: {detail}"
);
}
let invalid_intervals = [43201, 86400, 172800]; for interval in invalid_intervals {
let result = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
Some(interval),
)
.await;
assert!(
result.is_err(),
"IAMTokenManager creation should fail with invalid interval: {interval} seconds"
);
let error = result.unwrap_err();
assert_eq!(error.kind(), ErrorKind::ClientError);
let detail = error.detail().unwrap_or_default();
assert!(
detail.contains(&format!("{interval}")),
"Expected interval value in error detail, got: {detail}"
);
assert!(
detail.contains(&format!("{MAX_REFRESH_INTERVAL_SECONDS}")),
"Expected max interval value in error detail, got: {detail}"
);
}
}
#[tokio::test]
#[serial]
async fn test_iam_token_manager_generates_new_token_every_x_seconds() {
initialize_test_environment();
setup_test_credentials();
const REFRESH_TIME_SECONDS: u32 = 2;
let cluster_name = "test-cluster".to_string();
let username = "test-user".to_string();
let region = "us-east-1".to_string();
let mut manager = IAMTokenManager::new(
cluster_name.clone(),
username.clone(),
region.clone(),
ServiceType::ElastiCache,
Some(REFRESH_TIME_SECONDS),
)
.await
.unwrap();
let initial_token = manager.get_token().await;
assert!(
!initial_token.is_empty(),
"Initial token should not be empty"
);
let state = create_test_state(®ion, &cluster_name, &username, ServiceType::ElastiCache);
save_token_to_file(
"test_iam_token_manager_generates_new_token_every_5_seconds_initial",
&initial_token,
&state,
);
manager.start_refresh_task();
sleep(Duration::from_secs(REFRESH_TIME_SECONDS as u64 + 1)).await;
let first_refresh_token = manager.get_token().await;
assert_ne!(
initial_token, first_refresh_token,
"Token should be different after first refresh interval"
);
save_token_to_file(
"test_iam_token_manager_generates_new_token_every_5_seconds_first_refresh",
&first_refresh_token,
&state,
);
sleep(Duration::from_secs(REFRESH_TIME_SECONDS as u64 + 1)).await;
let second_refresh_token = manager.get_token().await;
assert_ne!(
first_refresh_token, second_refresh_token,
"Token should be different after second refresh interval"
);
assert_ne!(
initial_token, second_refresh_token,
"Second refresh token should be different from initial token"
);
save_token_to_file(
"test_iam_token_manager_generates_new_token_every_5_seconds_second_refresh",
&second_refresh_token,
&state,
);
for (name, token) in [
("initial", &initial_token),
("first_refresh", &first_refresh_token),
("second_refresh", &second_refresh_token),
] {
assert!(
token.starts_with(&format!("{}/", cluster_name)),
"{name} token should start with cluster name"
);
assert!(
token.contains("Action=connect"),
"{name} token should contain Action=connect"
);
assert!(
token.contains("X-Amz-Expires=900"),
"{name} token should contain 15-minute expiration"
);
assert!(
token.contains("X-Amz-Signature="),
"{name} token should contain X-Amz-Signature parameter"
);
}
}
}