batata-client 0.0.2

Rust client for Batata/Nacos service discovery and configuration management
Documentation
//! KMS (Key Management Service) provider implementations
//!
//! Provides traits and implementations for decrypting data keys using KMS services.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use base64::Engine;
use hmac::{Hmac, Mac};
use sha1::Sha1;

use crate::error::{BatataError, Result};

/// KMS provider trait for decrypting data keys
///
/// Implement this trait to support different KMS backends.
#[async_trait]
pub trait KmsProvider: Send + Sync {
    /// Decrypt an encrypted data key
    ///
    /// # Arguments
    /// * `encrypted_data_key` - Base64 encoded encrypted data key
    ///
    /// # Returns
    /// Decrypted data key bytes
    async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>>;

    /// Get provider name
    fn name(&self) -> &str;
}

/// Alibaba Cloud KMS provider
///
/// Supports decryption using Alibaba Cloud KMS service.
pub struct AliyunKmsProvider {
    /// Region ID (e.g., "cn-hangzhou")
    pub region_id: String,
    /// Access key ID
    pub access_key_id: String,
    /// Access key secret
    pub access_key_secret: String,
    /// KMS endpoint (optional, defaults to regional endpoint)
    pub kms_endpoint: Option<String>,
    /// HTTP client
    client: reqwest::Client,
}

impl AliyunKmsProvider {
    /// Create a new Alibaba Cloud KMS provider
    pub fn new(
        region_id: impl Into<String>,
        access_key_id: impl Into<String>,
        access_key_secret: impl Into<String>,
    ) -> Self {
        Self {
            region_id: region_id.into(),
            access_key_id: access_key_id.into(),
            access_key_secret: access_key_secret.into(),
            kms_endpoint: None,
            client: reqwest::Client::builder()
                .timeout(Duration::from_secs(10))
                .build()
                .expect("Failed to build HTTP client"),
        }
    }

    /// Set custom KMS endpoint
    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
        self.kms_endpoint = Some(endpoint.into());
        self
    }

    /// Get KMS endpoint
    fn get_endpoint(&self) -> String {
        self.kms_endpoint.clone().unwrap_or_else(|| {
            format!("https://kms.{}.aliyuncs.com", self.region_id)
        })
    }

    /// Generate signature for Aliyun API
    fn sign_request(&self, params: &HashMap<String, String>) -> String {
        // Sort parameters by key
        let mut sorted_keys: Vec<&String> = params.keys().collect();
        sorted_keys.sort();

        // Build canonicalized query string
        let canonical_query: String = sorted_keys
            .iter()
            .map(|k| {
                format!(
                    "{}={}",
                    percent_encode(k),
                    percent_encode(params.get(*k).unwrap())
                )
            })
            .collect::<Vec<_>>()
            .join("&");

        // Build string to sign
        let string_to_sign = format!(
            "POST&{}&{}",
            percent_encode("/"),
            percent_encode(&canonical_query)
        );

        // Calculate HMAC-SHA1
        let sign_key = format!("{}&", self.access_key_secret);
        let mut mac = Hmac::<Sha1>::new_from_slice(sign_key.as_bytes())
            .expect("HMAC can take key of any size");
        mac.update(string_to_sign.as_bytes());
        let result = mac.finalize();

        base64::engine::general_purpose::STANDARD.encode(result.into_bytes())
    }
}

#[async_trait]
impl KmsProvider for AliyunKmsProvider {
    async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>> {
        let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
        let nonce = uuid::Uuid::new_v4().to_string();

        let mut params = HashMap::new();
        params.insert("Action".to_string(), "Decrypt".to_string());
        params.insert("CiphertextBlob".to_string(), encrypted_data_key.to_string());
        params.insert("Format".to_string(), "JSON".to_string());
        params.insert("Version".to_string(), "2016-01-20".to_string());
        params.insert("AccessKeyId".to_string(), self.access_key_id.clone());
        params.insert("SignatureMethod".to_string(), "HMAC-SHA1".to_string());
        params.insert("Timestamp".to_string(), timestamp);
        params.insert("SignatureVersion".to_string(), "1.0".to_string());
        params.insert("SignatureNonce".to_string(), nonce);

        let signature = self.sign_request(&params);
        params.insert("Signature".to_string(), signature);

        let endpoint = self.get_endpoint();
        let response = self
            .client
            .post(&endpoint)
            .form(&params)
            .send()
            .await
            .map_err(|e| BatataError::EncryptionError {
                message: format!("KMS request failed: {}", e),
            })?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();
            return Err(BatataError::EncryptionError {
                message: format!("KMS decryption failed: {} - {}", status, body),
            });
        }

        let body: serde_json::Value =
            response
                .json()
                .await
                .map_err(|e| BatataError::EncryptionError {
                    message: format!("Failed to parse KMS response: {}", e),
                })?;

        let plaintext = body["Plaintext"]
            .as_str()
            .ok_or_else(|| BatataError::EncryptionError {
                message: "Missing Plaintext in KMS response".to_string(),
            })?;

        base64::engine::general_purpose::STANDARD
            .decode(plaintext)
            .map_err(|e| BatataError::EncryptionError {
                message: format!("Failed to decode plaintext: {}", e),
            })
    }

    fn name(&self) -> &str {
        "AliyunKMS"
    }
}

/// No-op KMS provider for testing or when encryption is disabled
pub struct NoopKmsProvider;

impl NoopKmsProvider {
    /// Create a new no-op KMS provider
    pub fn new() -> Self {
        Self
    }
}

impl Default for NoopKmsProvider {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl KmsProvider for NoopKmsProvider {
    async fn decrypt_data_key(&self, encrypted_data_key: &str) -> Result<Vec<u8>> {
        // Just decode the base64 data directly (for testing)
        base64::engine::general_purpose::STANDARD
            .decode(encrypted_data_key)
            .map_err(|e| BatataError::EncryptionError {
                message: format!("Failed to decode data key: {}", e),
            })
    }

    fn name(&self) -> &str {
        "Noop"
    }
}

/// Create a shared KMS provider
pub fn create_aliyun_kms_provider(
    region_id: impl Into<String>,
    access_key_id: impl Into<String>,
    access_key_secret: impl Into<String>,
) -> Arc<dyn KmsProvider> {
    Arc::new(AliyunKmsProvider::new(
        region_id,
        access_key_id,
        access_key_secret,
    ))
}

/// URL percent encoding for Aliyun signature
fn percent_encode(s: &str) -> String {
    let mut result = String::new();
    for c in s.chars() {
        match c {
            'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => result.push(c),
            _ => {
                for b in c.to_string().as_bytes() {
                    result.push_str(&format!("%{:02X}", b));
                }
            }
        }
    }
    result
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_percent_encode() {
        assert_eq!(percent_encode("hello"), "hello");
        assert_eq!(percent_encode("hello world"), "hello%20world");
        assert_eq!(percent_encode("a=b&c=d"), "a%3Db%26c%3Dd");
    }

    #[tokio::test]
    async fn test_noop_provider() {
        let provider = NoopKmsProvider::new();
        let data = base64::engine::general_purpose::STANDARD.encode("test-data");
        let result = provider.decrypt_data_key(&data).await.unwrap();
        assert_eq!(String::from_utf8(result).unwrap(), "test-data");
    }

    #[test]
    fn test_aliyun_provider_creation() {
        let provider = AliyunKmsProvider::new("cn-hangzhou", "test-ak", "test-sk");
        assert_eq!(provider.region_id, "cn-hangzhou");
        assert_eq!(provider.access_key_id, "test-ak");
        assert_eq!(provider.name(), "AliyunKMS");
    }
}