1use std::collections::HashMap;
21use std::sync::Arc;
22
23use chrono::{DateTime, Utc};
24use parking_lot::RwLock;
25
26use base64::Engine;
27
28use crate::state::{KmsKey, KmsState, SharedKmsState};
29
30#[derive(Clone, serde::Serialize)]
34pub struct KmsUsageRecord {
35 pub timestamp: DateTime<Utc>,
36 pub operation: String,
37 pub service_principal: String,
38 pub account_id: String,
39 pub key_arn: String,
40 pub encryption_context: HashMap<String, String>,
41}
42
43#[derive(Default)]
44pub struct KmsUsageState {
45 records: Vec<KmsUsageRecord>,
46}
47
48impl KmsUsageState {
49 pub fn records(&self) -> &[KmsUsageRecord] {
50 &self.records
51 }
52
53 pub fn push(&mut self, record: KmsUsageRecord) {
54 self.records.push(record);
55 }
56
57 pub fn clear(&mut self) {
58 self.records.clear();
59 }
60}
61
62pub type SharedKmsUsageState = Arc<RwLock<KmsUsageState>>;
63
64pub struct KmsServiceHook {
67 state: SharedKmsState,
68 usage: SharedKmsUsageState,
69}
70
71#[derive(Debug)]
72pub enum KmsHookError {
73 KeyNotFound(String),
76 InvalidCiphertext(String),
79}
80
81impl std::fmt::Display for KmsHookError {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 match self {
84 Self::KeyNotFound(k) => write!(f, "kms key not found: {k}"),
85 Self::InvalidCiphertext(msg) => write!(f, "invalid ciphertext: {msg}"),
86 }
87 }
88}
89
90impl std::error::Error for KmsHookError {}
91
92impl KmsServiceHook {
93 pub fn new(state: SharedKmsState, usage: SharedKmsUsageState) -> Self {
94 Self { state, usage }
95 }
96
97 pub fn encrypt(
102 &self,
103 account_id: &str,
104 region: &str,
105 key_id: &str,
106 plaintext: &[u8],
107 service_principal: &str,
108 encryption_context: HashMap<String, String>,
109 ) -> Result<String, KmsHookError> {
110 let key_arn = self.resolve_or_provision(account_id, region, key_id, service_principal)?;
111 let key_short = key_id_from_arn(&key_arn).to_string();
112
113 let plaintext_b64 = base64::engine::general_purpose::STANDARD.encode(plaintext);
114 let envelope = format!("fakecloud-kms:{key_short}:{plaintext_b64}");
115 let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(envelope.as_bytes());
116
117 self.usage.write().push(KmsUsageRecord {
118 timestamp: Utc::now(),
119 operation: "GenerateDataKey".to_string(),
120 service_principal: service_principal.to_string(),
121 account_id: account_id.to_string(),
122 key_arn,
123 encryption_context,
124 });
125
126 Ok(ciphertext_b64)
127 }
128
129 pub fn decrypt(
132 &self,
133 account_id: &str,
134 ciphertext_b64: &str,
135 service_principal: &str,
136 encryption_context: HashMap<String, String>,
137 ) -> Result<Vec<u8>, KmsHookError> {
138 let envelope_bytes = base64::engine::general_purpose::STANDARD
139 .decode(ciphertext_b64)
140 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
141 let envelope = String::from_utf8(envelope_bytes)
142 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
143
144 let rest = envelope
145 .strip_prefix("fakecloud-kms:")
146 .ok_or_else(|| KmsHookError::InvalidCiphertext("unrecognized envelope".into()))?;
147 let (key_short, plaintext_b64) = rest
148 .split_once(':')
149 .ok_or_else(|| KmsHookError::InvalidCiphertext("missing key separator".into()))?;
150
151 let plaintext = base64::engine::general_purpose::STANDARD
152 .decode(plaintext_b64)
153 .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
154
155 let key_arn = {
156 let mas = self.state.read();
157 let state = mas
158 .get(account_id)
159 .ok_or_else(|| KmsHookError::KeyNotFound(key_short.into()))?;
160 state
161 .keys
162 .get(key_short)
163 .map(|k| k.arn.clone())
164 .ok_or_else(|| KmsHookError::KeyNotFound(key_short.into()))?
165 };
166
167 self.usage.write().push(KmsUsageRecord {
168 timestamp: Utc::now(),
169 operation: "Decrypt".to_string(),
170 service_principal: service_principal.to_string(),
171 account_id: account_id.to_string(),
172 key_arn,
173 encryption_context,
174 });
175
176 Ok(plaintext)
177 }
178
179 fn resolve_or_provision(
180 &self,
181 account_id: &str,
182 region: &str,
183 key_id: &str,
184 service_principal: &str,
185 ) -> Result<String, KmsHookError> {
186 {
188 let mas = self.state.read();
189 if let Some(state) = mas.get(account_id) {
190 if let Some(arn) = resolve_key(state, key_id) {
191 return Ok(arn);
192 }
193 }
194 }
195
196 let alias = normalize_alias(key_id);
200 if !alias.starts_with("aws/") {
201 return Err(KmsHookError::KeyNotFound(key_id.to_string()));
202 }
203 let mut mas = self.state.write();
204 let state = mas.get_or_create(account_id);
205 if let Some(arn) = resolve_key(state, key_id) {
207 return Ok(arn);
208 }
209 let key_arn = provision_aws_managed_key(state, region, &alias, service_principal);
210 Ok(key_arn)
211 }
212}
213
214fn strip_kms_arn_prefix(key_id: &str) -> Option<&str> {
218 let rest = key_id.strip_prefix("arn:aws:kms:")?;
219 let (_region, after_region) = rest.split_once(':')?;
223 let (_account, resource) = after_region.split_once(':')?;
224 Some(resource)
225}
226
227fn resolve_key(state: &KmsState, key_id: &str) -> Option<String> {
230 if let Some(resource) = strip_kms_arn_prefix(key_id) {
231 if let Some(short) = resource.strip_prefix("key/") {
232 return state.keys.get(short).map(|k| k.arn.clone());
233 }
234 if let Some(alias) = resource.strip_prefix("alias/") {
235 let full = format!("alias/{alias}");
236 if let Some(a) = state.aliases.get(&full) {
237 return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
238 }
239 }
240 }
241 if let Some(alias) = key_id.strip_prefix("alias/") {
242 let full = format!("alias/{alias}");
243 if let Some(a) = state.aliases.get(&full) {
244 return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
245 }
246 }
247 state.keys.get(key_id).map(|k| k.arn.clone())
248}
249
250fn normalize_alias(key_id: &str) -> String {
251 if let Some(resource) = strip_kms_arn_prefix(key_id) {
252 if let Some(alias) = resource.strip_prefix("alias/") {
253 return alias.to_string();
254 }
255 }
256 key_id.strip_prefix("alias/").unwrap_or(key_id).to_string()
257}
258
259fn provision_aws_managed_key(
260 state: &mut KmsState,
261 region: &str,
262 alias: &str,
263 service_principal: &str,
264) -> String {
265 let key_id = uuid::Uuid::new_v4().to_string();
266 let arn = format!(
267 "arn:aws:kms:{region}:{account}:key/{key_id}",
268 account = state.account_id,
269 region = region,
270 );
271 let policy = serde_json::json!({
272 "Version": "2012-10-17",
273 "Statement": [{
274 "Sid": "Allow access through service",
275 "Effect": "Allow",
276 "Principal": {"Service": service_principal},
277 "Action": ["kms:GenerateDataKey", "kms:Decrypt", "kms:DescribeKey"],
278 "Resource": "*"
279 }]
280 })
281 .to_string();
282 let key = KmsKey {
283 key_id: key_id.clone(),
284 arn: arn.clone(),
285 creation_date: Utc::now().timestamp() as f64,
286 description: format!(
287 "Default master key that protects {alias} when no other key is defined"
288 ),
289 enabled: true,
290 key_usage: "ENCRYPT_DECRYPT".to_string(),
291 key_spec: "SYMMETRIC_DEFAULT".to_string(),
292 key_manager: "AWS".to_string(),
293 key_state: "Enabled".to_string(),
294 deletion_date: None,
295 tags: HashMap::new(),
296 policy,
297 key_rotation_enabled: true,
298 origin: "AWS_KMS".to_string(),
299 multi_region: false,
300 rotations: Vec::new(),
301 signing_algorithms: None,
302 encryption_algorithms: Some(vec!["SYMMETRIC_DEFAULT".to_string()]),
303 mac_algorithms: None,
304 custom_key_store_id: None,
305 imported_key_material: false,
306 imported_material_bytes: None,
307 private_key_seed: Vec::new(),
308 primary_region: None,
309 };
310 state.keys.insert(key_id.clone(), key);
311 let alias_full = format!("alias/{alias}");
312 state.aliases.insert(
313 alias_full.clone(),
314 crate::state::KmsAlias {
315 alias_name: alias_full,
316 alias_arn: format!(
317 "arn:aws:kms:{region}:{account}:alias/{alias}",
318 account = state.account_id,
319 region = region,
320 ),
321 target_key_id: key_id,
322 creation_date: Utc::now().timestamp() as f64,
323 },
324 );
325 arn
326}
327
328fn key_id_from_arn(arn: &str) -> &str {
329 arn.rsplit_once('/').map(|(_, k)| k).unwrap_or(arn)
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn strip_arn_prefix_skips_region_and_account() {
338 assert_eq!(
339 strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:key/abc-123"),
340 Some("key/abc-123")
341 );
342 assert_eq!(
343 strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
344 Some("alias/aws/secretsmanager")
345 );
346 assert_eq!(strip_kms_arn_prefix("not-an-arn"), None);
347 assert_eq!(strip_kms_arn_prefix("arn:aws:kms:key/abc"), None);
349 }
350
351 #[test]
352 fn normalize_alias_handles_arns_correctly() {
353 assert_eq!(
354 normalize_alias("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
355 "aws/secretsmanager"
356 );
357 assert_eq!(normalize_alias("alias/aws/sqs"), "aws/sqs");
358 assert_eq!(normalize_alias("aws/s3"), "aws/s3");
359 }
360}