Skip to main content

fakecloud_kms/
hook.rs

1//! Cross-service KMS hook.
2//!
3//! Services that accept a `KmsKeyId` (Secrets Manager, SSM
4//! `SecureString`, S3 SSE-KMS, SQS, SNS, DynamoDB) call into this
5//! module so that:
6//!
7//! 1. The supplied key is resolved (alias `aws/<service>` and bare
8//!    aliases included), auto-provisioning AWS-managed keys on first
9//!    use to match real AWS.
10//! 2. Each call is recorded in [`KmsUsageState`] so test code can
11//!    assert through `/_fakecloud/kms/usage` that the right service
12//!    triggered the right operation on the right key.
13//! 3. The returned ciphertext is a real envelope decryptable by the
14//!    public KMS `Decrypt` API (uses the same `fakecloud-kms:`
15//!    envelope as the existing service-side encrypt path).
16//!
17//! Encryption context, key policy enforcement, and KMS-managed key
18//! rotation come in follow-up PRs.
19
20use std::collections::HashMap;
21use std::sync::Arc;
22
23use chrono::{DateTime, Utc};
24use parking_lot::RwLock;
25
26use base64::Engine;
27
28use crate::state::{KmsKey, KmsState, SharedKmsState};
29
30/// One recorded KMS hook call. Returned by the introspection endpoint
31/// so test code can assert `kms:GenerateDataKey` / `kms:Decrypt` ran
32/// on the expected key + service principal.
33#[derive(Clone, serde::Serialize)]
34pub struct KmsUsageRecord {
35    pub timestamp: DateTime<Utc>,
36    pub operation: String,
37    pub service_principal: String,
38    pub account_id: String,
39    pub key_arn: String,
40    pub encryption_context: HashMap<String, String>,
41}
42
43#[derive(Default)]
44pub struct KmsUsageState {
45    records: Vec<KmsUsageRecord>,
46}
47
48impl KmsUsageState {
49    pub fn records(&self) -> &[KmsUsageRecord] {
50        &self.records
51    }
52
53    pub fn push(&mut self, record: KmsUsageRecord) {
54        self.records.push(record);
55    }
56
57    pub fn clear(&mut self) {
58        self.records.clear();
59    }
60}
61
62pub type SharedKmsUsageState = Arc<RwLock<KmsUsageState>>;
63
64/// Hook used by service crates that need to call KMS for encryption /
65/// decryption without going through the AWS-shaped HTTP layer.
66pub struct KmsServiceHook {
67    state: SharedKmsState,
68    usage: SharedKmsUsageState,
69}
70
71#[derive(Debug)]
72pub enum KmsHookError {
73    /// Caller supplied a key id / alias / ARN that doesn't resolve to
74    /// an existing key (and isn't an AWS-managed alias we auto-create).
75    KeyNotFound(String),
76    /// Ciphertext envelope is malformed or signed by a key that no
77    /// longer exists.
78    InvalidCiphertext(String),
79}
80
81impl std::fmt::Display for KmsHookError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            Self::KeyNotFound(k) => write!(f, "kms key not found: {k}"),
85            Self::InvalidCiphertext(msg) => write!(f, "invalid ciphertext: {msg}"),
86        }
87    }
88}
89
90impl std::error::Error for KmsHookError {}
91
92impl KmsServiceHook {
93    pub fn new(state: SharedKmsState, usage: SharedKmsUsageState) -> Self {
94        Self { state, usage }
95    }
96
97    /// Encrypt `plaintext` under `key_id` (raw id, ARN, alias, or
98    /// `aws/<service>` AWS-managed alias). Records the call as a
99    /// `GenerateDataKey`-shaped usage record and returns the base64
100    /// ciphertext envelope.
101    pub fn encrypt(
102        &self,
103        account_id: &str,
104        region: &str,
105        key_id: &str,
106        plaintext: &[u8],
107        service_principal: &str,
108        encryption_context: HashMap<String, String>,
109    ) -> Result<String, KmsHookError> {
110        let key_arn = self.resolve_or_provision(account_id, region, key_id, service_principal)?;
111        let key_short = key_id_from_arn(&key_arn).to_string();
112
113        let plaintext_b64 = base64::engine::general_purpose::STANDARD.encode(plaintext);
114        let envelope = format!("fakecloud-kms:{key_short}:{plaintext_b64}");
115        let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(envelope.as_bytes());
116
117        self.usage.write().push(KmsUsageRecord {
118            timestamp: Utc::now(),
119            operation: "GenerateDataKey".to_string(),
120            service_principal: service_principal.to_string(),
121            account_id: account_id.to_string(),
122            key_arn,
123            encryption_context,
124        });
125
126        Ok(ciphertext_b64)
127    }
128
129    /// Decrypt a previously-`encrypt`-produced base64 ciphertext.
130    /// Records the call as a `Decrypt`-shaped usage record.
131    pub fn decrypt(
132        &self,
133        account_id: &str,
134        ciphertext_b64: &str,
135        service_principal: &str,
136        encryption_context: HashMap<String, String>,
137    ) -> Result<Vec<u8>, KmsHookError> {
138        let envelope_bytes = base64::engine::general_purpose::STANDARD
139            .decode(ciphertext_b64)
140            .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
141        let envelope = String::from_utf8(envelope_bytes)
142            .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
143
144        let rest = envelope
145            .strip_prefix("fakecloud-kms:")
146            .ok_or_else(|| KmsHookError::InvalidCiphertext("unrecognized envelope".into()))?;
147        let (key_short, plaintext_b64) = rest
148            .split_once(':')
149            .ok_or_else(|| KmsHookError::InvalidCiphertext("missing key separator".into()))?;
150
151        let plaintext = base64::engine::general_purpose::STANDARD
152            .decode(plaintext_b64)
153            .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
154
155        let key_arn = {
156            let mas = self.state.read();
157            let state = mas
158                .get(account_id)
159                .ok_or_else(|| KmsHookError::KeyNotFound(key_short.into()))?;
160            state
161                .keys
162                .get(key_short)
163                .map(|k| k.arn.clone())
164                .ok_or_else(|| KmsHookError::KeyNotFound(key_short.into()))?
165        };
166
167        self.usage.write().push(KmsUsageRecord {
168            timestamp: Utc::now(),
169            operation: "Decrypt".to_string(),
170            service_principal: service_principal.to_string(),
171            account_id: account_id.to_string(),
172            key_arn,
173            encryption_context,
174        });
175
176        Ok(plaintext)
177    }
178
179    fn resolve_or_provision(
180        &self,
181        account_id: &str,
182        region: &str,
183        key_id: &str,
184        service_principal: &str,
185    ) -> Result<String, KmsHookError> {
186        // Pre-flight read to see if the key resolves cleanly.
187        {
188            let mas = self.state.read();
189            if let Some(state) = mas.get(account_id) {
190                if let Some(arn) = resolve_key(state, key_id) {
191                    return Ok(arn);
192                }
193            }
194        }
195
196        // AWS-managed aliases (`aws/<service>`) auto-provision on
197        // first use. Customer-supplied aliases / IDs that don't
198        // resolve are an error.
199        let alias = normalize_alias(key_id);
200        if !alias.starts_with("aws/") {
201            return Err(KmsHookError::KeyNotFound(key_id.to_string()));
202        }
203        let mut mas = self.state.write();
204        let state = mas.get_or_create(account_id);
205        // Re-check under the write lock in case a concurrent caller won the race.
206        if let Some(arn) = resolve_key(state, key_id) {
207            return Ok(arn);
208        }
209        let key_arn = provision_aws_managed_key(state, region, &alias, service_principal);
210        Ok(key_arn)
211    }
212}
213
214/// Strip the `arn:aws:kms:<region>:<account>:` ARN prefix and return
215/// the resource portion (e.g. `key/<id>` or `alias/<name>`). Returns
216/// `None` for ARNs that don't have the right shape.
217fn strip_kms_arn_prefix(key_id: &str) -> Option<&str> {
218    let rest = key_id.strip_prefix("arn:aws:kms:")?;
219    // Format after prefix: <region>:<account>:<resource>. Need to skip
220    // both `region` and `account` separately so the resource starts
221    // cleanly at `key/...` or `alias/...`.
222    let (_region, after_region) = rest.split_once(':')?;
223    let (_account, resource) = after_region.split_once(':')?;
224    Some(resource)
225}
226
227/// Resolve `key_id` (raw id, alias name, alias ARN, or key ARN) to the
228/// full key ARN if it currently exists in `state`.
229fn resolve_key(state: &KmsState, key_id: &str) -> Option<String> {
230    if let Some(resource) = strip_kms_arn_prefix(key_id) {
231        if let Some(short) = resource.strip_prefix("key/") {
232            return state.keys.get(short).map(|k| k.arn.clone());
233        }
234        if let Some(alias) = resource.strip_prefix("alias/") {
235            let full = format!("alias/{alias}");
236            if let Some(a) = state.aliases.get(&full) {
237                return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
238            }
239        }
240    }
241    if let Some(alias) = key_id.strip_prefix("alias/") {
242        let full = format!("alias/{alias}");
243        if let Some(a) = state.aliases.get(&full) {
244            return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
245        }
246    }
247    state.keys.get(key_id).map(|k| k.arn.clone())
248}
249
250fn normalize_alias(key_id: &str) -> String {
251    if let Some(resource) = strip_kms_arn_prefix(key_id) {
252        if let Some(alias) = resource.strip_prefix("alias/") {
253            return alias.to_string();
254        }
255    }
256    key_id.strip_prefix("alias/").unwrap_or(key_id).to_string()
257}
258
259fn provision_aws_managed_key(
260    state: &mut KmsState,
261    region: &str,
262    alias: &str,
263    service_principal: &str,
264) -> String {
265    let key_id = uuid::Uuid::new_v4().to_string();
266    let arn = format!(
267        "arn:aws:kms:{region}:{account}:key/{key_id}",
268        account = state.account_id,
269        region = region,
270    );
271    let policy = serde_json::json!({
272        "Version": "2012-10-17",
273        "Statement": [{
274            "Sid": "Allow access through service",
275            "Effect": "Allow",
276            "Principal": {"Service": service_principal},
277            "Action": ["kms:GenerateDataKey", "kms:Decrypt", "kms:DescribeKey"],
278            "Resource": "*"
279        }]
280    })
281    .to_string();
282    let key = KmsKey {
283        key_id: key_id.clone(),
284        arn: arn.clone(),
285        creation_date: Utc::now().timestamp() as f64,
286        description: format!(
287            "Default master key that protects {alias} when no other key is defined"
288        ),
289        enabled: true,
290        key_usage: "ENCRYPT_DECRYPT".to_string(),
291        key_spec: "SYMMETRIC_DEFAULT".to_string(),
292        key_manager: "AWS".to_string(),
293        key_state: "Enabled".to_string(),
294        deletion_date: None,
295        tags: HashMap::new(),
296        policy,
297        key_rotation_enabled: true,
298        origin: "AWS_KMS".to_string(),
299        multi_region: false,
300        rotations: Vec::new(),
301        signing_algorithms: None,
302        encryption_algorithms: Some(vec!["SYMMETRIC_DEFAULT".to_string()]),
303        mac_algorithms: None,
304        custom_key_store_id: None,
305        imported_key_material: false,
306        imported_material_bytes: None,
307        private_key_seed: Vec::new(),
308        primary_region: None,
309    };
310    state.keys.insert(key_id.clone(), key);
311    let alias_full = format!("alias/{alias}");
312    state.aliases.insert(
313        alias_full.clone(),
314        crate::state::KmsAlias {
315            alias_name: alias_full,
316            alias_arn: format!(
317                "arn:aws:kms:{region}:{account}:alias/{alias}",
318                account = state.account_id,
319                region = region,
320            ),
321            target_key_id: key_id,
322            creation_date: Utc::now().timestamp() as f64,
323        },
324    );
325    arn
326}
327
328fn key_id_from_arn(arn: &str) -> &str {
329    arn.rsplit_once('/').map(|(_, k)| k).unwrap_or(arn)
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn strip_arn_prefix_skips_region_and_account() {
338        assert_eq!(
339            strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:key/abc-123"),
340            Some("key/abc-123")
341        );
342        assert_eq!(
343            strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
344            Some("alias/aws/secretsmanager")
345        );
346        assert_eq!(strip_kms_arn_prefix("not-an-arn"), None);
347        // Missing one of region/account should return None, not a half-stripped resource.
348        assert_eq!(strip_kms_arn_prefix("arn:aws:kms:key/abc"), None);
349    }
350
351    #[test]
352    fn normalize_alias_handles_arns_correctly() {
353        assert_eq!(
354            normalize_alias("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
355            "aws/secretsmanager"
356        );
357        assert_eq!(normalize_alias("alias/aws/sqs"), "aws/sqs");
358        assert_eq!(normalize_alias("aws/s3"), "aws/s3");
359    }
360}