use crate::{DsqlError, Result};
use aws_sdk_dsql::auth_token::{AuthTokenGenerator, Config};
pub(crate) fn build_signer(
host: &str,
region: &aws_config::Region,
sdk_config: &aws_config::SdkConfig,
token_duration_secs: Option<u64>,
) -> Result<AuthTokenGenerator> {
let credentials_provider = sdk_config
.credentials_provider()
.ok_or_else(|| DsqlError::TokenError("No credentials provider found".into()))?;
let mut builder = Config::builder()
.hostname(host)
.region(region.clone())
.credentials(credentials_provider);
if let Some(duration) = token_duration_secs {
builder = builder.expires_in(duration);
}
Ok(AuthTokenGenerator::new(
builder.build().map_err(DsqlError::TokenError)?,
))
}
pub(crate) async fn generate_token(
signer: &AuthTokenGenerator,
user: &str,
sdk_config: &aws_config::SdkConfig,
) -> Result<String> {
let token = if user == "admin" {
signer.db_connect_admin_auth_token(sdk_config).await
} else {
signer.db_connect_auth_token(sdk_config).await
}
.map_err(DsqlError::TokenError)?;
Ok(token.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_credential_types::Credentials;
const TEST_HOST: &str = "example.dsql.us-east-1.on.aws";
async fn fake_signer(duration: Option<u64>) -> (AuthTokenGenerator, aws_config::SdkConfig) {
let creds = Credentials::new("fake_key", "fake_secret", None, None, "test");
let sdk_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(aws_config::Region::new("us-east-1"))
.credentials_provider(SharedCredentialsProvider::new(creds))
.load()
.await;
let region = aws_config::Region::new("us-east-1");
let signer = build_signer(TEST_HOST, ®ion, &sdk_config, duration)
.expect("signer creation should succeed");
(signer, sdk_config)
}
#[tokio::test]
async fn test_generate_token_admin_user() {
let (signer, sdk_config) = fake_signer(None).await;
let token = generate_token(&signer, "admin", &sdk_config)
.await
.expect("token generation should succeed with fake credentials");
assert!(!token.is_empty(), "Token should not be empty");
}
#[tokio::test]
async fn test_generate_token_non_admin_user() {
let (signer, sdk_config) = fake_signer(None).await;
let token = generate_token(&signer, "regular_user", &sdk_config)
.await
.expect("token generation should succeed with fake credentials");
assert!(!token.is_empty(), "Token should not be empty");
}
#[tokio::test]
async fn test_generate_token_with_custom_duration() {
let (signer, sdk_config) = fake_signer(Some(600)).await;
let token = generate_token(&signer, "admin", &sdk_config)
.await
.expect("token generation should succeed with custom duration");
assert!(!token.is_empty(), "Token should not be empty");
}
#[test]
fn test_build_signer_requires_credentials_provider() {
let sdk_config = aws_config::SdkConfig::builder()
.region(aws_config::Region::new("us-east-1"))
.build();
let region = aws_config::Region::new("us-east-1");
let result = build_signer("example.dsql.us-east-1.on.aws", ®ion, &sdk_config, None);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("credentials"),
"Expected credentials error, got: {}",
msg
);
}
}