use std::fmt::{Debug, Display};
use azure_core::auth::TokenCredential;
use azure_core::error::{Error, ErrorKind};
use base64::{CharacterSet, Config};
use chrono::serde::ts_seconds_option;
use chrono::{DateTime, Utc};
use getset::Getters;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{Map, Value};
use crate::client::API_VERSION_PARAM;
use crate::KeyClient;
#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct KeyVaultKey {
#[serde(flatten)]
properties: KeyProperties,
key: JsonWebKey,
}
#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct KeyProperties {
attributes: KeyAttributes,
managed: Option<bool>,
tags: Option<Map<String, Value>>,
}
#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
#[serde(rename_all = "camelCase")]
pub struct KeyAttributes {
#[serde(rename = "created", with = "ts_seconds_option", default)]
created_on: Option<DateTime<Utc>>,
enabled: Option<bool>,
#[serde(rename = "exp", with = "ts_seconds_option", default)]
expires_on: Option<DateTime<Utc>>,
#[serde(rename = "nbf", with = "ts_seconds_option", default)]
not_before: Option<DateTime<Utc>>,
recoverable_days: Option<u8>,
recovery_level: Option<String>,
#[serde(rename = "updated", with = "ts_seconds_option", default)]
updated_on: Option<DateTime<Utc>>,
}
#[derive(Debug, Serialize, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct JsonWebKey {
#[serde(rename = "crv")]
curve_name: Option<String>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
d: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
dp: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
dq: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
e: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
k: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
#[serde(rename = "key_hsm")]
t: Option<Vec<u8>>,
key_ops: Option<Vec<String>>,
#[serde(rename = "kid")]
id: Option<String>,
#[serde(rename = "kty")]
key_type: String,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
n: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
p: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
q: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
qi: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
x: Option<Vec<u8>>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
#[serde(default)]
y: Option<Vec<u8>>,
}
const BASE64_URL_SAFE: Config = Config::new(CharacterSet::UrlSafe, false);
fn ser_base64<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let base_64 = base64::encode_config(bytes, BASE64_URL_SAFE);
serializer.serialize_str(&base_64)
}
fn ser_base64_opt<S>(bytes: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if let Some(bytes) = bytes {
let base_64 = base64::encode_config(bytes, BASE64_URL_SAFE);
serializer.serialize_str(&base_64)
} else {
serializer.serialize_none()
}
}
fn deser_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s: String = String::deserialize(deserializer)?;
let res = base64::decode_config(s, BASE64_URL_SAFE).map_err(serde::de::Error::custom)?;
Ok(res)
}
fn deser_base64_opt<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
where
D: Deserializer<'de>,
{
let s: Option<&str> = Option::deserialize(deserializer)?;
let res = match s {
Some(s) => {
Some(base64::decode_config(s, BASE64_URL_SAFE).map_err(serde::de::Error::custom)?)
}
None => None,
};
Ok(res)
}
#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct SignResult {
#[serde(
rename = "value",
serialize_with = "ser_base64",
deserialize_with = "deser_base64"
)]
signature: Vec<u8>,
#[serde(skip)]
algorithm: SignatureAlgorithm,
#[serde(rename = "kid")]
key_id: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SignatureAlgorithm {
ES256, ES256K, ES384, ES512, PS256, PS384, PS512, RS256, RS384, RS512, Custom(String),
}
impl Default for SignatureAlgorithm {
fn default() -> Self {
SignatureAlgorithm::Custom("".to_string())
}
}
impl Display for SignatureAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub enum EncryptionAlgorithm {
#[serde(rename = "A128CBC")]
A128Cbc,
#[serde(rename = "A128CBCPAD")]
A128CbcPad,
#[serde(rename = "A128GCM")]
A128Gcm,
#[serde(rename = "A192CBC")]
A192Cbc,
#[serde(rename = "A192CBCPAD")]
A192CbcPad,
#[serde(rename = "A192GCM")]
A192Gcm,
#[serde(rename = "A256CBC")]
A256Cbc,
#[serde(rename = "A256CBCPAD")]
A256CbcPad,
#[serde(rename = "A256GCM")]
A256Gcm,
#[serde(rename = "RSA-OAEP")]
RsaOaep,
#[serde(rename = "RSA-OAEP-256")]
RsaOaep256,
#[serde(rename = "RSA1_5")]
Rsa15,
}
impl Default for EncryptionAlgorithm {
fn default() -> Self {
EncryptionAlgorithm::A128Cbc
}
}
impl Display for EncryptionAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DecryptParameters {
pub decrypt_parameters_encryption: DecryptParametersEncryption,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub ciphertext: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum DecryptParametersEncryption {
Rsa(RsaDecryptParameters),
AesGcm(AesGcmDecryptParameters),
AesCbc(AesCbcDecryptParameters),
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RsaDecryptParameters {
algorithm: EncryptionAlgorithm,
}
impl RsaDecryptParameters {
pub fn new(algorithm: EncryptionAlgorithm) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::Rsa15
| EncryptionAlgorithm::RsaOaep
| EncryptionAlgorithm::RsaOaep256 => Ok(Self { algorithm }),
_ => Err(Error::with_message(ErrorKind::Other, || {
format!("unexpected encryption algorithm: {algorithm}")
})),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AesGcmDecryptParameters {
algorithm: EncryptionAlgorithm,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub iv: Vec<u8>,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub authentication_tag: Vec<u8>,
#[serde(
serialize_with = "ser_base64_opt",
deserialize_with = "deser_base64_opt"
)]
pub additional_authenticated_data: Option<Vec<u8>>,
}
impl AesGcmDecryptParameters {
pub fn new(
algorithm: EncryptionAlgorithm,
iv: Vec<u8>,
authentication_tag: Vec<u8>,
additional_authenticated_data: Option<Vec<u8>>,
) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::A128Gcm
| EncryptionAlgorithm::A192Gcm
| EncryptionAlgorithm::A256Gcm => Ok(Self {
algorithm,
iv,
authentication_tag,
additional_authenticated_data,
}),
_ => Err(Error::with_message(ErrorKind::Other, || {
format!("unexpected encryption algorithm: {algorithm}")
})),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AesCbcDecryptParameters {
algorithm: EncryptionAlgorithm,
#[serde(serialize_with = "ser_base64", deserialize_with = "deser_base64")]
pub iv: Vec<u8>,
}
impl AesCbcDecryptParameters {
pub fn new(algorithm: EncryptionAlgorithm, iv: Vec<u8>) -> Result<Self, Error> {
match algorithm {
EncryptionAlgorithm::A128Cbc
| EncryptionAlgorithm::A192Cbc
| EncryptionAlgorithm::A256Cbc
| EncryptionAlgorithm::A128CbcPad
| EncryptionAlgorithm::A192CbcPad
| EncryptionAlgorithm::A256CbcPad => Ok(Self { algorithm, iv }),
_ => Err(Error::with_message(ErrorKind::Other, || {
format!("unexpected encryption algorithm: {algorithm}")
})),
}
}
}
#[derive(Debug, Deserialize, Getters)]
#[getset(get = "pub")]
pub struct DecryptResult {
#[serde(skip)]
algorithm: EncryptionAlgorithm,
#[serde(rename = "kid")]
key_id: String,
#[serde(
rename = "value",
serialize_with = "ser_base64",
deserialize_with = "deser_base64"
)]
result: Vec<u8>,
}
impl<'a, T: TokenCredential> KeyClient<'a, T> {
pub async fn get_key(
&mut self,
key_name: &str,
key_version: Option<&str>,
) -> Result<KeyVaultKey, Error> {
let mut uri = self.vault_url.clone();
let path = if let Some(ver) = key_version {
format!("keys/{}/{}", key_name, ver)
} else {
format!("keys/{}", key_name)
};
uri.set_path(&path);
uri.set_query(Some(API_VERSION_PARAM));
let resp_body = self.get_authed(uri.to_string()).await?;
let response = serde_json::from_str::<KeyVaultKey>(&resp_body)?;
Ok(response)
}
pub async fn sign(
&mut self,
algorithm: SignatureAlgorithm,
key_name: &str,
key_version: &str,
digest: &str,
) -> Result<SignResult, Error> {
let mut uri = self.vault_url.clone();
uri.set_path(&format!("keys/{}/{}/sign", key_name, key_version));
uri.set_query(Some(API_VERSION_PARAM));
let mut request_body = Map::new();
request_body.insert("alg".to_owned(), Value::String(algorithm.to_string()));
request_body.insert("value".to_owned(), Value::String(digest.to_owned()));
let response = self
.post_authed(
uri.to_string(),
Some(Value::Object(request_body).to_string()),
)
.await?;
let mut result = serde_json::from_str::<SignResult>(&response)?;
result.algorithm = algorithm;
Ok(result)
}
pub async fn decrypt(
&mut self,
key_name: &str,
key_version: Option<&str>,
decrypt_parameters: DecryptParameters,
) -> Result<DecryptResult, Error> {
let mut uri = self.vault_url.clone();
let path = format!("keys/{}/{}/decrypt", key_name, key_version.unwrap_or(""));
uri.set_path(&path);
uri.set_query(Some(API_VERSION_PARAM));
let mut request_body = Map::new();
request_body.insert(
"value".to_owned(),
Value::String(base64::encode(&decrypt_parameters.ciphertext)),
);
let algorithm = match decrypt_parameters.decrypt_parameters_encryption {
DecryptParametersEncryption::Rsa(RsaDecryptParameters { algorithm }) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
algorithm
}
DecryptParametersEncryption::AesGcm(AesGcmDecryptParameters {
algorithm,
iv,
authentication_tag,
additional_authenticated_data,
}) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
request_body.insert("iv".to_owned(), serde_json::to_value(iv).unwrap());
request_body.insert(
"tag".to_owned(),
serde_json::to_value(authentication_tag).unwrap(),
);
if let Some(aad) = additional_authenticated_data {
request_body.insert("aad".to_owned(), serde_json::to_value(aad).unwrap());
};
algorithm
}
DecryptParametersEncryption::AesCbc(AesCbcDecryptParameters { algorithm, iv }) => {
request_body.insert("alg".to_owned(), serde_json::to_value(&algorithm).unwrap());
request_body.insert("iv".to_owned(), serde_json::to_value(iv).unwrap());
algorithm
}
};
let response = self
.post_authed(
uri.to_string(),
Some(Value::Object(request_body).to_string()),
)
.await?;
let mut result = serde_json::from_str::<DecryptResult>(&response)?;
result.algorithm = algorithm;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{DateTime, Duration, Utc};
use mockito::{mock, Matcher};
use serde_json::json;
use crate::client::API_VERSION;
use crate::mock_key_client;
use crate::tests::MockCredential;
fn diff(first: DateTime<Utc>, second: DateTime<Utc>) -> Duration {
if first > second {
first - second
} else {
second - first
}
}
#[tokio::test]
async fn can_get_key() {
let time_created = Utc::now() - Duration::days(7);
let time_updated = Utc::now();
let _m = mock("GET", "/keys/test-key/78deebed173b48e48f55abf87ed4cf71")
.match_query(Matcher::UrlEncoded("api-version".into(), API_VERSION.into()))
.with_header("content-type", "application/json")
.with_body(
json!({
"key": {
"kid": "https://test-keyvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71",
"kty": "RSA",
"key_ops": [
"encrypt",
"decrypt",
"sign",
"verify",
"wrapKey",
"unwrapKey",
"destroy!"
],
"n": "2HJAE5fU3Cw2Rt9hEuq-F6XjINKGa-zskfISVqopqUy60GOs2eyhxbWbJBeUXNor_gf-tXtNeuqeBgitLeVa640UDvnEjYTKWjCniTxZRaU7ewY8BfTSk-7KxoDdLsPSpX_MX4rwlAx-_1UGk5t4sQgTbm9T6Fm2oqFd37dsz5-Gj27UP2GTAShfJPFD7MqU_zIgOI0pfqsbNL5xTQVM29K6rX4jSPtylZV3uWJtkoQIQnrIHhk1d0SC0KwlBV3V7R_LVYjiXLyIXsFzSNYgQ68ZjAwt8iL7I8Osa-ehQLM13DVvLASaf7Jnu3sC3CWl3Gyirgded6cfMmswJzY87w",
"e": "AQAB"
},
"attributes": {
"enabled": true,
"created": time_created.timestamp(),
"updated": time_updated.timestamp(),
"recoveryLevel": "Recoverable+Purgeable"
},
"tags": {
"purpose": "unit test",
"test name ": "CreateGetDeleteKeyTest"
}
})
.to_string(),
)
.with_status(200)
.create();
let creds = MockCredential;
let mut client = mock_key_client!(&"test-keyvault", &creds,);
let key = client
.get_key("test-key", Some("78deebed173b48e48f55abf87ed4cf71"))
.await
.unwrap();
let JsonWebKey { id, n, .. } = key.key();
let KeyProperties {
attributes,
managed,
tags,
} = key.properties();
let KeyAttributes {
created_on,
enabled,
updated_on,
..
} = attributes;
let expected_n = base64::decode_config("2HJAE5fU3Cw2Rt9hEuq-F6XjINKGa-zskfISVqopqUy60GOs2eyhxbWbJBeUXNor_gf-tXtNeuqeBgitLeVa640UDvnEjYTKWjCniTxZRaU7ewY8BfTSk-7KxoDdLsPSpX_MX4rwlAx-_1UGk5t4sQgTbm9T6Fm2oqFd37dsz5-Gj27UP2GTAShfJPFD7MqU_zIgOI0pfqsbNL5xTQVM29K6rX4jSPtylZV3uWJtkoQIQnrIHhk1d0SC0KwlBV3V7R_LVYjiXLyIXsFzSNYgQ68ZjAwt8iL7I8Osa-ehQLM13DVvLASaf7Jnu3sC3CWl3Gyirgded6cfMmswJzY87w", BASE64_URL_SAFE).unwrap();
assert_eq!(expected_n, n.to_owned().unwrap());
assert_eq!(
"https://test-keyvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71",
id.to_owned().unwrap()
);
assert!(managed.is_none());
assert_eq!(
tags.to_owned().unwrap().get("purpose").unwrap(),
"unit test"
);
assert!(enabled.unwrap());
assert!(diff(time_created, created_on.unwrap()) < Duration::seconds(1));
assert!(diff(time_updated, updated_on.unwrap()) < Duration::seconds(1));
}
#[tokio::test]
async fn can_sign() {
let _m = mock("POST", "/keys/test-key/78deebed173b48e48f55abf87ed4cf71/sign")
.match_query(Matcher::UrlEncoded("api-version".into(), API_VERSION.into()))
.with_header("content-type", "application/json")
.with_body(
json!({
"kid": "https://myvault.vault.azure.net/keys/testkey/9885aa558e8d448789683188f8c194b0",
"value": "aKFG8NXcfTzqyR44rW42484K_zZI_T7zZuebvWuNgAoEI1gXYmxrshp42CunSmmu4oqo4-IrCikPkNIBkHXnAW2cv03Ad0UpwXhVfepK8zzDBaJPMKVGS-ZRz8CshEyGDKaLlb3J3zEkXpM3RrSEr0mdV6hndHD_mznLB5RmFui5DsKAhez4vUqajgtkgcPfCekMqeSwp6r9ItVL-gEoAohx8XMDsPedqu-7BuZcBcdayaPuBRL4wWoTDULA11P-UN_sJ5qMj3BbiRYhIlBWGR04wIGfZ3pkJjHJUpOvgH2QajdYPzUBauOCewMYbq9XkLRSzI_A7HkkDVycugSeAA"
})
.to_string(),
)
.with_status(200)
.create();
let creds = MockCredential;
let mut client = mock_key_client!(&"test-keyvault", &creds,);
let res = client
.sign(
SignatureAlgorithm::RS512,
"test-key",
"78deebed173b48e48f55abf87ed4cf71",
"base64msg2sign",
)
.await
.unwrap();
let kid = res.key_id();
let sig = res.signature();
let alg = res.algorithm();
assert_eq!(
kid,
"https://myvault.vault.azure.net/keys/testkey/9885aa558e8d448789683188f8c194b0"
);
let expected_sig = base64::decode_config("aKFG8NXcfTzqyR44rW42484K_zZI_T7zZuebvWuNgAoEI1gXYmxrshp42CunSmmu4oqo4-IrCikPkNIBkHXnAW2cv03Ad0UpwXhVfepK8zzDBaJPMKVGS-ZRz8CshEyGDKaLlb3J3zEkXpM3RrSEr0mdV6hndHD_mznLB5RmFui5DsKAhez4vUqajgtkgcPfCekMqeSwp6r9ItVL-gEoAohx8XMDsPedqu-7BuZcBcdayaPuBRL4wWoTDULA11P-UN_sJ5qMj3BbiRYhIlBWGR04wIGfZ3pkJjHJUpOvgH2QajdYPzUBauOCewMYbq9XkLRSzI_A7HkkDVycugSeAA", BASE64_URL_SAFE).unwrap();
assert_eq!(expected_sig, sig.to_owned());
assert!(matches!(alg, SignatureAlgorithm::RS512));
}
#[tokio::test]
async fn can_decrypt() {
let _m = mock("POST", "/keys/test-key/78deebed173b48e48f55abf87ed4cf71/decrypt")
.match_query(Matcher::UrlEncoded("api-version".into(), API_VERSION.into()))
.with_header("content-type", "application/json")
.with_body(
json!({
"kid": "https://myvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71",
"value": "dvDmrSBpjRjtYg"
})
.to_string(),
)
.with_status(200)
.create();
let creds = MockCredential;
let mut client = mock_key_client!(&"test-keyvault", &creds,);
let decrypt_parameters = DecryptParameters {
ciphertext: base64::decode("dvDmrSBpjRjtYg").unwrap(),
decrypt_parameters_encryption: DecryptParametersEncryption::Rsa(
RsaDecryptParameters::new(EncryptionAlgorithm::RsaOaep256).unwrap(),
),
};
let res = client
.decrypt(
"test-key",
Some("78deebed173b48e48f55abf87ed4cf71"),
decrypt_parameters,
)
.await
.unwrap();
let kid = res.key_id();
let val = res.result();
let alg = res.algorithm();
assert_eq!(
kid,
"https://myvault.vault.azure.net/keys/test-key/78deebed173b48e48f55abf87ed4cf71"
);
let expected_val = base64::decode_config("dvDmrSBpjRjtYg", BASE64_URL_SAFE).unwrap();
assert_eq!(expected_val, val.to_owned());
assert!(matches!(alg, &EncryptionAlgorithm::RsaOaep256));
}
}