use crate::core::providers::unified_provider::ProviderError;
use std::collections::HashMap;
use std::env;
#[derive(Debug, Clone)]
pub struct AwsCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
pub region: String,
}
#[derive(Debug, Clone)]
pub struct AwsAuth {
credentials: AwsCredentials,
}
impl AwsAuth {
pub fn new(
access_key_id: String,
secret_access_key: String,
session_token: Option<String>,
region: String,
) -> Self {
Self {
credentials: AwsCredentials {
access_key_id,
secret_access_key,
session_token,
region,
},
}
}
pub fn from_env() -> Result<Self, ProviderError> {
let access_key_id = env::var("AWS_ACCESS_KEY_ID").map_err(|_| {
ProviderError::configuration(
"bedrock",
"AWS_ACCESS_KEY_ID environment variable not found".to_string(),
)
})?;
let secret_access_key = env::var("AWS_SECRET_ACCESS_KEY").map_err(|_| {
ProviderError::configuration(
"bedrock",
"AWS_SECRET_ACCESS_KEY environment variable not found".to_string(),
)
})?;
let session_token = env::var("AWS_SESSION_TOKEN").ok();
let region = env::var("AWS_REGION")
.or_else(|_| env::var("AWS_DEFAULT_REGION"))
.unwrap_or_else(|_| "us-east-1".to_string());
Ok(Self::new(
access_key_id,
secret_access_key,
session_token,
region,
))
}
pub fn credentials(&self) -> &AwsCredentials {
&self.credentials
}
pub fn validate(&self) -> Result<(), ProviderError> {
if self.credentials.access_key_id.is_empty() {
return Err(ProviderError::configuration(
"bedrock",
"AWS access key ID cannot be empty".to_string(),
));
}
if self.credentials.secret_access_key.is_empty() {
return Err(ProviderError::configuration(
"bedrock",
"AWS secret access key cannot be empty".to_string(),
));
}
if self.credentials.region.is_empty() {
return Err(ProviderError::configuration(
"bedrock",
"AWS region cannot be empty".to_string(),
));
}
if !self.credentials.access_key_id.starts_with("AKIA")
&& !self.credentials.access_key_id.starts_with("ASIA")
{
return Err(ProviderError::configuration(
"bedrock",
"Invalid AWS access key format".to_string(),
));
}
Ok(())
}
pub fn get_mapped_auth_params(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert(
"aws_access_key_id".to_string(),
self.credentials.access_key_id.clone(),
);
params.insert(
"aws_secret_access_key".to_string(),
self.credentials.secret_access_key.clone(),
);
params.insert(
"aws_region_name".to_string(),
self.credentials.region.clone(),
);
if let Some(ref token) = self.credentials.session_token {
params.insert("aws_session_token".to_string(), token.clone());
}
params
}
pub fn is_temporary_credentials(&self) -> bool {
self.credentials.session_token.is_some()
}
pub fn credential_type(&self) -> &'static str {
if self.credentials.access_key_id.starts_with("AKIA") {
"long-term"
} else if self.credentials.access_key_id.starts_with("ASIA") {
"temporary"
} else {
"unknown"
}
}
}
#[cfg(test)]
pub fn extract_credentials_from_params(
params: &HashMap<String, String>,
) -> Result<AwsCredentials, ProviderError> {
let access_key_id = params
.get("aws_access_key_id")
.or_else(|| params.get("access_key"))
.ok_or_else(|| {
ProviderError::configuration(
"bedrock",
"AWS access key ID not found in parameters".to_string(),
)
})?
.clone();
let secret_access_key = params
.get("aws_secret_access_key")
.or_else(|| params.get("secret_key"))
.ok_or_else(|| {
ProviderError::configuration(
"bedrock",
"AWS secret access key not found in parameters".to_string(),
)
})?
.clone();
let session_token = params
.get("aws_session_token")
.or_else(|| params.get("session_token"))
.cloned();
let region = params
.get("aws_region_name")
.or_else(|| params.get("region"))
.or_else(|| params.get("aws_region"))
.cloned()
.unwrap_or_else(|| "us-east-1".to_string());
Ok(AwsCredentials {
access_key_id,
secret_access_key,
session_token,
region,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aws_auth_creation() {
let auth = AwsAuth::new(
"AKIATEST123456789012".to_string(),
"test-secret-key".to_string(),
None,
"us-east-1".to_string(),
);
assert_eq!(auth.credentials().access_key_id, "AKIATEST123456789012");
assert_eq!(auth.credentials().region, "us-east-1");
assert!(!auth.is_temporary_credentials());
assert_eq!(auth.credential_type(), "long-term");
}
#[test]
fn test_temporary_credentials() {
let auth = AwsAuth::new(
"ASIATEST123456789012".to_string(),
"test-secret-key".to_string(),
Some("session-token".to_string()),
"us-west-2".to_string(),
);
assert!(auth.is_temporary_credentials());
assert_eq!(auth.credential_type(), "temporary");
}
#[test]
fn test_credential_validation() {
let auth = AwsAuth::new(
"AKIATEST123456789012".to_string(),
"test-secret-key".to_string(),
None,
"us-east-1".to_string(),
);
assert!(auth.validate().is_ok());
let invalid_auth = AwsAuth::new(
"INVALID_KEY".to_string(),
"test-secret-key".to_string(),
None,
"us-east-1".to_string(),
);
assert!(invalid_auth.validate().is_err());
}
#[test]
fn test_mapped_auth_params() {
let auth = AwsAuth::new(
"AKIATEST123456789012".to_string(),
"test-secret-key".to_string(),
Some("session-token".to_string()),
"us-east-1".to_string(),
);
let params = auth.get_mapped_auth_params();
assert_eq!(
params.get("aws_access_key_id").unwrap(),
"AKIATEST123456789012"
);
assert_eq!(
params.get("aws_secret_access_key").unwrap(),
"test-secret-key"
);
assert_eq!(params.get("aws_session_token").unwrap(), "session-token");
assert_eq!(params.get("aws_region_name").unwrap(), "us-east-1");
}
#[test]
fn test_extract_credentials_from_params() {
let mut params = HashMap::new();
params.insert(
"aws_access_key_id".to_string(),
"AKIATEST123456789012".to_string(),
);
params.insert(
"aws_secret_access_key".to_string(),
"test-secret-key".to_string(),
);
params.insert("aws_region_name".to_string(), "us-west-2".to_string());
let credentials = extract_credentials_from_params(¶ms).unwrap();
assert_eq!(credentials.access_key_id, "AKIATEST123456789012");
assert_eq!(credentials.secret_access_key, "test-secret-key");
assert_eq!(credentials.region, "us-west-2");
assert!(credentials.session_token.is_none());
}
}