use serde::{Deserialize, Serialize};
use crate::error::{CipherError, Result};
const ENVELOPE_V2: &str = "dugout-envelope-v2";
#[derive(Debug, Clone, PartialEq)]
#[allow(dead_code)]
pub enum KmsProvider {
Aws,
Gcp,
}
#[allow(dead_code)]
impl KmsProvider {
pub fn detect(key: &str) -> Option<Self> {
if key.starts_with("arn:aws:kms:") {
return Some(Self::Aws);
}
if key.starts_with("projects/") && key.contains("/cryptoKeys/") {
return Some(Self::Gcp);
}
None
}
pub fn name(&self) -> &'static str {
match self {
Self::Aws => "aws",
Self::Gcp => "gcp",
}
}
}
#[derive(Debug, Serialize, Deserialize)]
#[allow(dead_code)]
pub struct Envelope {
version: String,
pub age: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub kms: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub provider: Option<String>,
}
#[allow(dead_code)]
impl Envelope {
pub fn new(
age_ciphertext: String,
kms_ciphertext: Option<String>,
provider: Option<&KmsProvider>,
) -> Self {
Self {
version: ENVELOPE_V2.to_string(),
age: age_ciphertext,
kms: kms_ciphertext,
provider: provider.map(|p| p.name().to_string()),
}
}
pub fn seal(&self) -> Result<String> {
serde_json::to_string(self).map_err(|e| {
CipherError::EncryptionFailed(format!("failed to serialize envelope: {}", e)).into()
})
}
pub fn parse(ciphertext: &str) -> Option<Self> {
let envelope: Self = serde_json::from_str(ciphertext).ok()?;
if envelope.version == ENVELOPE_V2 {
Some(envelope)
} else {
None
}
}
pub fn is_envelope(ciphertext: &str) -> bool {
ciphertext.starts_with('{') && ciphertext.contains(ENVELOPE_V2)
}
}
#[allow(dead_code)]
pub trait KmsBackend: std::fmt::Debug {
fn encrypt(&self, plaintext: &str) -> Result<String>;
fn decrypt(&self, ciphertext: &str) -> Result<String>;
fn provider(&self) -> &KmsProvider;
}
#[cfg(any(test, feature = "test-kms"))]
#[derive(Debug)]
pub struct StubKms;
#[cfg(any(test, feature = "test-kms"))]
impl KmsBackend for StubKms {
fn encrypt(&self, plaintext: &str) -> Result<String> {
let hex: String = plaintext.bytes().map(|b| format!("{:02x}", b)).collect();
Ok(format!("stub-kms:{}", hex))
}
fn decrypt(&self, ciphertext: &str) -> Result<String> {
let hex = ciphertext.strip_prefix("stub-kms:").ok_or_else(|| {
CipherError::DecryptionFailed("not a stub-kms ciphertext".to_string())
})?;
let bytes: std::result::Result<Vec<u8>, _> = (0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16))
.collect();
let bytes =
bytes.map_err(|e| CipherError::DecryptionFailed(format!("invalid hex: {}", e)))?;
String::from_utf8(bytes)
.map_err(|e| CipherError::DecryptionFailed(format!("invalid utf8: {}", e)).into())
}
fn provider(&self) -> &KmsProvider {
&KmsProvider::Aws
}
}
#[cfg(any(test, feature = "test-kms"))]
mod tests {
#[allow(unused_imports)]
use super::{Envelope, KmsBackend, KmsProvider, StubKms};
#[test]
fn test_detect_aws_arn() {
let key = "arn:aws:kms:us-east-1:123456789012:key/abc-123";
assert_eq!(KmsProvider::detect(key), Some(KmsProvider::Aws));
}
#[test]
fn test_detect_aws_alias() {
let key = "arn:aws:kms:eu-west-1:999:alias/my-key";
assert_eq!(KmsProvider::detect(key), Some(KmsProvider::Aws));
}
#[test]
fn test_detect_gcp_resource() {
let key = "projects/my-project/locations/global/keyRings/my-ring/cryptoKeys/my-key";
assert_eq!(KmsProvider::detect(key), Some(KmsProvider::Gcp));
}
#[test]
fn test_detect_invalid() {
assert_eq!(KmsProvider::detect("not-a-kms-key"), None);
assert_eq!(KmsProvider::detect(""), None);
assert_eq!(KmsProvider::detect("arn:aws:s3:::bucket"), None);
assert_eq!(KmsProvider::detect("projects/foo"), None);
}
#[test]
fn test_envelope_roundtrip() {
let envelope = Envelope::new(
"age-ciphertext-here".to_string(),
Some("kms-ciphertext-here".to_string()),
Some(&KmsProvider::Aws),
);
let sealed = envelope.seal().unwrap();
let parsed = Envelope::parse(&sealed).unwrap();
assert_eq!(parsed.age, "age-ciphertext-here");
assert_eq!(parsed.kms.unwrap(), "kms-ciphertext-here");
assert_eq!(parsed.provider.unwrap(), "aws");
}
#[test]
fn test_envelope_age_only() {
let envelope = Envelope::new("age-ciphertext".to_string(), None, None);
let sealed = envelope.seal().unwrap();
let parsed = Envelope::parse(&sealed).unwrap();
assert_eq!(parsed.age, "age-ciphertext");
assert!(parsed.kms.is_none());
assert!(parsed.provider.is_none());
}
#[test]
fn test_envelope_parse_raw_age_returns_none() {
let raw = "-----BEGIN AGE ENCRYPTED FILE-----\ntest\n-----END AGE ENCRYPTED FILE-----";
assert!(Envelope::parse(raw).is_none());
}
#[test]
fn test_envelope_is_envelope() {
let envelope = Envelope::new("age".to_string(), None, None);
let sealed = envelope.seal().unwrap();
assert!(Envelope::is_envelope(&sealed));
assert!(!Envelope::is_envelope("raw age ciphertext"));
}
#[test]
fn test_stub_kms_roundtrip() {
let stub = StubKms;
let encrypted = stub.encrypt("secret-value").unwrap();
let decrypted = stub.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, "secret-value");
}
#[test]
fn test_stub_kms_invalid_ciphertext() {
let stub = StubKms;
assert!(stub.decrypt("not-valid-base64!!!").is_err());
}
}