1use aes_gcm::aead::{Aead, KeyInit, Payload};
15use aes_gcm::{Aes256Gcm, Key, Nonce};
16use sha2::{Digest, Sha256};
17
18use crate::state::{KmsKey, SharedKmsState};
19
20#[derive(Debug, thiserror::Error)]
21pub enum KmsApiError {
22 #[error("KMS key {0} not found")]
23 KeyNotFound(String),
24 #[error("KMS key {0} is not enabled")]
25 KeyDisabled(String),
26 #[error("encryption failed: {0}")]
27 EncryptFailed(String),
28 #[error("decryption failed: {0}")]
29 DecryptFailed(String),
30 #[error("malformed ciphertext envelope")]
31 MalformedCiphertext,
32}
33
34pub fn encrypt_blob(
38 state: &SharedKmsState,
39 account_id: &str,
40 key_ref: &str,
41 plaintext: &[u8],
42) -> Result<Vec<u8>, KmsApiError> {
43 let (key_arn, aes_key) = {
44 let accounts = state.read();
45 let s = accounts
46 .get(account_id)
47 .ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
48 let key =
49 lookup_key(s, key_ref).ok_or_else(|| KmsApiError::KeyNotFound(key_ref.to_string()))?;
50 if !key.enabled {
51 return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
52 }
53 (key.arn.clone(), derive_aes_key(key))
54 };
55
56 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
57 let iv = random_iv();
58 let nonce = Nonce::from_slice(&iv);
59 let ciphertext = cipher
60 .encrypt(
61 nonce,
62 Payload {
63 msg: plaintext,
64 aad: key_arn.as_bytes(),
65 },
66 )
67 .map_err(|e| KmsApiError::EncryptFailed(e.to_string()))?;
68
69 let arn_bytes = key_arn.as_bytes();
70 let arn_len = arn_bytes.len() as u16;
71 let mut out = Vec::with_capacity(2 + arn_bytes.len() + 12 + ciphertext.len());
72 out.extend_from_slice(&arn_len.to_be_bytes());
73 out.extend_from_slice(arn_bytes);
74 out.extend_from_slice(&iv);
75 out.extend_from_slice(&ciphertext);
76 Ok(out)
77}
78
79pub fn decrypt_blob(
81 state: &SharedKmsState,
82 account_id: &str,
83 ciphertext: &[u8],
84) -> Result<Vec<u8>, KmsApiError> {
85 if ciphertext.len() < 2 {
86 return Err(KmsApiError::MalformedCiphertext);
87 }
88 let arn_len = u16::from_be_bytes([ciphertext[0], ciphertext[1]]) as usize;
89 let header_end = 2 + arn_len;
90 if ciphertext.len() < header_end + 12 + 16 {
91 return Err(KmsApiError::MalformedCiphertext);
92 }
93 let key_arn = std::str::from_utf8(&ciphertext[2..header_end])
94 .map_err(|_| KmsApiError::MalformedCiphertext)?;
95 let iv = &ciphertext[header_end..header_end + 12];
96 let body = &ciphertext[header_end + 12..];
97
98 let aes_key = {
99 let accounts = state.read();
100 let s = accounts
101 .get(account_id)
102 .ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
103 let key =
104 lookup_key(s, key_arn).ok_or_else(|| KmsApiError::KeyNotFound(key_arn.to_string()))?;
105 if !key.enabled {
106 return Err(KmsApiError::KeyDisabled(key.key_id.clone()));
107 }
108 derive_aes_key(key)
109 };
110
111 let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&aes_key));
112 let nonce = Nonce::from_slice(iv);
113 cipher
114 .decrypt(
115 nonce,
116 Payload {
117 msg: body,
118 aad: key_arn.as_bytes(),
119 },
120 )
121 .map_err(|e| KmsApiError::DecryptFailed(e.to_string()))
122}
123
124fn lookup_key<'a>(s: &'a crate::state::KmsState, key_ref: &str) -> Option<&'a KmsKey> {
127 if let Some(alias) = key_ref.strip_prefix("alias/") {
128 let full = format!("alias/{alias}");
129 let target = s
130 .aliases
131 .values()
132 .find(|a| a.alias_name == full)
133 .map(|a| a.target_key_id.as_str())?;
134 return s.keys.get(target);
135 }
136 if let Some(id) = key_ref.rsplit(':').next() {
137 if let Some(stripped) = id.strip_prefix("key/") {
138 return s.keys.get(stripped);
139 }
140 if let Some(k) = s.keys.get(id) {
141 return Some(k);
142 }
143 }
144 s.keys.get(key_ref)
145}
146
147fn derive_aes_key(key: &KmsKey) -> [u8; 32] {
153 let source: &[u8] = key
154 .imported_material_bytes
155 .as_deref()
156 .unwrap_or(&key.private_key_seed);
157 let mut hasher = Sha256::new();
158 hasher.update(b"fakecloud-kms-aes256:");
159 hasher.update(key.key_id.as_bytes());
160 hasher.update(b":");
161 hasher.update(source);
162 let out = hasher.finalize();
163 let mut aes_key = [0u8; 32];
164 aes_key.copy_from_slice(&out[..]);
165 aes_key
166}
167
168fn random_iv() -> [u8; 12] {
169 use std::sync::atomic::{AtomicU64, Ordering};
174 static COUNTER: AtomicU64 = AtomicU64::new(0);
175 let ts = std::time::SystemTime::now()
176 .duration_since(std::time::UNIX_EPOCH)
177 .map(|d| d.as_nanos() as u64)
178 .unwrap_or(0);
179 let cnt = COUNTER.fetch_add(1, Ordering::Relaxed);
180 let mut hasher = Sha256::new();
181 hasher.update(ts.to_be_bytes());
182 hasher.update(cnt.to_be_bytes());
183 hasher.update(std::process::id().to_be_bytes());
184 let digest = hasher.finalize();
185 let mut iv = [0u8; 12];
186 iv.copy_from_slice(&digest[..12]);
187 iv
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::Arc;
193
194 use fakecloud_aws::arn::Arn;
195 use fakecloud_core::multi_account::MultiAccountState;
196 use parking_lot::RwLock;
197
198 use super::*;
199 use crate::state::{KmsKey, KmsState};
200
201 fn make_state_with_key() -> (SharedKmsState, String) {
202 let state = Arc::new(RwLock::new(MultiAccountState::<KmsState>::new(
203 "123456789012",
204 "us-east-1",
205 "http://localhost:4566",
206 )));
207 let key_id = "00000000-0000-0000-0000-000000000001".to_string();
208 let arn =
209 Arn::new("kms", "us-east-1", "123456789012", &format!("key/{key_id}")).to_string();
210 {
211 let mut accounts = state.write();
212 let s = accounts.get_or_create("123456789012");
213 s.keys.insert(
214 key_id.clone(),
215 KmsKey {
216 key_id: key_id.clone(),
217 arn: arn.clone(),
218 creation_date: 0.0,
219 description: String::new(),
220 enabled: true,
221 key_usage: "ENCRYPT_DECRYPT".into(),
222 key_spec: "SYMMETRIC_DEFAULT".into(),
223 key_manager: "CUSTOMER".into(),
224 key_state: "Enabled".into(),
225 deletion_date: None,
226 tags: Default::default(),
227 policy: String::new(),
228 key_rotation_enabled: false,
229 rotation_period_in_days: None,
230 origin: "AWS_KMS".into(),
231 multi_region: false,
232 rotations: Vec::new(),
233 signing_algorithms: None,
234 encryption_algorithms: None,
235 mac_algorithms: None,
236 custom_key_store_id: None,
237 imported_key_material: false,
238 imported_material_bytes: None,
239 private_key_seed: vec![7; 32],
240 primary_region: None,
241 asymmetric_private_key_der: None,
242 asymmetric_public_key_der: None,
243 },
244 );
245 }
246 (state, arn)
247 }
248
249 #[test]
250 fn encrypt_decrypt_roundtrip() {
251 let (state, arn) = make_state_with_key();
252 let plaintext = b"hello fakecloud kms";
253 let ct = encrypt_blob(&state, "123456789012", &arn, plaintext).unwrap();
254 assert_ne!(&ct[..], plaintext, "ciphertext must differ from plaintext");
255 let pt = decrypt_blob(&state, "123456789012", &ct).unwrap();
256 assert_eq!(pt.as_slice(), plaintext);
257 }
258
259 #[test]
260 fn each_encrypt_yields_distinct_ciphertext() {
261 let (state, arn) = make_state_with_key();
262 let a = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
263 let b = encrypt_blob(&state, "123456789012", &arn, b"same plaintext").unwrap();
264 assert_ne!(a, b, "distinct IVs should produce distinct ciphertext");
265 }
266
267 #[test]
268 fn decrypt_with_tampered_ciphertext_fails() {
269 let (state, arn) = make_state_with_key();
270 let mut ct = encrypt_blob(&state, "123456789012", &arn, b"tamper me").unwrap();
271 let last = ct.len() - 1;
273 ct[last] ^= 0x01;
274 assert!(decrypt_blob(&state, "123456789012", &ct).is_err());
275 }
276
277 #[test]
278 fn decrypt_with_disabled_key_fails() {
279 let (state, arn) = make_state_with_key();
280 let ct = encrypt_blob(&state, "123456789012", &arn, b"ok").unwrap();
281 {
282 let mut accounts = state.write();
283 let s = accounts.get_mut("123456789012").unwrap();
284 for k in s.keys.values_mut() {
285 k.enabled = false;
286 }
287 }
288 assert!(matches!(
289 decrypt_blob(&state, "123456789012", &ct),
290 Err(KmsApiError::KeyDisabled(_))
291 ));
292 }
293}