use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use base64::Engine;
use hmac::{Hmac, Mac};
use sha1::Sha1;
use crate::error::{BatataError, Result};
#[async_trait]
pub trait KmsProvider: Send + Sync {
async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>>;
fn name(&self) -> &str;
}
pub struct AliyunKmsProvider {
pub region_id: String,
pub access_key_id: String,
pub access_key_secret: String,
pub kms_endpoint: Option<String>,
client: reqwest::Client,
}
impl AliyunKmsProvider {
pub fn new(
region_id: impl Into<String>,
access_key_id: impl Into<String>,
access_key_secret: impl Into<String>,
) -> Self {
Self {
region_id: region_id.into(),
access_key_id: access_key_id.into(),
access_key_secret: access_key_secret.into(),
kms_endpoint: None,
client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to build HTTP client"),
}
}
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.kms_endpoint = Some(endpoint.into());
self
}
fn get_endpoint(&self) -> String {
self.kms_endpoint.clone().unwrap_or_else(|| {
format!("https://kms.{}.aliyuncs.com", self.region_id)
})
}
fn sign_request(&self, params: &HashMap<String, String>) -> String {
let mut sorted_keys: Vec<&String> = params.keys().collect();
sorted_keys.sort();
let canonical_query: String = sorted_keys
.iter()
.map(|k| {
format!(
"{}={}",
percent_encode(k),
percent_encode(params.get(*k).unwrap())
)
})
.collect::<Vec<_>>()
.join("&");
let string_to_sign = format!(
"POST&{}&{}",
percent_encode("/"),
percent_encode(&canonical_query)
);
let sign_key = format!("{}&", self.access_key_secret);
let mut mac = Hmac::<Sha1>::new_from_slice(sign_key.as_bytes())
.expect("HMAC can take key of any size");
mac.update(string_to_sign.as_bytes());
let result = mac.finalize();
base64::engine::general_purpose::STANDARD.encode(result.into_bytes())
}
}
#[async_trait]
impl KmsProvider for AliyunKmsProvider {
async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>> {
let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
let nonce = uuid::Uuid::new_v4().to_string();
let mut params = HashMap::new();
params.insert("Action".to_string(), "Decrypt".to_string());
params.insert("CiphertextBlob".to_string(), encrypted_data_key.to_string());
params.insert("Format".to_string(), "JSON".to_string());
params.insert("Version".to_string(), "2016-01-20".to_string());
params.insert("AccessKeyId".to_string(), self.access_key_id.clone());
params.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
params.insert("Timestamp".to_string(), timestamp);
params.insert("SignatureVersion".to_string(), "1.0".to_string());
params.insert("SignatureNonce".to_string(), nonce);
let signature = self.sign_request(¶ms);
params.insert("Signature".to_string(), signature);
let endpoint = self.get_endpoint();
let response = self
.client
.post(&endpoint)
.form(¶ms)
.send()
.await
.map_err(|e| BatataError::EncryptionError {
message: format!("KMS request failed: {}", e),
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(BatataError::EncryptionError {
message: format!("KMS decryption failed: {} - {}", status, body),
});
}
let body: serde_json::Value =
response
.json()
.await
.map_err(|e| BatataError::EncryptionError {
message: format!("Failed to parse KMS response: {}", e),
})?;
let plaintext = body["Plaintext"]
.as_str()
.ok_or_else(|| BatataError::EncryptionError {
message: "Missing Plaintext in KMS response".to_string(),
})?;
base64::engine::general_purpose::STANDARD
.decode(plaintext)
.map_err(|e| BatataError::EncryptionError {
message: format!("Failed to decode plaintext: {}", e),
})
}
fn name(&self) -> &str {
"AliyunKMS"
}
}
pub struct NoopKmsProvider;
impl NoopKmsProvider {
pub fn new() -> Self {
Self
}
}
impl Default for NoopKmsProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl KmsProvider for NoopKmsProvider {
async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>> {
base64::engine::general_purpose::STANDARD
.decode(encrypted_data_key)
.map_err(|e| BatataError::EncryptionError {
message: format!("Failed to decode data key: {}", e),
})
}
fn name(&self) -> &str {
"Noop"
}
}
pub fn create_aliyun_kms_provider(
region_id: impl Into<String>,
access_key_id: impl Into<String>,
access_key_secret: impl Into<String>,
) -> Arc<dyn KmsProvider> {
Arc::new(AliyunKmsProvider::new(
region_id,
access_key_id,
access_key_secret,
))
}
fn percent_encode(s: &str) -> String {
let mut result = String::new();
for c in s.chars() {
match c {
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
_ => {
for b in c.to_string().as_bytes() {
result.push_str(&format!("%{:02X}", b));
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_percent_encode() {
assert_eq!(percent_encode("hello"), "hello");
assert_eq!(percent_encode("hello world"), "hello%20world");
assert_eq!(percent_encode("a=b&c=d"), "a%3Db%26c%3Dd");
}
#[tokio::test]
async fn test_noop_provider() {
let provider = NoopKmsProvider::new();
let data = base64::engine::general_purpose::STANDARD.encode("test-data");
let result = provider.decrypt_data_key(&data).await.unwrap();
assert_eq!(String::from_utf8(result).unwrap(), "test-data");
}
#[test]
fn test_aliyun_provider_creation() {
let provider = AliyunKmsProvider::new("cn-hangzhou", "test-ak", "test-sk");
assert_eq!(provider.region_id, "cn-hangzhou");
assert_eq!(provider.access_key_id, "test-ak");
assert_eq!(provider.name(), "AliyunKMS");
}
}