use crate::aws_kms_client::{AwsKmsClient, AwsKmsConfig, KmsClient};
use crate::error::Error;
use ockam_core::compat::sync::{Arc, RwLock};
use ockam_core::{async_trait, Result};
use ockam_vault::{
KeyId, PublicKey, SecretAttributes, SecretType, Signature, SigningVault, VaultError,
};
use tracing::error;
struct AwsKeyPair {
key_id: KeyId,
public_key: PublicKey,
}
pub struct AwsSigningVault {
client: Arc<dyn KmsClient + Send + Sync>,
keys: Arc<RwLock<Vec<AwsKeyPair>>>,
}
impl AwsSigningVault {
pub async fn create() -> Result<Self> {
Self::create_with_config(AwsKmsConfig::default().await?).await
}
pub async fn create_with_config(config: AwsKmsConfig) -> Result<Self> {
let client = AwsKmsClient::new(config).await?;
let mut keys: Vec<AwsKeyPair> = vec![];
let key_ids = client.list_keys().await?;
for key_id in key_ids {
match client.public_key(&key_id).await {
Ok(public_key) => keys.push(AwsKeyPair { key_id, public_key }),
Err(err) => error!("Error exporting public key: {err}"),
}
}
Ok(Self {
client: Arc::new(client),
keys: Arc::new(RwLock::new(keys)),
})
}
pub fn keys(&self) -> Vec<KeyId> {
self.keys
.read()
.unwrap()
.iter()
.map(|x| x.key_id.clone())
.collect()
}
}
#[async_trait]
impl SigningVault for AwsSigningVault {
async fn get_public_key(&self, key_id: &KeyId) -> Result<PublicKey> {
self.keys
.read()
.unwrap()
.iter()
.find_map(|x| {
if &x.key_id == key_id {
Some(x.public_key.clone())
} else {
None
}
})
.ok_or(Error::KeyNotFound.into())
}
async fn get_key_id(&self, public_key: &PublicKey) -> Result<KeyId> {
self.keys
.read()
.unwrap()
.iter()
.find_map(|x| {
if &x.public_key == public_key {
Some(x.key_id.clone())
} else {
None
}
})
.ok_or(Error::KeyNotFound.into())
}
async fn sign(&self, key_id: &KeyId, message: &[u8]) -> Result<Signature> {
self.client.sign(key_id, message).await
}
async fn generate_key(&self, attributes: SecretAttributes) -> Result<KeyId> {
if attributes.secret_type() != SecretType::NistP256 {
return Err(VaultError::InvalidKeyType.into());
}
let key_id = self.client.create_key().await?;
let public_key = self.client.public_key(&key_id).await?;
self.keys.write().unwrap().push(AwsKeyPair {
key_id: key_id.clone(),
public_key,
});
Ok(key_id)
}
async fn delete_key(&self, key_id: KeyId) -> Result<bool> {
if self.client.delete_key(&key_id).await? {
self.keys.write().unwrap().retain(|x| x.key_id != key_id);
Ok(true)
} else {
Ok(false)
}
}
async fn number_of_keys(&self) -> Result<usize> {
Ok(self.keys.read().unwrap().len())
}
}