use azure_core::auth::TokenCredential;
use azure_security_keyvault::KeyClient;
use azure_security_keyvault::prelude::{
CryptographParamtersEncryption, DecryptParameters, EncryptParameters,
};
use log::error;
use std::future::IntoFuture;
use std::sync::mpsc::SyncSender;
use std::sync::{Arc, mpsc};
use std::{cell::RefCell, rc::Rc};
use tink_core::TinkError;
use tink_core::utils::wrap_err;
use url::Url;
#[derive(Clone)]
pub struct AzureAead {
kms: KeyClient,
key_name: String,
key_version: Option<String>,
algorithm: CryptographParamtersEncryption,
}
impl AzureAead {
pub(crate) fn new(
key_url: &str,
creds: Arc<dyn TokenCredential>,
algorithm: CryptographParamtersEncryption,
) -> Result<AzureAead, TinkError> {
let (vault_url, key_name, key_version) = get_key_info(key_url)?;
let kms = KeyClient::new(&vault_url, creds)
.map_err(|e| wrap_err("failed to create KeyClient", e))?;
Ok(AzureAead {
kms,
key_name,
key_version,
algorithm,
})
}
async fn encrypt_async(
self,
plaintext: Vec<u8>,
additional_data: Vec<u8>,
sender: SyncSender<Result<Vec<u8>, TinkError>>,
) {
let params = EncryptParameters {
encrypt_parameters_encryption: self.algorithm.clone(),
plaintext,
};
let mut req = self.kms.encrypt(self.key_name.clone(), params);
if let Some(version) = self.key_version.clone() {
req = req.version(version);
}
let result = req
.into_future()
.await
.map(|r| r.result)
.map_err(|e| wrap_err("failed to encrypt", e));
if result.is_err() {
error!("failed to encrypt: {result:?}");
}
if sender.send(result).is_err() {
error!("failed to send result");
}
}
async fn decrypt_async(
self,
ciphertext: Vec<u8>,
additional_data: Vec<u8>,
sender: SyncSender<Result<Vec<u8>, TinkError>>,
) {
let params = DecryptParameters {
decrypt_parameters_encryption: self.algorithm.clone(),
ciphertext,
};
let mut req = self.kms.decrypt(self.key_name.clone(), params);
if let Some(version) = self.key_version.clone() {
req = req.version(version);
}
let result = req
.into_future()
.await
.map(|r| r.result)
.map_err(|e| wrap_err("request failed", e));
if result.is_err() {
error!("failed to decrypt: {result:?}");
}
if sender.send(result).is_err() {
error!("failed to send result");
}
}
}
impl tink_core::Aead for AzureAead {
fn encrypt(&self, plaintext: &[u8], additional_data: &[u8]) -> Result<Vec<u8>, TinkError> {
let (sender, receiver) = mpsc::sync_channel(1);
let this = self.clone();
let plaintext_vec = plaintext.to_vec();
let ad_vec = additional_data.to_vec();
tokio::spawn(async move { this.encrypt_async(plaintext_vec, ad_vec, sender).await });
receiver
.recv()
.map_err(|e| wrap_err("failed to receive", e))?
}
fn decrypt(&self, ciphertext: &[u8], additional_data: &[u8]) -> Result<Vec<u8>, TinkError> {
let (sender, receiver) = mpsc::sync_channel(1);
let this = self.clone();
let cipher_vec = ciphertext.to_vec();
let ad_vec = additional_data.to_vec();
tokio::spawn(async move { this.decrypt_async(cipher_vec, ad_vec, sender).await });
receiver
.recv()
.map_err(|e| wrap_err("failed to receive", e))?
}
}
fn get_key_info(key_uri: &str) -> Result<(String, String, Option<String>), TinkError> {
let parsed = Url::parse(key_uri).map_err(|e| wrap_err("failed to parse URI", e))?;
let path = parsed.path();
let parts: Vec<&str> = path.split('/').collect();
let len = parts.len();
if (len != 3 && len != 4) || !parts[0].is_empty() || parts[1] != "keys" {
return Err("invalid key uri".into());
}
let vault_url = parsed.scheme().to_string() + "://" + parsed.host_str().unwrap_or("localhost");
let key_name = parts[2].to_string();
let key_version = if len == 4 && !parts[3].is_empty() {
Some(parts[3].to_string())
} else {
None
};
Ok((vault_url, key_name, key_version))
}