Skip to main content

fakecloud_kms/
api.rs

1//! Crate-level KMS encrypt/decrypt API for cross-service callers.
2//!
3//! Real AES-256-GCM with a fresh 12-byte IV per call and an
4//! authenticated tag. Envelope format:
5//!
6//! ```text
7//! | key_arn_len:u16_be | key_arn_utf8 | iv:12 | ciphertext_with_tag |
8//! ```
9//!
10//! The key ARN is embedded so decryption callers can pass opaque
11//! ciphertext back without tracking the key separately — matching how
12//! AWS's KMS blob format self-describes.
13
14use aes_gcm::aead::{Aead, KeyInit, Payload};
15use aes_gcm::{Aes256Gcm, Key, Nonce};
16use sha2::{Digest, Sha256};
17
18use crate::state::{KmsKey, SharedKmsState};
19
20#[derive(Debug, thiserror::Error)]
21pub enum KmsApiError {
22    #[error("KMS key {0} not found")]
23    KeyNotFound(String),
24    #[error("KMS key {0} is not enabled")]
25    KeyDisabled(String),
26    #[error("encryption failed: {0}")]
27    EncryptFailed(String),
28    #[error("decryption failed: {0}")]
29    DecryptFailed(String),
30    #[error("malformed ciphertext envelope")]
31    MalformedCiphertext,
32}
33
34/// Encrypt `plaintext` under the AES-256 key derived from `key_ref`
35/// (key id or ARN). Returns an envelope that `decrypt_blob` accepts
36/// without needing the key-ref passed again.
37pub fn encrypt_blob(
38    state: &SharedKmsState,
39    account_id: &str,
40    key_ref: &str,
41    plaintext: &[u8],
42) -> Result<Vec<u8>, KmsApiError> {
43    let (key_arn, aes_key) = {
44        let accounts = state.read();
45        let s = accounts
46            .get(account_id)
47            .ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
48        let key =
49            lookup_key(s, key_ref).ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
50        if !key.enabled {
51            return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
52        }
53        (key.arn.clone(), derive_aes_key(key))
54    };
55
56    let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
57    let iv = random_iv();
58    let nonce = Nonce::from_slice(&iv);
59    let ciphertext = cipher
60        .encrypt(
61            nonce,
62            Payload {
63                msg: plaintext,
64                aad: key_arn.as_bytes(),
65            },
66        )
67        .map_err(|e| KmsApiError::EncryptFailed(e.to_string()))?;
68
69    let arn_bytes = key_arn.as_bytes();
70    let arn_len = arn_bytes.len() as u16;
71    let mut out = Vec::with_capacity(2 + arn_bytes.len() + 12 + ciphertext.len());
72    out.extend_from_slice(&arn_len.to_be_bytes());
73    out.extend_from_slice(arn_bytes);
74    out.extend_from_slice(&iv);
75    out.extend_from_slice(&ciphertext);
76    Ok(out)
77}
78
79/// Decrypt a blob produced by `encrypt_blob`.
80pub fn decrypt_blob(
81    state: &SharedKmsState,
82    account_id: &str,
83    ciphertext: &[u8],
84) -> Result<Vec<u8>, KmsApiError> {
85    if ciphertext.len() < 2 {
86        return Err(KmsApiError::MalformedCiphertext);
87    }
88    let arn_len = u16::from_be_bytes([ciphertext[0], ciphertext[1]]) as usize;
89    let header_end = 2 + arn_len;
90    if ciphertext.len() < header_end + 12 + 16 {
91        return Err(KmsApiError::MalformedCiphertext);
92    }
93    let key_arn = std::str::from_utf8(&ciphertext[2..header_end])
94        .map_err(|_| KmsApiError::MalformedCiphertext)?;
95    let iv = &ciphertext[header_end..header_end + 12];
96    let body = &ciphertext[header_end + 12..];
97
98    let aes_key = {
99        let accounts = state.read();
100        let s = accounts
101            .get(account_id)
102            .ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
103        let key =
104            lookup_key(s, key_arn).ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
105        if !key.enabled {
106            return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
107        }
108        derive_aes_key(key)
109    };
110
111    let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
112    let nonce = Nonce::from_slice(iv);
113    cipher
114        .decrypt(
115            nonce,
116            Payload {
117                msg: body,
118                aad: key_arn.as_bytes(),
119            },
120        )
121        .map_err(|e| KmsApiError::DecryptFailed(e.to_string()))
122}
123
124/// Resolve `key_ref` against a KMS state. Accepts key id, ARN, or
125/// alias-name (`alias/<name>`). Returns the canonical key.
126fn lookup_key<'a>(s: &'a crate::state::KmsState, key_ref: &str) -> Option<&'a KmsKey> {
127    if let Some(alias) = key_ref.strip_prefix("alias/") {
128        let full = format!("alias/{alias}");
129        let target = s
130            .aliases
131            .values()
132            .find(|a| a.alias_name == full)
133            .map(|a| a.target_key_id.as_str())?;
134        return s.keys.get(target);
135    }
136    if let Some(id) = key_ref.rsplit(':').next() {
137        if let Some(stripped) = id.strip_prefix("key/") {
138            return s.keys.get(stripped);
139        }
140        if let Some(k) = s.keys.get(id) {
141            return Some(k);
142        }
143    }
144    s.keys.get(key_ref)
145}
146
147/// Derive a stable 32-byte AES key from a KmsKey. Priority:
148/// `imported_material_bytes` (when caller used `ImportKeyMaterial`),
149/// else the `private_key_seed` which every CreateKey populates.
150/// Hashed with SHA-256 so we always end up with the right length
151/// regardless of the source length.
152fn derive_aes_key(key: &KmsKey) -> [u8; 32] {
153    let source: &[u8] = key
154        .imported_material_bytes
155        .as_deref()
156        .unwrap_or(&key.private_key_seed);
157    let mut hasher = Sha256::new();
158    hasher.update(b"fakecloud-kms-aes256:");
159    hasher.update(key.key_id.as_bytes());
160    hasher.update(b":");
161    hasher.update(source);
162    let out = hasher.finalize();
163    let mut aes_key = [0u8; 32];
164    aes_key.copy_from_slice(&out[..]);
165    aes_key
166}
167
168fn random_iv() -> [u8; 12] {
169    // aes-gcm re-exports an rng trait; keep a small local RNG using
170    // timestamp + key-independent entropy. CI fakes don't need
171    // cryptographic-quality randomness, but each IV must be unique
172    // per key+plaintext pair, so mix in a monotonic counter.
173    use std::sync::atomic::{AtomicU64, Ordering};
174    static COUNTER: AtomicU64 = AtomicU64::new(0);
175    let ts = std::time::SystemTime::now()
176        .duration_since(std::time::UNIX_EPOCH)
177        .map(|d| d.as_nanos() as u64)
178        .unwrap_or(0);
179    let cnt = COUNTER.fetch_add(1, Ordering::Relaxed);
180    let mut hasher = Sha256::new();
181    hasher.update(ts.to_be_bytes());
182    hasher.update(cnt.to_be_bytes());
183    hasher.update(std::process::id().to_be_bytes());
184    let digest = hasher.finalize();
185    let mut iv = [0u8; 12];
186    iv.copy_from_slice(&digest[..12]);
187    iv
188}
189
190#[cfg(test)]
191mod tests {
192    use std::sync::Arc;
193
194    use fakecloud_aws::arn::Arn;
195    use fakecloud_core::multi_account::MultiAccountState;
196    use parking_lot::RwLock;
197
198    use super::*;
199    use crate::state::{KmsKey, KmsState};
200
201    fn make_state_with_key() -> (SharedKmsState, String) {
202        let state = Arc::new(RwLock::new(MultiAccountState::<KmsState>::new(
203            "123456789012",
204            "us-east-1",
205            "http://localhost:4566",
206        )));
207        let key_id = "00000000-0000-0000-0000-000000000001".to_string();
208        let arn =
209            Arn::new("kms", "us-east-1", "123456789012", &format!("key/{key_id}")).to_string();
210        {
211            let mut accounts = state.write();
212            let s = accounts.get_or_create("123456789012");
213            s.keys.insert(
214                key_id.clone(),
215                KmsKey {
216                    key_id: key_id.clone(),
217                    arn: arn.clone(),
218                    creation_date: 0.0,
219                    description: String::new(),
220                    enabled: true,
221                    key_usage: "ENCRYPT_DECRYPT".into(),
222                    key_spec: "SYMMETRIC_DEFAULT".into(),
223                    key_manager: "CUSTOMER".into(),
224                    key_state: "Enabled".into(),
225                    deletion_date: None,
226                    tags: Default::default(),
227                    policy: String::new(),
228                    key_rotation_enabled: false,
229                    rotation_period_in_days: None,
230                    origin: "AWS_KMS".into(),
231                    multi_region: false,
232                    rotations: Vec::new(),
233                    signing_algorithms: None,
234                    encryption_algorithms: None,
235                    mac_algorithms: None,
236                    custom_key_store_id: None,
237                    imported_key_material: false,
238                    imported_material_bytes: None,
239                    private_key_seed: vec![7; 32],
240                    primary_region: None,
241                    asymmetric_private_key_der: None,
242                    asymmetric_public_key_der: None,
243                },
244            );
245        }
246        (state, arn)
247    }
248
249    #[test]
250    fn encrypt_decrypt_roundtrip() {
251        let (state, arn) = make_state_with_key();
252        let plaintext = b"hello fakecloud kms";
253        let ct = encrypt_blob(&state, "123456789012", &arn, plaintext).unwrap();
254        assert_ne!(&ct[..], plaintext, "ciphertext must differ from plaintext");
255        let pt = decrypt_blob(&state, "123456789012", &ct).unwrap();
256        assert_eq!(pt.as_slice(), plaintext);
257    }
258
259    #[test]
260    fn each_encrypt_yields_distinct_ciphertext() {
261        let (state, arn) = make_state_with_key();
262        let a = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
263        let b = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
264        assert_ne!(a, b, "distinct IVs should produce distinct ciphertext");
265    }
266
267    #[test]
268    fn decrypt_with_tampered_ciphertext_fails() {
269        let (state, arn) = make_state_with_key();
270        let mut ct = encrypt_blob(&state, "123456789012", &arn, b"tamper me").unwrap();
271        // Flip a bit in the payload region.
272        let last = ct.len() - 1;
273        ct[last] ^= 0x01;
274        assert!(decrypt_blob(&state, "123456789012", &ct).is_err());
275    }
276
277    #[test]
278    fn decrypt_with_disabled_key_fails() {
279        let (state, arn) = make_state_with_key();
280        let ct = encrypt_blob(&state, "123456789012", &arn, b"ok").unwrap();
281        {
282            let mut accounts = state.write();
283            let s = accounts.get_mut("123456789012").unwrap();
284            for k in s.keys.values_mut() {
285                k.enabled = false;
286            }
287        }
288        assert!(matches!(
289            decrypt_blob(&state, "123456789012", &ct),
290            Err(KmsApiError::KeyDisabled(_))
291        ));
292    }
293}