#![forbid(missing_debug_implementations, missing_copy_implementations)]
#![deny(rust_2018_idioms)]
#![deny(missing_docs)]
#![warn(clippy::pedantic)]
#![allow(
clippy::module_name_repetitions,
clippy::must_use_candidate,
clippy::missing_errors_doc
)]
mod client;
pub mod error;
use aws_sdk_kms::primitives::Blob;
use aws_sdk_kms::Client as KmsClient;
use ring::digest::{digest, SHA256};
use ring::rand::SecureRandom;
use snafu::{ensure, OptionExt, ResultExt};
use std::collections::HashMap;
use std::fmt;
use tough::async_trait;
use tough::key_source::KeySource;
use tough::schema::decoded::{Decoded, RsaPem};
use tough::schema::key::{Key, RsaKey, RsaScheme};
use tough::sign::Sign;
#[non_exhaustive]
#[derive(Debug, Clone, Eq, PartialEq, Copy)]
pub enum KmsSigningAlgorithm {
RsassaPssSha256,
}
impl KmsSigningAlgorithm {
fn value(self) -> aws_sdk_kms::types::SigningAlgorithmSpec {
match self {
KmsSigningAlgorithm::RsassaPssSha256 => {
aws_sdk_kms::types::SigningAlgorithmSpec::RsassaPssSha256
}
}
}
}
pub struct KmsKeySource {
pub profile: Option<String>,
pub key_id: String,
pub client: Option<KmsClient>,
pub signing_algorithm: KmsSigningAlgorithm,
}
impl fmt::Debug for KmsKeySource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KmsKeySource")
.field("key_id", &self.key_id)
.field("profile", &self.profile)
.finish_non_exhaustive()
}
}
#[async_trait]
impl KeySource for KmsKeySource {
async fn as_sign(
&self,
) -> std::result::Result<Box<dyn Sign>, Box<dyn std::error::Error + Send + Sync + 'static>>
{
let kms_client = match self.client.clone() {
Some(value) => value,
None => client::build_client_kms(self.profile.as_deref()).await,
};
let response = kms_client
.get_public_key()
.key_id(self.key_id.clone())
.send()
.await
.context(error::KmsGetPublicKeySnafu {
profile: self.profile.clone(),
key_id: self.key_id.clone(),
})?;
let key = pem::encode_config(
&pem::Pem::new(
"PUBLIC KEY".to_owned(),
response
.public_key
.context(error::PublicKeyNoneSnafu)?
.into_inner(),
),
pem::EncodeConfig::new().set_line_ending(pem::LineEnding::LF),
);
ensure!(
response
.signing_algorithms
.context(error::MissingSignAlgorithmSnafu)?
.contains(&self.signing_algorithm.value()),
error::ValidSignAlgorithmSnafu
);
Ok(Box::new(KmsRsaKey {
profile: self.profile.clone(),
client: Some(kms_client),
key_id: self.key_id.clone(),
public_key: key.parse().context(error::PublicKeyParseSnafu)?,
signing_algorithm: self.signing_algorithm,
modulus_size_bytes: parse_modulus_length_bytes(
response
.key_spec
.as_ref()
.context(error::MissingKeySpecSnafu)?
.as_str(),
)?,
}))
}
async fn write(
&self,
_value: &str,
_key_id_hex: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(())
}
}
pub struct KmsRsaKey {
key_id: String,
profile: Option<String>,
client: Option<KmsClient>,
public_key: Decoded<RsaPem>,
signing_algorithm: KmsSigningAlgorithm,
modulus_size_bytes: usize,
}
impl fmt::Debug for KmsRsaKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("KmsRsaKey")
.field("key_id", &self.key_id)
.field("signing_algorithm", &self.signing_algorithm)
.field("public_key", &self.public_key)
.finish_non_exhaustive()
}
}
#[async_trait]
impl Sign for KmsRsaKey {
fn tuf_key(&self) -> Key {
Key::Rsa {
keyval: RsaKey {
public: self.public_key.clone(),
_extra: HashMap::new(),
},
scheme: RsaScheme::RsassaPssSha256,
_extra: HashMap::new(),
}
}
async fn sign(
&self,
msg: &[u8],
_rng: &(dyn SecureRandom + Sync),
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync + 'static>> {
let kms_client = match self.client.clone() {
Some(value) => value,
None => client::build_client_kms(self.profile.as_deref()).await,
};
let blob = Blob::new(digest(&SHA256, msg).as_ref().to_vec());
let response = kms_client
.sign()
.key_id(self.key_id.clone())
.message(blob)
.message_type(aws_sdk_kms::types::MessageType::Digest)
.signing_algorithm(self.signing_algorithm.value())
.send()
.await
.context(error::KmsSignMessageSnafu {
profile: self.profile.clone(),
key_id: self.key_id.clone(),
})?;
let signature = response
.signature
.context(error::SignatureNotFoundSnafu)?
.into_inner();
let signature = match &self.signing_algorithm {
KmsSigningAlgorithm::RsassaPssSha256 => {
pad_signature(signature, self.modulus_size_bytes)?
}
};
Ok(signature)
}
}
fn parse_modulus_length_bytes(spec: &str) -> error::Result<usize> {
ensure!(spec.starts_with("RSA_"), error::BadKeySpecSnafu { spec });
ensure!(spec.len() > 4, error::BadKeySpecSnafu { spec });
let mod_len_str = &spec[4..];
let mod_bits = mod_len_str
.parse::<usize>()
.context(error::BadKeySpecIntSnafu { spec })?;
ensure!(
mod_bits % 8 == 0,
error::UnsupportedModulusSizeSnafu {
modulus_size_bits: mod_bits,
spec,
}
);
Ok(mod_bits / 8)
}
fn pad_signature(mut signature: Vec<u8>, modulus_size_bytes: usize) -> error::Result<Vec<u8>> {
ensure!(
signature.len() <= modulus_size_bytes,
error::SignatureTooLongSnafu {
modulus_size_bytes,
signature_size_bytes: signature.len()
},
);
if signature.len() == modulus_size_bytes {
return Ok(signature);
}
let padding_size: usize = modulus_size_bytes - signature.len();
signature.splice(..0, [0].repeat(padding_size));
Ok(signature)
}
#[test]
fn parse_modulus_length_wrong_alg() {
let result = parse_modulus_length_bytes("ECC_SECG_P256K1");
assert!(result.is_err());
}
#[test]
fn parse_modulus_length_bad_str() {
let result = parse_modulus_length_bytes("RSA_");
assert!(result.is_err());
}
#[test]
fn parse_modulus_length_3072() {
let modulus_length = parse_modulus_length_bytes("RSA_3072").unwrap();
assert_eq!(modulus_length, 384);
}
#[test]
fn parse_modulus_length_3073() {
let result = parse_modulus_length_bytes("RSA_3073");
assert!(result.is_err());
}
#[test]
fn pad_signature_too_long() {
let signature: Vec<u8> = vec![1, 2, 3, 4, 5];
let modulus_size: usize = 4;
let result = pad_signature(signature, modulus_size);
assert!(result.is_err());
}
#[test]
fn pad_signature_no_change() {
let signature: Vec<u8> = vec![1, 2, 3, 4, 5];
let expected: Vec<u8> = vec![1, 2, 3, 4, 5];
let modulus_size: usize = 5;
let actual = pad_signature(signature, modulus_size).unwrap();
assert_eq!(expected, actual);
}
#[test]
fn pad_signature_short_by_one() {
let signature: Vec<u8> = vec![1, 2, 3, 4, 5];
let expected: Vec<u8> = vec![0, 1, 2, 3, 4, 5];
let modulus_size: usize = 6;
let actual = pad_signature(signature, modulus_size).unwrap();
assert_eq!(expected, actual);
}
#[test]
fn pad_signature_short_by_two() {
let signature: Vec<u8> = vec![1, 2, 3, 4];
let expected: Vec<u8> = vec![0, 0, 1, 2, 3, 4];
let modulus_size: usize = 6;
let actual = pad_signature(signature, modulus_size).unwrap();
assert_eq!(expected, actual);
}