use crate::error::{AmiError, Result};
use crate::provider::{AwsProvider, CloudProvider};
use crate::store::traits::IdentityProviderStore;
use crate::types::Tag;
use crate::wami::identity::identity_provider::{
builder, operations, AddClientIDToOpenIDConnectProviderRequest,
CreateOpenIDConnectProviderRequest, CreateSAMLProviderRequest,
ListOpenIDConnectProvidersRequest, ListSAMLProvidersRequest, OidcProvider,
RemoveClientIDFromOpenIDConnectProviderRequest, SamlProvider,
UpdateOpenIDConnectProviderThumbprintRequest, UpdateSAMLProviderRequest,
};
use std::sync::{Arc, RwLock};
pub struct IdentityProviderService<S> {
store: Arc<RwLock<S>>,
provider: Arc<dyn CloudProvider>,
account_id: String,
}
impl<S: IdentityProviderStore> IdentityProviderService<S> {
pub fn new(store: Arc<RwLock<S>>, account_id: String) -> Self {
Self {
store,
provider: Arc::new(AwsProvider::new()),
account_id,
}
}
pub fn with_provider(&self, provider: Arc<dyn CloudProvider>) -> Self {
Self {
store: Arc::clone(&self.store),
provider,
account_id: self.account_id.clone(),
}
}
pub async fn create_saml_provider(
&self,
request: CreateSAMLProviderRequest,
) -> Result<SamlProvider> {
SamlProvider::validate_name(&request.name)?;
operations::validate_saml_metadata(&request.saml_metadata_document)?;
let mut provider = builder::build_saml_provider(
request.name,
request.saml_metadata_document.clone(),
self.provider.as_ref(),
&self.account_id,
);
if let Ok(Some(valid_until)) =
operations::extract_saml_validity(&request.saml_metadata_document)
{
provider = builder::set_saml_valid_until(provider, valid_until);
}
if let Some(tags) = request.tags {
provider = builder::add_saml_tags(provider, tags);
}
let mut store = self.store.write().unwrap();
store.create_saml_provider(provider).await
}
pub async fn get_saml_provider(&self, arn: &str) -> Result<Option<SamlProvider>> {
let store = self.store.read().unwrap();
store.get_saml_provider(arn).await
}
pub async fn update_saml_provider(
&self,
request: UpdateSAMLProviderRequest,
) -> Result<SamlProvider> {
operations::validate_saml_metadata(&request.saml_metadata_document)?;
let existing = {
let store = self.store.read().unwrap();
store.get_saml_provider(&request.arn).await?
};
let existing = existing.ok_or_else(|| AmiError::ResourceNotFound {
resource: format!("SamlProvider: {}", request.arn),
})?;
let mut updated =
builder::update_saml_metadata(existing, request.saml_metadata_document.clone());
if let Ok(Some(valid_until)) =
operations::extract_saml_validity(&request.saml_metadata_document)
{
updated = builder::set_saml_valid_until(updated, valid_until);
}
let mut store = self.store.write().unwrap();
store.update_saml_provider(updated).await
}
pub async fn delete_saml_provider(&self, arn: &str) -> Result<()> {
let mut store = self.store.write().unwrap();
store.delete_saml_provider(arn).await
}
pub async fn list_saml_providers(
&self,
request: ListSAMLProvidersRequest,
) -> Result<(Vec<SamlProvider>, bool, Option<String>)> {
let store = self.store.read().unwrap();
store.list_saml_providers(request.pagination.as_ref()).await
}
pub async fn create_oidc_provider(
&self,
request: CreateOpenIDConnectProviderRequest,
) -> Result<OidcProvider> {
operations::validate_oidc_url(&request.url)?;
operations::validate_client_id_list(&request.client_id_list)?;
operations::validate_thumbprint_list(&request.thumbprint_list)?;
let mut provider = builder::build_oidc_provider(
request.url,
request.client_id_list,
request.thumbprint_list,
self.provider.as_ref(),
&self.account_id,
);
if let Some(tags) = request.tags {
provider = builder::add_oidc_tags(provider, tags);
}
let mut store = self.store.write().unwrap();
store.create_oidc_provider(provider).await
}
pub async fn get_oidc_provider(&self, arn: &str) -> Result<Option<OidcProvider>> {
let store = self.store.read().unwrap();
store.get_oidc_provider(arn).await
}
pub async fn update_oidc_thumbprints(
&self,
request: UpdateOpenIDConnectProviderThumbprintRequest,
) -> Result<OidcProvider> {
operations::validate_thumbprint_list(&request.thumbprint_list)?;
let existing = {
let store = self.store.read().unwrap();
store.get_oidc_provider(&request.arn).await?
};
let existing = existing.ok_or_else(|| AmiError::ResourceNotFound {
resource: format!("OidcProvider: {}", request.arn),
})?;
let updated = builder::update_thumbprints(existing, request.thumbprint_list);
let mut store = self.store.write().unwrap();
store.update_oidc_provider(updated).await
}
pub async fn add_client_id(
&self,
request: AddClientIDToOpenIDConnectProviderRequest,
) -> Result<OidcProvider> {
let existing = {
let store = self.store.read().unwrap();
store.get_oidc_provider(&request.arn).await?
};
let existing = existing.ok_or_else(|| AmiError::ResourceNotFound {
resource: format!("OidcProvider: {}", request.arn),
})?;
let updated = builder::add_client_id(existing, request.client_id);
let mut store = self.store.write().unwrap();
store.update_oidc_provider(updated).await
}
pub async fn remove_client_id(
&self,
request: RemoveClientIDFromOpenIDConnectProviderRequest,
) -> Result<OidcProvider> {
let existing = {
let store = self.store.read().unwrap();
store.get_oidc_provider(&request.arn).await?
};
let existing = existing.ok_or_else(|| AmiError::ResourceNotFound {
resource: format!("OidcProvider: {}", request.arn),
})?;
let updated = builder::remove_client_id(existing, &request.client_id);
let mut store = self.store.write().unwrap();
store.update_oidc_provider(updated).await
}
pub async fn delete_oidc_provider(&self, arn: &str) -> Result<()> {
let mut store = self.store.write().unwrap();
store.delete_oidc_provider(arn).await
}
pub async fn list_oidc_providers(
&self,
request: ListOpenIDConnectProvidersRequest,
) -> Result<(Vec<OidcProvider>, bool, Option<String>)> {
let store = self.store.read().unwrap();
store.list_oidc_providers(request.pagination.as_ref()).await
}
pub async fn tag_identity_provider(&self, arn: &str, tags: Vec<Tag>) -> Result<()> {
let mut store = self.store.write().unwrap();
store.tag_identity_provider(arn, tags).await
}
pub async fn list_identity_provider_tags(&self, arn: &str) -> Result<Vec<Tag>> {
let store = self.store.read().unwrap();
store.list_identity_provider_tags(arn).await
}
pub async fn untag_identity_provider(&self, arn: &str, tag_keys: Vec<String>) -> Result<()> {
let mut store = self.store.write().unwrap();
store.untag_identity_provider(arn, tag_keys).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::memory::InMemoryWamiStore;
#[tokio::test]
async fn test_saml_provider_service() {
let store = Arc::new(RwLock::new(InMemoryWamiStore::default()));
let service = IdentityProviderService::new(store, "123456789012".to_string());
let metadata = r#"<?xml version="1.0"?>
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata">
<IDPSSODescriptor />
</EntityDescriptor>"#;
let request = CreateSAMLProviderRequest {
name: "TestSAML".to_string(),
saml_metadata_document: metadata.to_string(),
tags: None,
};
let created = service.create_saml_provider(request).await.unwrap();
assert_eq!(created.saml_provider_name, "TestSAML");
let retrieved = service.get_saml_provider(&created.arn).await.unwrap();
assert!(retrieved.is_some());
let update_req = UpdateSAMLProviderRequest {
arn: created.arn.clone(),
saml_metadata_document: metadata.to_string(),
};
let updated = service.update_saml_provider(update_req).await.unwrap();
assert_eq!(updated.arn, created.arn);
let (providers, _, _) = service
.list_saml_providers(ListSAMLProvidersRequest::default())
.await
.unwrap();
assert_eq!(providers.len(), 1);
service.delete_saml_provider(&created.arn).await.unwrap();
let after_delete = service.get_saml_provider(&created.arn).await.unwrap();
assert!(after_delete.is_none());
}
#[tokio::test]
async fn test_oidc_provider_service() {
let store = Arc::new(RwLock::new(InMemoryWamiStore::default()));
let service = IdentityProviderService::new(store, "123456789012".to_string());
let request = CreateOpenIDConnectProviderRequest {
url: "https://accounts.google.com".to_string(),
client_id_list: vec!["client-123".to_string()],
thumbprint_list: vec!["0123456789abcdef0123456789abcdef01234567".to_string()],
tags: None,
};
let created = service.create_oidc_provider(request).await.unwrap();
assert_eq!(created.url, "https://accounts.google.com");
let retrieved = service.get_oidc_provider(&created.arn).await.unwrap();
assert!(retrieved.is_some());
let update_req = UpdateOpenIDConnectProviderThumbprintRequest {
arn: created.arn.clone(),
thumbprint_list: vec!["AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".to_string()],
};
let updated = service.update_oidc_thumbprints(update_req).await.unwrap();
assert_eq!(updated.thumbprint_list.len(), 1);
let add_req = AddClientIDToOpenIDConnectProviderRequest {
arn: created.arn.clone(),
client_id: "client-456".to_string(),
};
let with_client = service.add_client_id(add_req).await.unwrap();
assert_eq!(with_client.client_id_list.len(), 2);
let remove_req = RemoveClientIDFromOpenIDConnectProviderRequest {
arn: created.arn.clone(),
client_id: "client-123".to_string(),
};
let without_client = service.remove_client_id(remove_req).await.unwrap();
assert_eq!(without_client.client_id_list.len(), 1);
let (providers, _, _) = service
.list_oidc_providers(ListOpenIDConnectProvidersRequest::default())
.await
.unwrap();
assert_eq!(providers.len(), 1);
service.delete_oidc_provider(&created.arn).await.unwrap();
let after_delete = service.get_oidc_provider(&created.arn).await.unwrap();
assert!(after_delete.is_none());
}
}