use aws_config::SdkConfig;
use aws_sdk_kms::error::SdkError;
use aws_sdk_kms::operation::create_key::CreateKeyError;
use aws_sdk_kms::operation::get_public_key::GetPublicKeyError;
use aws_sdk_kms::operation::schedule_key_deletion::ScheduleKeyDeletionError;
use aws_sdk_kms::operation::sign::SignError;
use aws_sdk_kms::operation::verify::VerifyError;
use aws_sdk_kms::primitives::Blob;
use aws_sdk_kms::types::{KeySpec, KeyUsageType, MessageType, SigningAlgorithmSpec};
use aws_sdk_kms::Client;
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{async_trait, Result};
use ockam_vault::{KeyId, PublicKey, SecretType, Signature};
use sha2::{Digest, Sha256};
use thiserror::Error;
use tracing as log;
#[derive(Debug, Clone)]
pub struct AwsKmsClient {
client: Client,
config: AwsKmsConfig,
}
#[derive(Debug, Clone)]
pub struct AwsKmsConfig {
multi_region: bool,
sdk_config: SdkConfig,
}
impl AwsKmsConfig {
pub async fn default() -> Result<AwsKmsConfig> {
Ok(Self::new(aws_config::load_from_env().await))
}
pub fn new(sdk_config: SdkConfig) -> AwsKmsConfig {
AwsKmsConfig {
multi_region: false,
sdk_config,
}
}
pub fn multi_region(mut self, val: bool) -> Self {
self.multi_region = val;
self
}
}
impl AwsKmsClient {
pub async fn new(config: AwsKmsConfig) -> Result<AwsKmsClient> {
let client = Client::new(&config.sdk_config);
Ok(Self { client, config })
}
pub async fn default() -> Result<Self> {
Self::new(AwsKmsConfig::default().await?).await
}
pub async fn create_key(&self) -> Result<KeyId> {
log::trace!("create new key");
let mut client = self
.client
.create_key()
.key_usage(KeyUsageType::SignVerify)
.key_spec(KeySpec::EccNistP256);
if self.config.multi_region {
client = client.multi_region(true)
}
let output = match client.send().await {
Ok(out) => out,
Err(err) => {
log::error!(%err, "failed to create new key");
return Err(Into::<ockam_core::Error>::into(Error::Create(err)));
}
};
if let Some(kid) = output.key_metadata().and_then(|meta| meta.key_id()) {
log::debug!(%kid, "created new key");
return Ok(kid.to_string());
}
Err(Error::MissingKeyId.into())
}
pub async fn delete_key(&self, key_id: &KeyId) -> Result<bool> {
log::trace!(%key_id, "schedule key for deletion");
const DAYS: i32 = 7;
let client = self
.client
.schedule_key_deletion()
.key_id(key_id)
.pending_window_in_days(DAYS);
match client.send().await {
Err(SdkError::ServiceError(err))
if matches!(err.err(), ScheduleKeyDeletionError::NotFoundException(_)) =>
{
log::debug!(%key_id, "key does not exist");
Ok(false)
}
Err(err) => {
log::error!(%key_id, %err, "failed to schedule key for deletion");
Err(Error::Delete {
keyid: key_id.to_string(),
error: err,
}
.into())
}
Ok(_) => {
log::debug!(%key_id, "key is scheduled for deletion in {DAYS} days");
Ok(true)
}
}
}
pub async fn public_key(&self, key_id: &KeyId) -> Result<PublicKey> {
log::trace!(%key_id, "get public key");
let output = self
.client
.get_public_key()
.key_id(key_id)
.send()
.await
.map_err(|err| {
log::error!(%key_id, %err, "failed to get public key");
Error::Export {
keyid: key_id.to_string(),
error: err,
}
})?;
if output.key_spec() != Some(&KeySpec::EccNistP256) {
log::error!(%key_id, "key spec not supported to get a public key");
return Err(Error::UnsupportedKeyType.into());
}
if output.key_usage() != Some(&KeyUsageType::SignVerify) {
log::error!(%key_id, "usage type not supported to get a public key");
return Err(Error::UnsupportedKeyType.into());
}
if let Some(k) = output.public_key() {
log::debug!(%key_id, "received public key");
return Ok(PublicKey::new(k.as_ref().to_vec(), SecretType::NistP256));
}
log::error!(%key_id, "key type not supported to get a public key");
Err(Error::UnsupportedKeyType.into())
}
pub async fn verify(
&self,
key_id: &KeyId,
message: &[u8],
signature: &Signature,
) -> Result<bool> {
log::trace!(%key_id, "verify message signature");
let client = self
.client
.verify()
.key_id(key_id)
.signature(Blob::new(signature.as_ref()))
.signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
.message(digest(message))
.message_type(MessageType::Digest);
let output = client.send().await.map_err(|err| {
log::error!(%key_id, %err, "failed to verify message signature");
Error::Verify {
keyid: key_id.to_string(),
error: err,
}
})?;
let is_valid = output.signature_valid();
log::debug!(%key_id, %is_valid, "verified message signature");
Ok(is_valid)
}
pub async fn sign(&self, key_id: &KeyId, message: &[u8]) -> Result<Signature> {
log::trace!(%key_id, "sign message");
let client = self
.client
.sign()
.key_id(key_id)
.signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
.message(digest(message))
.message_type(MessageType::Digest);
let output = client.send().await.map_err(|err| {
log::error!(%key_id, %err, "failed to sign message");
Error::Sign {
keyid: key_id.to_string(),
error: err,
}
})?;
if let Some(sig) = output.signature() {
log::debug!(%key_id, "signed message");
return Ok(Signature::new(sig.as_ref().to_vec()));
}
log::error!(%key_id, "no signature received from aws");
Err(Error::MissingSignature.into())
}
}
#[async_trait]
pub(crate) trait KmsClient {
async fn create_key(&self) -> Result<KeyId>;
async fn delete_key(&self, key_id: &KeyId) -> Result<bool>;
async fn public_key(&self, key_id: &KeyId) -> Result<PublicKey>;
async fn list_keys(&self) -> Result<Vec<KeyId>>;
async fn verify(&self, key_id: &KeyId, message: &[u8], signature: &Signature) -> Result<bool>;
async fn sign(&self, key_id: &KeyId, message: &[u8]) -> Result<Signature>;
}
#[async_trait]
impl KmsClient for AwsKmsClient {
async fn create_key(&self) -> Result<KeyId> {
self.create_key().await
}
async fn delete_key(&self, key_id: &KeyId) -> Result<bool> {
self.delete_key(key_id).await
}
async fn list_keys(&self) -> Result<Vec<KeyId>> {
let output = self.client.list_keys().send().await.map_err(|err| {
log::error!(%err, "failed to list all keys");
Error::MissingKeys
})?;
if let Some(keys) = output.keys() {
let mut result = vec![];
for key in keys {
if let Some(key_id) = key.key_id() {
result.push(key_id.to_string())
}
}
return Ok(result);
}
Ok(vec![])
}
async fn public_key(&self, key_id: &KeyId) -> Result<PublicKey> {
self.public_key(key_id).await
}
async fn verify(&self, key_id: &KeyId, message: &[u8], signature: &Signature) -> Result<bool> {
self.verify(key_id, message, signature).await
}
async fn sign(&self, key_id: &KeyId, message: &[u8]) -> Result<Signature> {
self.sign(key_id, message).await
}
}
fn digest(data: &[u8]) -> Blob {
Blob::new(Sha256::digest(data).to_vec())
}
#[derive(Error, Debug)]
pub(crate) enum Error {
#[error("aws sdk error creating new key")]
Create(#[from] SdkError<CreateKeyError>),
#[error("aws sdk error signing message with key {keyid}")]
Sign {
keyid: String,
#[source]
error: SdkError<SignError>,
},
#[error("aws sdk error verifying message with key {keyid}")]
Verify {
keyid: String,
#[source]
error: SdkError<VerifyError>,
},
#[error("aws sdk error exporting public key {keyid}")]
Export {
keyid: String,
#[source]
error: SdkError<GetPublicKeyError>,
},
#[error("aws sdk error exporting public key {keyid}")]
Delete {
keyid: String,
#[source]
error: SdkError<ScheduleKeyDeletionError>,
},
#[error("aws did not return a key id")]
MissingKeyId,
#[error("aws did not return the list of existing keys")]
MissingKeys,
#[error("aws did not return a signature")]
MissingSignature,
#[error("key type is not supported")]
UnsupportedKeyType,
}
impl From<Error> for ockam_core::Error {
fn from(e: Error) -> Self {
ockam_core::Error::new(Origin::Other, Kind::Io, e)
}
}