1use std::collections::{BTreeMap, HashMap};
21use std::sync::Arc;
22
23use chrono::{DateTime, Utc};
24use parking_lot::RwLock;
25
26use base64::Engine;
27
28use crate::state::{KmsKey, KmsState, SharedKmsState};
29
30#[derive(Clone, serde::Serialize)]
34pub struct KmsUsageRecord {
35 pub timestamp: DateTime<Utc>,
36 pub operation: String,
37 pub service_principal: String,
38 pub account_id: String,
39 pub key_arn: String,
40 pub encryption_context: HashMap<String, String>,
41}
42
43#[derive(Default)]
44pub struct KmsUsageState {
45 records: Vec<KmsUsageRecord>,
46}
47
48impl KmsUsageState {
49 pub fn records(&self) -> &[KmsUsageRecord] {
50 &self.records
51 }
52
53 pub fn push(&mut self, record: KmsUsageRecord) {
54 self.records.push(record);
55 }
56
57 pub fn clear(&mut self) {
58 self.records.clear();
59 }
60}
61
62pub type SharedKmsUsageState = Arc<RwLock<KmsUsageState>>;
63
64pub struct KmsServiceHook {
67 state: SharedKmsState,
68 usage: SharedKmsUsageState,
69}
70
71#[derive(Debug)]
72pub enum KmsHookError {
73 KeyNotFound(String),
76 InvalidCiphertext(String),
79}
80
81impl std::fmt::Display for KmsHookError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Self::KeyNotFound(k) => write!(f, "kms key not found: {k}"),
85 Self::InvalidCiphertext(msg) => write!(f, "invalid ciphertext: {msg}"),
86 }
87 }
88}
89
90impl std::error::Error for KmsHookError {}
91
92impl KmsServiceHook {
93 pub fn new(state: SharedKmsState, usage: SharedKmsUsageState) -> Self {
94 Self { state, usage }
95 }
96
97 pub fn encrypt(
102 &self,
103 account_id: &str,
104 region: &str,
105 key_id: &str,
106 plaintext: &[u8],
107 service_principal: &str,
108 encryption_context: HashMap<String, String>,
109 ) -> Result<String, KmsHookError> {
110 let key_arn = self.resolve_or_provision(account_id, region, key_id, service_principal)?;
111 let key_short = key_id_from_arn(&key_arn).to_string();
112
113 let master_key_bytes = {
119 let mas = self.state.read();
120 mas.get(account_id)
121 .map(|s| s.master_key_bytes.clone())
122 .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?
123 };
124 let blob = crate::blob::encode(&master_key_bytes, &key_short, plaintext);
125 let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(&blob);
126
127 self.usage.write().push(KmsUsageRecord {
128 timestamp: Utc::now(),
129 operation: "GenerateDataKey".to_string(),
130 service_principal: service_principal.to_string(),
131 account_id: account_id.to_string(),
132 key_arn,
133 encryption_context,
134 });
135
136 Ok(ciphertext_b64)
137 }
138
139 pub fn decrypt(
142 &self,
143 account_id: &str,
144 ciphertext_b64: &str,
145 service_principal: &str,
146 encryption_context: HashMap<String, String>,
147 ) -> Result<Vec<u8>, KmsHookError> {
148 let envelope_bytes = base64::engine::general_purpose::STANDARD
149 .decode(ciphertext_b64)
150 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
151
152 let master_key_bytes = {
155 let mas = self.state.read();
156 mas.get(account_id)
157 .map(|s| s.master_key_bytes.clone())
158 .unwrap_or_default()
159 };
160 let (key_short, plaintext) =
161 if let Some(decoded) = crate::blob::decode(&master_key_bytes, &envelope_bytes) {
162 (decoded.key_id, decoded.plaintext)
163 } else {
164 let envelope = String::from_utf8(envelope_bytes)
165 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
166 let rest = envelope.strip_prefix("fakecloud-kms:").ok_or_else(|| {
167 KmsHookError::InvalidCiphertext("unrecognized envelope".into())
168 })?;
169 let (key_short, plaintext_b64) = rest.split_once(':').ok_or_else(|| {
170 KmsHookError::InvalidCiphertext("missing key separator".into())
171 })?;
172
173 let plaintext = base64::engine::general_purpose::STANDARD
174 .decode(plaintext_b64)
175 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
176 (key_short.to_string(), plaintext)
177 };
178
179 let key_arn = {
180 let mas = self.state.read();
181 let state = mas
182 .get(account_id)
183 .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?;
184 state
185 .keys
186 .get(&key_short)
187 .map(|k| k.arn.clone())
188 .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?
189 };
190
191 self.usage.write().push(KmsUsageRecord {
192 timestamp: Utc::now(),
193 operation: "Decrypt".to_string(),
194 service_principal: service_principal.to_string(),
195 account_id: account_id.to_string(),
196 key_arn,
197 encryption_context,
198 });
199
200 Ok(plaintext)
201 }
202
203 fn resolve_or_provision(
204 &self,
205 account_id: &str,
206 region: &str,
207 key_id: &str,
208 service_principal: &str,
209 ) -> Result<String, KmsHookError> {
210 {
212 let mas = self.state.read();
213 if let Some(state) = mas.get(account_id) {
214 if let Some(arn) = resolve_key(state, key_id) {
215 return Ok(arn);
216 }
217 }
218 }
219
220 let alias = normalize_alias(key_id);
224 if !alias.starts_with("aws/") {
225 return Err(KmsHookError::KeyNotFound(key_id.to_string()));
226 }
227 let mut mas = self.state.write();
228 let state = mas.get_or_create(account_id);
229 if let Some(arn) = resolve_key(state, key_id) {
231 return Ok(arn);
232 }
233 let key_arn = provision_aws_managed_key(state, region, &alias, service_principal);
234 Ok(key_arn)
235 }
236}
237
238fn strip_kms_arn_prefix(key_id: &str) -> Option<&str> {
242 let rest = key_id.strip_prefix("arn:aws:kms:")?;
243 let (_region, after_region) = rest.split_once(':')?;
247 let (_account, resource) = after_region.split_once(':')?;
248 Some(resource)
249}
250
251fn resolve_key(state: &KmsState, key_id: &str) -> Option<String> {
254 if let Some(resource) = strip_kms_arn_prefix(key_id) {
255 if let Some(short) = resource.strip_prefix("key/") {
256 return state.keys.get(short).map(|k| k.arn.clone());
257 }
258 if let Some(alias) = resource.strip_prefix("alias/") {
259 let full = format!("alias/{alias}");
260 if let Some(a) = state.aliases.get(&full) {
261 return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
262 }
263 }
264 }
265 if let Some(alias) = key_id.strip_prefix("alias/") {
266 let full = format!("alias/{alias}");
267 if let Some(a) = state.aliases.get(&full) {
268 return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
269 }
270 }
271 state.keys.get(key_id).map(|k| k.arn.clone())
272}
273
274fn normalize_alias(key_id: &str) -> String {
275 if let Some(resource) = strip_kms_arn_prefix(key_id) {
276 if let Some(alias) = resource.strip_prefix("alias/") {
277 return alias.to_string();
278 }
279 }
280 key_id.strip_prefix("alias/").unwrap_or(key_id).to_string()
281}
282
283pub const DEFAULT_AWS_MANAGED_ALIASES: &[&str] = &[
290 "aws/dynamodb",
291 "aws/s3",
292 "aws/sqs",
293 "aws/sns",
294 "aws/secretsmanager",
295 "aws/ssm",
296 "aws/rds",
297 "aws/lambda",
298 "aws/kinesis",
299 "aws/logs",
300 "aws/ebs",
301 "aws/glue",
302 "aws/elasticache",
303 "aws/backup",
304 "aws/es",
305 "aws/redshift",
306 "aws/xray",
307 "aws/elasticfilesystem",
308 "aws/cloudtrail",
309 "aws/sagemaker",
310];
311
312pub fn ensure_default_managed_aliases(state: &mut KmsState, region: &str) {
317 for alias in DEFAULT_AWS_MANAGED_ALIASES {
318 let alias_full = format!("alias/{alias}");
319 if state.aliases.contains_key(&alias_full) {
320 continue;
321 }
322 let service = alias.strip_prefix("aws/").unwrap_or(alias);
327 let principal = format!("{service}.amazonaws.com");
328 provision_aws_managed_key(state, region, alias, &principal);
329 }
330}
331
332fn provision_aws_managed_key(
333 state: &mut KmsState,
334 region: &str,
335 alias: &str,
336 service_principal: &str,
337) -> String {
338 let key_id = uuid::Uuid::new_v4().to_string();
339 let arn = format!(
340 "arn:aws:kms:{region}:{account}:key/{key_id}",
341 account = state.account_id,
342 region = region,
343 );
344 let policy = serde_json::json!({
345 "Version": "2012-10-17",
346 "Statement": [{
347 "Sid": "Allow access through service",
348 "Effect": "Allow",
349 "Principal": {"Service": service_principal},
350 "Action": ["kms:GenerateDataKey", "kms:Decrypt", "kms:DescribeKey"],
351 "Resource": "*"
352 }]
353 })
354 .to_string();
355 let key = KmsKey {
356 key_id: key_id.clone(),
357 arn: arn.clone(),
358 creation_date: Utc::now().timestamp() as f64,
359 description: format!(
360 "Default master key that protects {alias} when no other key is defined"
361 ),
362 enabled: true,
363 key_usage: "ENCRYPT_DECRYPT".to_string(),
364 key_spec: "SYMMETRIC_DEFAULT".to_string(),
365 key_manager: "AWS".to_string(),
366 key_state: "Enabled".to_string(),
367 deletion_date: None,
368 tags: BTreeMap::new(),
369 policy,
370 key_rotation_enabled: true,
371 rotation_period_in_days: None,
372 origin: "AWS_KMS".to_string(),
373 multi_region: false,
374 rotations: Vec::new(),
375 signing_algorithms: None,
376 encryption_algorithms: Some(vec!["SYMMETRIC_DEFAULT".to_string()]),
377 mac_algorithms: None,
378 custom_key_store_id: None,
379 imported_key_material: false,
380 imported_material_bytes: None,
381 private_key_seed: Vec::new(),
382 primary_region: None,
383 asymmetric_private_key_der: None,
384 asymmetric_public_key_der: None,
385 };
386 state.keys.insert(key_id.clone(), key);
387 let alias_full = format!("alias/{alias}");
388 state.aliases.insert(
389 alias_full.clone(),
390 crate::state::KmsAlias {
391 alias_name: alias_full,
392 alias_arn: format!(
393 "arn:aws:kms:{region}:{account}:alias/{alias}",
394 account = state.account_id,
395 region = region,
396 ),
397 target_key_id: key_id,
398 creation_date: Utc::now().timestamp() as f64,
399 },
400 );
401 arn
402}
403
404fn key_id_from_arn(arn: &str) -> &str {
405 arn.rsplit_once('/').map(|(_, k)| k).unwrap_or(arn)
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn strip_arn_prefix_skips_region_and_account() {
414 assert_eq!(
415 strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:key/abc-123"),
416 Some("key/abc-123")
417 );
418 assert_eq!(
419 strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
420 Some("alias/aws/secretsmanager")
421 );
422 assert_eq!(strip_kms_arn_prefix("not-an-arn"), None);
423 assert_eq!(strip_kms_arn_prefix("arn:aws:kms:key/abc"), None);
425 }
426
427 #[test]
428 fn normalize_alias_handles_arns_correctly() {
429 assert_eq!(
430 normalize_alias("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
431 "aws/secretsmanager"
432 );
433 assert_eq!(normalize_alias("alias/aws/sqs"), "aws/sqs");
434 assert_eq!(normalize_alias("aws/s3"), "aws/s3");
435 }
436}