use std::collections::HashMap;
use tempfile::TempDir;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let session_key = generate_session_key();
let mut keys = HashMap::new();
keys.insert("parallel_api_key".to_string(), "pk_test_12345".to_string());
keys.insert("epo_api_key".to_string(), "epo_secret_67890".to_string());
keys.insert(
"cerebras_api_key".to_string(),
"csk-abc123def456".to_string(),
);
let plaintext = serde_json::to_vec(&keys).unwrap();
let encrypted = encrypt_keyring(&session_key, &plaintext);
assert!(encrypted.is_ok(), "Encryption should succeed");
let encrypted = encrypted.unwrap();
assert_ne!(encrypted, plaintext);
assert!(encrypted.len() > plaintext.len());
let decrypted = decrypt_keyring(&session_key, &encrypted);
assert!(decrypted.is_ok(), "Decryption should succeed");
let decrypted = decrypted.unwrap();
assert_eq!(decrypted, plaintext);
let recovered: HashMap<String, String> = serde_json::from_slice(&decrypted).unwrap();
assert_eq!(recovered.get("parallel_api_key").unwrap(), "pk_test_12345");
assert_eq!(recovered.get("epo_api_key").unwrap(), "epo_secret_67890");
}
#[test]
fn test_wrong_key_fails_decryption() {
let key1 = generate_session_key();
let key2 = generate_session_key();
let plaintext = br#"{"test_key":"secret_value"}"#;
let encrypted = encrypt_keyring(&key1, plaintext).unwrap();
let result = decrypt_keyring(&key2, &encrypted);
assert!(result.is_err(), "Decryption with wrong key should fail");
}
#[test]
fn test_tampered_ciphertext_fails() {
let session_key = generate_session_key();
let plaintext = br#"{"key":"value"}"#;
let mut encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
if encrypted.len() > 15 {
encrypted[14] ^= 0xFF;
}
let result = decrypt_keyring(&session_key, &encrypted);
assert!(
result.is_err(),
"Tampered ciphertext should fail authentication"
);
}
#[test]
fn test_keyring_file_roundtrip() {
let dir = TempDir::new().unwrap();
let session_key = generate_session_key();
let keys: HashMap<String, String> = [("api_key".into(), "secret123".into())].into();
let plaintext = serde_json::to_vec(&keys).unwrap();
let encrypted = encrypt_keyring(&session_key, &plaintext).unwrap();
let keyring_path = dir.path().join("keyring.enc");
std::fs::write(&keyring_path, &encrypted).unwrap();
let key_path = dir.path().join(".key");
let key_b64 = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, session_key);
std::fs::write(&key_path, &key_b64).unwrap();
assert!(key_path.exists());
let contents = std::fs::read_to_string(&key_path).unwrap();
std::fs::remove_file(&key_path).unwrap();
let decoded =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, contents.trim())
.unwrap();
assert_eq!(decoded.len(), 32);
assert!(!key_path.exists(), "Key file should be deleted");
let mut key_array = [0u8; 32];
key_array.copy_from_slice(&decoded);
let encrypted_data = std::fs::read(&keyring_path).unwrap();
let decrypted = decrypt_keyring(&key_array, &encrypted_data).unwrap();
let recovered: HashMap<String, String> = serde_json::from_slice(&decrypted).unwrap();
assert_eq!(recovered.get("api_key").unwrap(), "secret123");
}
#[test]
fn test_too_small_encrypted_data() {
let key = generate_session_key();
let result = decrypt_keyring(&key, &[0u8; 10]);
assert!(result.is_err());
}
#[test]
fn test_empty_keyring() {
let session_key = generate_session_key();
let plaintext = br#"{}"#;
let encrypted = encrypt_keyring(&session_key, plaintext).unwrap();
let decrypted = decrypt_keyring(&session_key, &encrypted).unwrap();
let recovered: HashMap<String, String> = serde_json::from_slice(&decrypted).unwrap();
assert!(recovered.is_empty());
}
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
AeadCore, Aes256Gcm, Nonce,
};
use rand::RngCore;
const NONCE_SIZE: usize = 12;
fn generate_session_key() -> [u8; 32] {
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
key
}
fn encrypt_keyring(
session_key: &[u8; 32],
plaintext: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
let cipher = Aes256Gcm::new_from_slice(session_key)?;
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| format!("{e}"))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce);
result.extend_from_slice(&ciphertext);
Ok(result)
}
fn decrypt_keyring(
session_key: &[u8; 32],
encrypted: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
if encrypted.len() < NONCE_SIZE + 16 {
return Err("Too small".into());
}
let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(session_key)?;
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| format!("Decryption failed: {e}").into())
}
#[test]
fn test_load_credentials_with_file_reference() {
let dir = TempDir::new().unwrap();
let secret_path = dir.path().join("gcp-creds.json");
std::fs::write(
&secret_path,
r#"{"type":"service_account","project_id":"test"}"#,
)
.unwrap();
let creds_path = dir.path().join("credentials");
std::fs::write(
&creds_path,
format!(
r#"{{"inline_key": "simple_value", "gcp_credentials": "@file:{}"}}"#,
secret_path.display()
),
)
.unwrap();
let keyring = ati::core::keyring::Keyring::load_credentials(&creds_path).unwrap();
assert_eq!(keyring.get("inline_key").unwrap(), "simple_value");
assert!(keyring
.get("gcp_credentials")
.unwrap()
.contains("service_account"));
}
#[test]
fn test_load_credentials_file_reference_trims_whitespace() {
let dir = TempDir::new().unwrap();
let secret_path = dir.path().join("api-key.txt");
std::fs::write(&secret_path, "my_secret_key\n").unwrap();
let creds_path = dir.path().join("credentials");
std::fs::write(
&creds_path,
format!(r#"{{"api_key": "@file:{}"}}"#, secret_path.display()),
)
.unwrap();
let keyring = ati::core::keyring::Keyring::load_credentials(&creds_path).unwrap();
assert_eq!(keyring.get("api_key").unwrap(), "my_secret_key");
}
#[test]
fn test_load_credentials_file_reference_missing_file() {
let dir = TempDir::new().unwrap();
let creds_path = dir.path().join("credentials");
std::fs::write(
&creds_path,
r#"{"bad_ref": "@file:/nonexistent/path/secret.txt"}"#,
)
.unwrap();
let result = ati::core::keyring::Keyring::load_credentials(&creds_path);
assert!(
result.is_err(),
"should fail when referenced file is missing"
);
}