litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Cryptographic utilities for the Gateway
//!
//! This module provides cryptographic functions for password hashing, API key generation, etc.

#![allow(dead_code)]

use crate::utils::error::{GatewayError, Result};
use argon2::password_hash::{SaltString, rand_core::OsRng};
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use base64::{Engine as _, engine::general_purpose};
use hmac::{Hmac, Mac};
use rand::{Rng, distributions::Alphanumeric};
use sha2::{Digest, Sha256};
use std::time::{SystemTime, UNIX_EPOCH};

type HmacSha256 = Hmac<Sha256>;

/// Hash a password using Argon2
pub fn hash_password(password: &str) -> Result<String> {
    let salt = SaltString::generate(&mut OsRng);
    let argon2 = Argon2::default();

    let password_hash = argon2
        .hash_password(password.as_bytes(), &salt)
        .map_err(|e| GatewayError::Crypto(format!("Failed to hash password: {}", e)))?;

    Ok(password_hash.to_string())
}

/// Verify a password against its hash
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
    let parsed_hash = PasswordHash::new(hash)
        .map_err(|e| GatewayError::Crypto(format!("Failed to parse password hash: {}", e)))?;

    let argon2 = Argon2::default();

    match argon2.verify_password(password.as_bytes(), &parsed_hash) {
        Ok(()) => Ok(true),
        Err(argon2::password_hash::Error::Password) => Ok(false),
        Err(e) => Err(GatewayError::Crypto(format!(
            "Password verification failed: {}",
            e
        ))),
    }
}

/// Generate a secure API key
pub fn generate_api_key() -> String {
    let prefix = "gw";
    let random_part: String = rand::thread_rng()
        .sample_iter(&Alphanumeric)
        .take(32)
        .map(char::from)
        .collect();

    format!("{}-{}", prefix, random_part)
}

/// Generate a JWT secret
pub fn generate_jwt_secret() -> String {
    rand::thread_rng()
        .sample_iter(&Alphanumeric)
        .take(64)
        .map(char::from)
        .collect()
}

/// Generate a secure random token
pub fn generate_token(length: usize) -> String {
    rand::thread_rng()
        .sample_iter(&Alphanumeric)
        .take(length)
        .map(char::from)
        .collect()
}

/// Generate a secure session token
pub fn generate_session_token() -> String {
    let mut rng = rand::thread_rng();
    let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect();
    general_purpose::URL_SAFE_NO_PAD.encode(&bytes)
}

/// Hash API key for storage
pub fn hash_api_key(api_key: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(api_key.as_bytes());
    hex::encode(hasher.finalize())
}

/// Generate API key prefix for identification
pub fn extract_api_key_prefix(api_key: &str) -> String {
    if api_key.len() >= 8 {
        format!("{}...{}", &api_key[..4], &api_key[api_key.len() - 4..])
    } else {
        api_key.to_string()
    }
}

/// Create HMAC signature
pub fn create_hmac_signature(secret: &str, data: &str) -> Result<String> {
    let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
        .map_err(|e| GatewayError::Crypto(format!("Invalid HMAC key: {}", e)))?;

    mac.update(data.as_bytes());
    let result = mac.finalize();
    Ok(hex::encode(result.into_bytes()))
}

/// Verify HMAC signature
pub fn verify_hmac_signature(secret: &str, data: &str, signature: &str) -> Result<bool> {
    let expected_signature = create_hmac_signature(secret, data)?;
    Ok(constant_time_eq(&expected_signature, signature))
}

/// Constant-time string comparison
fn constant_time_eq(a: &str, b: &str) -> bool {
    if a.len() != b.len() {
        return false;
    }

    let mut result = 0u8;
    for (a_byte, b_byte) in a.bytes().zip(b.bytes()) {
        result |= a_byte ^ b_byte;
    }

    result == 0
}

/// Generate a secure random salt
pub fn generate_salt() -> String {
    let mut rng = rand::thread_rng();
    let bytes: Vec<u8> = (0..16).map(|_| rng.r#gen()).collect();
    general_purpose::STANDARD.encode(&bytes)
}

/// Hash data with salt
pub fn hash_with_salt(data: &str, salt: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(data.as_bytes());
    hasher.update(salt.as_bytes());
    hex::encode(hasher.finalize())
}

/// Generate a time-based one-time password (TOTP) secret
pub fn generate_totp_secret() -> String {
    let mut rng = rand::thread_rng();
    let bytes: Vec<u8> = (0..20).map(|_| rng.r#gen()).collect();
    general_purpose::STANDARD.encode(&bytes)
}

/// Encrypt data using AES-GCM (simplified version)
pub fn encrypt_data(key: &[u8], data: &str) -> Result<String> {
    // This is a simplified implementation
    // In production, you would use a proper AES-GCM implementation
    let mut hasher = Sha256::new();
    hasher.update(key);
    hasher.update(data.as_bytes());
    Ok(hex::encode(hasher.finalize()))
}

/// Decrypt data using AES-GCM (simplified version)
pub fn decrypt_data(key: &[u8], encrypted_data: &str) -> Result<String> {
    // Basic implementation using XOR cipher for demonstration
    // In production, you would use a proper AES-GCM implementation

    // Decode base64 encrypted data
    let encrypted_bytes =
        base64::Engine::decode(&base64::engine::general_purpose::STANDARD, encrypted_data)
            .map_err(|e| GatewayError::Crypto(format!("Failed to decode encrypted data: {}", e)))?;

    // Simple XOR decryption (for demonstration only)
    let mut decrypted = Vec::new();
    for (i, &byte) in encrypted_bytes.iter().enumerate() {
        let key_byte = key[i % key.len()];
        decrypted.push(byte ^ key_byte);
    }

    String::from_utf8(decrypted).map_err(|e| {
        GatewayError::Crypto(format!("Failed to convert decrypted data to string: {}", e))
    })
}

/// Generate a secure backup code
pub fn generate_backup_code() -> String {
    let mut rng = rand::thread_rng();
    let code: String = (0..8)
        .map(|_| rng.gen_range(0..10).to_string())
        .collect::<Vec<_>>()
        .chunks(4)
        .map(|chunk| chunk.join(""))
        .collect::<Vec<_>>()
        .join("-");
    code
}

/// Generate multiple backup codes
pub fn generate_backup_codes(count: usize) -> Vec<String> {
    (0..count).map(|_| generate_backup_code()).collect()
}

/// Hash backup code for storage
pub fn hash_backup_code(code: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(code.as_bytes());
    hasher.update(b"backup_code_salt"); // Simple salt
    hex::encode(hasher.finalize())
}

/// Generate a webhook signature
pub fn generate_webhook_signature(secret: &str, payload: &str, timestamp: u64) -> Result<String> {
    let data = format!("{}.{}", timestamp, payload);
    create_hmac_signature(secret, &data)
}

/// Verify webhook signature
pub fn verify_webhook_signature(
    secret: &str,
    payload: &str,
    timestamp: u64,
    signature: &str,
) -> Result<bool> {
    // Check timestamp is within acceptable range (e.g., 5 minutes)
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap()
        .as_secs();

    if now.saturating_sub(timestamp) > 300 {
        return Ok(false); // Timestamp too old
    }

    let expected_signature = generate_webhook_signature(secret, payload, timestamp)?;
    Ok(constant_time_eq(&expected_signature, signature))
}

/// Generate a secure file upload token
pub fn generate_upload_token(user_id: &str, expires_at: u64) -> Result<String> {
    let data = format!("{}:{}", user_id, expires_at);
    let mut hasher = Sha256::new();
    hasher.update(data.as_bytes());
    Ok(general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()))
}

/// Verify file upload token
pub fn verify_upload_token(token: &str, user_id: &str, expires_at: u64) -> Result<bool> {
    let expected_token = generate_upload_token(user_id, expires_at)?;
    Ok(constant_time_eq(&expected_token, token))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_password_hashing() {
        let password = "test_password_123";
        let hash = hash_password(password).unwrap();

        assert!(verify_password(password, &hash).unwrap());
        assert!(!verify_password("wrong_password", &hash).unwrap());
    }

    #[test]
    fn test_api_key_generation() {
        let api_key = generate_api_key();
        assert!(api_key.starts_with("gw-"));
        assert_eq!(api_key.len(), 35); // "gw-" + 32 characters
    }

    #[test]
    fn test_jwt_secret_generation() {
        let secret = generate_jwt_secret();
        assert_eq!(secret.len(), 64);
        assert!(secret.chars().all(|c| c.is_alphanumeric()));
    }

    #[test]
    fn test_api_key_hashing() {
        let api_key = "gw-test123456789";
        let hash = hash_api_key(api_key);
        assert_eq!(hash.len(), 64); // SHA256 hex string
    }

    #[test]
    fn test_api_key_prefix() {
        let api_key = "gw-test123456789";
        let prefix = extract_api_key_prefix(api_key);
        assert_eq!(prefix, "gw-t...6789");
    }

    #[test]
    fn test_hmac_signature() {
        let secret = "test_secret";
        let data = "test_data";

        let signature = create_hmac_signature(secret, data).unwrap();
        assert!(verify_hmac_signature(secret, data, &signature).unwrap());
        assert!(!verify_hmac_signature(secret, "wrong_data", &signature).unwrap());
    }

    #[test]
    fn test_constant_time_eq() {
        assert!(constant_time_eq("hello", "hello"));
        assert!(!constant_time_eq("hello", "world"));
        assert!(!constant_time_eq("hello", "hello2"));
    }

    #[test]
    fn test_backup_code_generation() {
        let code = generate_backup_code();
        assert_eq!(code.len(), 9); // 4 digits + "-" + 4 digits
        assert!(code.contains('-'));

        let codes = generate_backup_codes(5);
        assert_eq!(codes.len(), 5);
        assert!(codes.iter().all(|c| c.len() == 9));
    }

    #[test]
    fn test_webhook_signature() {
        let secret = "webhook_secret";
        let payload = r#"{"test": "data"}"#;
        let timestamp = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap()
            .as_secs();

        let signature = generate_webhook_signature(secret, payload, timestamp).unwrap();
        assert!(verify_webhook_signature(secret, payload, timestamp, &signature).unwrap());

        // Test with wrong payload
        assert!(!verify_webhook_signature(secret, "wrong", timestamp, &signature).unwrap());

        // Test with old timestamp (should fail)
        let old_timestamp = timestamp - 400; // More than 5 minutes old
        let old_signature = generate_webhook_signature(secret, payload, old_timestamp).unwrap();
        assert!(!verify_webhook_signature(secret, payload, old_timestamp, &old_signature).unwrap());
    }
}