1use aes_gcm::aead::{Aead, KeyInit, Payload};
15use aes_gcm::{Aes256Gcm, Key, Nonce};
16use sha2::{Digest, Sha256};
17
18use crate::state::{KmsKey, SharedKmsState};
19
20#[derive(Debug, thiserror::Error)]
21pub enum KmsApiError {
22 #[error("KMS key {0} not found")]
23 KeyNotFound(String),
24 #[error("KMS key {0} is not enabled")]
25 KeyDisabled(String),
26 #[error("encryption failed: {0}")]
27 EncryptFailed(String),
28 #[error("decryption failed: {0}")]
29 DecryptFailed(String),
30 #[error("malformed ciphertext envelope")]
31 MalformedCiphertext,
32}
33
34pub fn encrypt_blob(
38 state: &SharedKmsState,
39 account_id: &str,
40 key_ref: &str,
41 plaintext: &[u8],
42) -> Result<Vec<u8>, KmsApiError> {
43 let (key_arn, aes_key) = {
44 let accounts = state.read();
45 let s = accounts
46 .get(account_id)
47 .ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
48 let key =
49 lookup_key(s, key_ref).ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
50 if !key.enabled {
51 return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
52 }
53 (key.arn.clone(), derive_aes_key(key))
54 };
55
56 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
57 let iv = random_iv();
58 let nonce = Nonce::from_slice(&iv);
59 let ciphertext = cipher
60 .encrypt(
61 nonce,
62 Payload {
63 msg: plaintext,
64 aad: key_arn.as_bytes(),
65 },
66 )
67 .map_err(|e| KmsApiError::EncryptFailed(e.to_string()))?;
68
69 let arn_bytes = key_arn.as_bytes();
70 let arn_len = arn_bytes.len() as u16;
71 let mut out = Vec::with_capacity(2 + arn_bytes.len() + 12 + ciphertext.len());
72 out.extend_from_slice(&arn_len.to_be_bytes());
73 out.extend_from_slice(arn_bytes);
74 out.extend_from_slice(&iv);
75 out.extend_from_slice(&ciphertext);
76 Ok(out)
77}
78
79pub fn decrypt_blob(
81 state: &SharedKmsState,
82 account_id: &str,
83 ciphertext: &[u8],
84) -> Result<Vec<u8>, KmsApiError> {
85 if ciphertext.len() < 2 {
86 return Err(KmsApiError::MalformedCiphertext);
87 }
88 let arn_len = u16::from_be_bytes([ciphertext[0], ciphertext[1]]) as usize;
89 let header_end = 2 + arn_len;
90 if ciphertext.len() < header_end + 12 + 16 {
91 return Err(KmsApiError::MalformedCiphertext);
92 }
93 let key_arn = std::str::from_utf8(&ciphertext[2..header_end])
94 .map_err(|_| KmsApiError::MalformedCiphertext)?;
95 let iv = &ciphertext[header_end..header_end + 12];
96 let body = &ciphertext[header_end + 12..];
97
98 let aes_key = {
99 let accounts = state.read();
100 let s = accounts
101 .get(account_id)
102 .ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
103 let key =
104 lookup_key(s, key_arn).ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
105 if !key.enabled {
106 return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
107 }
108 derive_aes_key(key)
109 };
110
111 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
112 let nonce = Nonce::from_slice(iv);
113 cipher
114 .decrypt(
115 nonce,
116 Payload {
117 msg: body,
118 aad: key_arn.as_bytes(),
119 },
120 )
121 .map_err(|e| KmsApiError::DecryptFailed(e.to_string()))
122}
123
124fn lookup_key<'a>(s: &'a crate::state::KmsState, key_ref: &str) -> Option<&'a KmsKey> {
127 if let Some(alias) = key_ref.strip_prefix("alias/") {
128 let full = format!("alias/{alias}");
129 let target = s
130 .aliases
131 .values()
132 .find(|a| a.alias_name == full)
133 .map(|a| a.target_key_id.as_str())?;
134 return s.keys.get(target);
135 }
136 if let Some(id) = key_ref.rsplit(':').next() {
137 if let Some(stripped) = id.strip_prefix("key/") {
138 return s.keys.get(stripped);
139 }
140 if let Some(k) = s.keys.get(id) {
141 return Some(k);
142 }
143 }
144 s.keys.get(key_ref)
145}
146
147fn derive_aes_key(key: &KmsKey) -> [u8; 32] {
153 let source: &[u8] = key
154 .imported_material_bytes
155 .as_deref()
156 .unwrap_or(&key.private_key_seed);
157 let mut hasher = Sha256::new();
158 hasher.update(b"fakecloud-kms-aes256:");
159 hasher.update(key.key_id.as_bytes());
160 hasher.update(b":");
161 hasher.update(source);
162 let out = hasher.finalize();
163 let mut aes_key = [0u8; 32];
164 aes_key.copy_from_slice(&out[..]);
165 aes_key
166}
167
168fn random_iv() -> [u8; 12] {
169 use std::sync::atomic::{AtomicU64, Ordering};
174 static COUNTER: AtomicU64 = AtomicU64::new(0);
175 let ts = std::time::SystemTime::now()
176 .duration_since(std::time::UNIX_EPOCH)
177 .map(|d| d.as_nanos() as u64)
178 .unwrap_or(0);
179 let cnt = COUNTER.fetch_add(1, Ordering::Relaxed);
180 let mut hasher = Sha256::new();
181 hasher.update(ts.to_be_bytes());
182 hasher.update(cnt.to_be_bytes());
183 hasher.update(std::process::id().to_be_bytes());
184 let digest = hasher.finalize();
185 let mut iv = [0u8; 12];
186 iv.copy_from_slice(&digest[..12]);
187 iv
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::Arc;
193
194 use fakecloud_core::multi_account::MultiAccountState;
195 use parking_lot::RwLock;
196
197 use super::*;
198 use crate::state::{KmsKey, KmsState};
199
200 fn make_state_with_key() -> (SharedKmsState, String) {
201 let state = Arc::new(RwLock::new(MultiAccountState::<KmsState>::new(
202 "123456789012",
203 "us-east-1",
204 "http://localhost:4566",
205 )));
206 let key_id = "00000000-0000-0000-0000-000000000001".to_string();
207 let arn = format!("arn:aws:kms:us-east-1:123456789012:key/{key_id}");
208 {
209 let mut accounts = state.write();
210 let s = accounts.get_or_create("123456789012");
211 s.keys.insert(
212 key_id.clone(),
213 KmsKey {
214 key_id: key_id.clone(),
215 arn: arn.clone(),
216 creation_date: 0.0,
217 description: String::new(),
218 enabled: true,
219 key_usage: "ENCRYPT_DECRYPT".into(),
220 key_spec: "SYMMETRIC_DEFAULT".into(),
221 key_manager: "CUSTOMER".into(),
222 key_state: "Enabled".into(),
223 deletion_date: None,
224 tags: Default::default(),
225 policy: String::new(),
226 key_rotation_enabled: false,
227 origin: "AWS_KMS".into(),
228 multi_region: false,
229 rotations: Vec::new(),
230 signing_algorithms: None,
231 encryption_algorithms: None,
232 mac_algorithms: None,
233 custom_key_store_id: None,
234 imported_key_material: false,
235 imported_material_bytes: None,
236 private_key_seed: vec![7; 32],
237 primary_region: None,
238 },
239 );
240 }
241 (state, arn)
242 }
243
244 #[test]
245 fn encrypt_decrypt_roundtrip() {
246 let (state, arn) = make_state_with_key();
247 let plaintext = b"hello fakecloud kms";
248 let ct = encrypt_blob(&state, "123456789012", &arn, plaintext).unwrap();
249 assert_ne!(&ct[..], plaintext, "ciphertext must differ from plaintext");
250 let pt = decrypt_blob(&state, "123456789012", &ct).unwrap();
251 assert_eq!(pt.as_slice(), plaintext);
252 }
253
254 #[test]
255 fn each_encrypt_yields_distinct_ciphertext() {
256 let (state, arn) = make_state_with_key();
257 let a = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
258 let b = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
259 assert_ne!(a, b, "distinct IVs should produce distinct ciphertext");
260 }
261
262 #[test]
263 fn decrypt_with_tampered_ciphertext_fails() {
264 let (state, arn) = make_state_with_key();
265 let mut ct = encrypt_blob(&state, "123456789012", &arn, b"tamper me").unwrap();
266 let last = ct.len() - 1;
268 ct[last] ^= 0x01;
269 assert!(decrypt_blob(&state, "123456789012", &ct).is_err());
270 }
271
272 #[test]
273 fn decrypt_with_disabled_key_fails() {
274 let (state, arn) = make_state_with_key();
275 let ct = encrypt_blob(&state, "123456789012", &arn, b"ok").unwrap();
276 {
277 let mut accounts = state.write();
278 let s = accounts.get_mut("123456789012").unwrap();
279 for k in s.keys.values_mut() {
280 k.enabled = false;
281 }
282 }
283 assert!(matches!(
284 decrypt_blob(&state, "123456789012", &ct),
285 Err(KmsApiError::KeyDisabled(_))
286 ));
287 }
288}