Skip to main content

fakecloud_kms/
hook.rs

1//! Cross-service KMS hook.
2//!
3//! Services that accept a `KmsKeyId` (Secrets Manager, SSM
4//! `SecureString`, S3 SSE-KMS, SQS, SNS, DynamoDB) call into this
5//! module so that:
6//!
7//! 1. The supplied key is resolved (alias `aws/<service>` and bare
8//!    aliases included), auto-provisioning AWS-managed keys on first
9//!    use to match real AWS.
10//! 2. Each call is recorded in [`KmsUsageState`] so test code can
11//!    assert through `/_fakecloud/kms/usage` that the right service
12//!    triggered the right operation on the right key.
13//! 3. The returned ciphertext is a real envelope decryptable by the
14//!    public KMS `Decrypt` API (uses the same `fakecloud-kms:`
15//!    envelope as the existing service-side encrypt path).
16//!
17//! Encryption context, key policy enforcement, and KMS-managed key
18//! rotation come in follow-up PRs.
19
20use std::collections::{BTreeMap, HashMap};
21use std::sync::Arc;
22
23use chrono::{DateTime, Utc};
24use parking_lot::RwLock;
25
26use base64::Engine;
27
28use crate::state::{KmsKey, KmsState, SharedKmsState};
29
30/// One recorded KMS hook call. Returned by the introspection endpoint
31/// so test code can assert `kms:GenerateDataKey` / `kms:Decrypt` ran
32/// on the expected key + service principal.
33#[derive(Clone, serde::Serialize)]
34pub struct KmsUsageRecord {
35    pub timestamp: DateTime<Utc>,
36    pub operation: String,
37    pub service_principal: String,
38    pub account_id: String,
39    pub key_arn: String,
40    pub encryption_context: HashMap<String, String>,
41}
42
43#[derive(Default)]
44pub struct KmsUsageState {
45    records: Vec<KmsUsageRecord>,
46}
47
48impl KmsUsageState {
49    pub fn records(&self) -> &[KmsUsageRecord] {
50        &self.records
51    }
52
53    pub fn push(&mut self, record: KmsUsageRecord) {
54        self.records.push(record);
55    }
56
57    pub fn clear(&mut self) {
58        self.records.clear();
59    }
60}
61
62pub type SharedKmsUsageState = Arc<RwLock<KmsUsageState>>;
63
64/// Hook used by service crates that need to call KMS for encryption /
65/// decryption without going through the AWS-shaped HTTP layer.
66pub struct KmsServiceHook {
67    state: SharedKmsState,
68    usage: SharedKmsUsageState,
69}
70
71#[derive(Debug)]
72pub enum KmsHookError {
73    /// Caller supplied a key id / alias / ARN that doesn't resolve to
74    /// an existing key (and isn't an AWS-managed alias we auto-create).
75    KeyNotFound(String),
76    /// Ciphertext envelope is malformed or signed by a key that no
77    /// longer exists.
78    InvalidCiphertext(String),
79}
80
81impl std::fmt::Display for KmsHookError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            Self::KeyNotFound(k) => write!(f, "kms key not found: {k}"),
85            Self::InvalidCiphertext(msg) => write!(f, "invalid ciphertext: {msg}"),
86        }
87    }
88}
89
90impl std::error::Error for KmsHookError {}
91
92impl KmsServiceHook {
93    pub fn new(state: SharedKmsState, usage: SharedKmsUsageState) -> Self {
94        Self { state, usage }
95    }
96
97    /// Encrypt `plaintext` under `key_id` (raw id, ARN, alias, or
98    /// `aws/<service>` AWS-managed alias). Records the call as a
99    /// `GenerateDataKey`-shaped usage record and returns the base64
100    /// ciphertext envelope.
101    pub fn encrypt(
102        &self,
103        account_id: &str,
104        region: &str,
105        key_id: &str,
106        plaintext: &[u8],
107        service_principal: &str,
108        encryption_context: HashMap<String, String>,
109    ) -> Result<String, KmsHookError> {
110        let key_arn = self.resolve_or_provision(account_id, region, key_id, service_principal)?;
111        let key_short = key_id_from_arn(&key_arn).to_string();
112
113        // Default to the AWS-shaped binary blob (AES-256-GCM under the
114        // per-account master key persisted in `KmsState`). The legacy
115        // `fakecloud-kms:<key>:<b64>` textual envelope is still accepted on
116        // the decrypt side for back-compat with older snapshots and
117        // external callers.
118        let master_key_bytes = {
119            let mas = self.state.read();
120            mas.get(account_id)
121                .map(|s| s.master_key_bytes.clone())
122                .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?
123        };
124        let blob = crate::blob::encode(&master_key_bytes, &key_short, plaintext);
125        let ciphertext_b64 = base64::engine::general_purpose::STANDARD.encode(&blob);
126
127        self.usage.write().push(KmsUsageRecord {
128            timestamp: Utc::now(),
129            operation: "GenerateDataKey".to_string(),
130            service_principal: service_principal.to_string(),
131            account_id: account_id.to_string(),
132            key_arn,
133            encryption_context,
134        });
135
136        Ok(ciphertext_b64)
137    }
138
139    /// Decrypt a previously-`encrypt`-produced base64 ciphertext.
140    /// Records the call as a `Decrypt`-shaped usage record.
141    pub fn decrypt(
142        &self,
143        account_id: &str,
144        ciphertext_b64: &str,
145        service_principal: &str,
146        encryption_context: HashMap<String, String>,
147    ) -> Result<Vec<u8>, KmsHookError> {
148        let envelope_bytes = base64::engine::general_purpose::STANDARD
149            .decode(ciphertext_b64)
150            .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
151
152        // Try AWS-shaped binary blob first using the account's master key;
153        // older textual envelopes fall through to the legacy parser below.
154        let master_key_bytes = {
155            let mas = self.state.read();
156            mas.get(account_id)
157                .map(|s| s.master_key_bytes.clone())
158                .unwrap_or_default()
159        };
160        let (key_short, plaintext) =
161            if let Some(decoded) = crate::blob::decode(&master_key_bytes, &envelope_bytes) {
162                (decoded.key_id, decoded.plaintext)
163            } else {
164                let envelope = String::from_utf8(envelope_bytes)
165                    .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
166                let rest = envelope.strip_prefix("fakecloud-kms:").ok_or_else(|| {
167                    KmsHookError::InvalidCiphertext("unrecognized envelope".into())
168                })?;
169                let (key_short, plaintext_b64) = rest.split_once(':').ok_or_else(|| {
170                    KmsHookError::InvalidCiphertext("missing key separator".into())
171                })?;
172
173                let plaintext = base64::engine::general_purpose::STANDARD
174                    .decode(plaintext_b64)
175                    .map_err(|e| KmsHookError::InvalidCiphertext(e.to_string()))?;
176                (key_short.to_string(), plaintext)
177            };
178
179        let key_arn = {
180            let mas = self.state.read();
181            let state = mas
182                .get(account_id)
183                .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?;
184            state
185                .keys
186                .get(&key_short)
187                .map(|k| k.arn.clone())
188                .ok_or_else(|| KmsHookError::KeyNotFound(key_short.clone()))?
189        };
190
191        self.usage.write().push(KmsUsageRecord {
192            timestamp: Utc::now(),
193            operation: "Decrypt".to_string(),
194            service_principal: service_principal.to_string(),
195            account_id: account_id.to_string(),
196            key_arn,
197            encryption_context,
198        });
199
200        Ok(plaintext)
201    }
202
203    fn resolve_or_provision(
204        &self,
205        account_id: &str,
206        region: &str,
207        key_id: &str,
208        service_principal: &str,
209    ) -> Result<String, KmsHookError> {
210        // Pre-flight read to see if the key resolves cleanly.
211        {
212            let mas = self.state.read();
213            if let Some(state) = mas.get(account_id) {
214                if let Some(arn) = resolve_key(state, key_id) {
215                    return Ok(arn);
216                }
217            }
218        }
219
220        // AWS-managed aliases (`aws/<service>`) auto-provision on
221        // first use. Customer-supplied aliases / IDs that don't
222        // resolve are an error.
223        let alias = normalize_alias(key_id);
224        if !alias.starts_with("aws/") {
225            return Err(KmsHookError::KeyNotFound(key_id.to_string()));
226        }
227        let mut mas = self.state.write();
228        let state = mas.get_or_create(account_id);
229        // Re-check under the write lock in case a concurrent caller won the race.
230        if let Some(arn) = resolve_key(state, key_id) {
231            return Ok(arn);
232        }
233        let key_arn = provision_aws_managed_key(state, region, &alias, service_principal);
234        Ok(key_arn)
235    }
236}
237
238/// Strip the `arn:aws:kms:<region>:<account>:` ARN prefix and return
239/// the resource portion (e.g. `key/<id>` or `alias/<name>`). Returns
240/// `None` for ARNs that don't have the right shape.
241fn strip_kms_arn_prefix(key_id: &str) -> Option<&str> {
242    let rest = key_id.strip_prefix("arn:aws:kms:")?;
243    // Format after prefix: <region>:<account>:<resource>. Need to skip
244    // both `region` and `account` separately so the resource starts
245    // cleanly at `key/...` or `alias/...`.
246    let (_region, after_region) = rest.split_once(':')?;
247    let (_account, resource) = after_region.split_once(':')?;
248    Some(resource)
249}
250
251/// Resolve `key_id` (raw id, alias name, alias ARN, or key ARN) to the
252/// full key ARN if it currently exists in `state`.
253fn resolve_key(state: &KmsState, key_id: &str) -> Option<String> {
254    if let Some(resource) = strip_kms_arn_prefix(key_id) {
255        if let Some(short) = resource.strip_prefix("key/") {
256            return state.keys.get(short).map(|k| k.arn.clone());
257        }
258        if let Some(alias) = resource.strip_prefix("alias/") {
259            let full = format!("alias/{alias}");
260            if let Some(a) = state.aliases.get(&full) {
261                return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
262            }
263        }
264    }
265    if let Some(alias) = key_id.strip_prefix("alias/") {
266        let full = format!("alias/{alias}");
267        if let Some(a) = state.aliases.get(&full) {
268            return state.keys.get(&a.target_key_id).map(|k| k.arn.clone());
269        }
270    }
271    state.keys.get(key_id).map(|k| k.arn.clone())
272}
273
274fn normalize_alias(key_id: &str) -> String {
275    if let Some(resource) = strip_kms_arn_prefix(key_id) {
276        if let Some(alias) = resource.strip_prefix("alias/") {
277            return alias.to_string();
278        }
279    }
280    key_id.strip_prefix("alias/").unwrap_or(key_id).to_string()
281}
282
283fn provision_aws_managed_key(
284    state: &mut KmsState,
285    region: &str,
286    alias: &str,
287    service_principal: &str,
288) -> String {
289    let key_id = uuid::Uuid::new_v4().to_string();
290    let arn = format!(
291        "arn:aws:kms:{region}:{account}:key/{key_id}",
292        account = state.account_id,
293        region = region,
294    );
295    let policy = serde_json::json!({
296        "Version": "2012-10-17",
297        "Statement": [{
298            "Sid": "Allow access through service",
299            "Effect": "Allow",
300            "Principal": {"Service": service_principal},
301            "Action": ["kms:GenerateDataKey", "kms:Decrypt", "kms:DescribeKey"],
302            "Resource": "*"
303        }]
304    })
305    .to_string();
306    let key = KmsKey {
307        key_id: key_id.clone(),
308        arn: arn.clone(),
309        creation_date: Utc::now().timestamp() as f64,
310        description: format!(
311            "Default master key that protects {alias} when no other key is defined"
312        ),
313        enabled: true,
314        key_usage: "ENCRYPT_DECRYPT".to_string(),
315        key_spec: "SYMMETRIC_DEFAULT".to_string(),
316        key_manager: "AWS".to_string(),
317        key_state: "Enabled".to_string(),
318        deletion_date: None,
319        tags: BTreeMap::new(),
320        policy,
321        key_rotation_enabled: true,
322        origin: "AWS_KMS".to_string(),
323        multi_region: false,
324        rotations: Vec::new(),
325        signing_algorithms: None,
326        encryption_algorithms: Some(vec!["SYMMETRIC_DEFAULT".to_string()]),
327        mac_algorithms: None,
328        custom_key_store_id: None,
329        imported_key_material: false,
330        imported_material_bytes: None,
331        private_key_seed: Vec::new(),
332        primary_region: None,
333        asymmetric_private_key_der: None,
334        asymmetric_public_key_der: None,
335    };
336    state.keys.insert(key_id.clone(), key);
337    let alias_full = format!("alias/{alias}");
338    state.aliases.insert(
339        alias_full.clone(),
340        crate::state::KmsAlias {
341            alias_name: alias_full,
342            alias_arn: format!(
343                "arn:aws:kms:{region}:{account}:alias/{alias}",
344                account = state.account_id,
345                region = region,
346            ),
347            target_key_id: key_id,
348            creation_date: Utc::now().timestamp() as f64,
349        },
350    );
351    arn
352}
353
354fn key_id_from_arn(arn: &str) -> &str {
355    arn.rsplit_once('/').map(|(_, k)| k).unwrap_or(arn)
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn strip_arn_prefix_skips_region_and_account() {
364        assert_eq!(
365            strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:key/abc-123"),
366            Some("key/abc-123")
367        );
368        assert_eq!(
369            strip_kms_arn_prefix("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
370            Some("alias/aws/secretsmanager")
371        );
372        assert_eq!(strip_kms_arn_prefix("not-an-arn"), None);
373        // Missing one of region/account should return None, not a half-stripped resource.
374        assert_eq!(strip_kms_arn_prefix("arn:aws:kms:key/abc"), None);
375    }
376
377    #[test]
378    fn normalize_alias_handles_arns_correctly() {
379        assert_eq!(
380            normalize_alias("arn:aws:kms:us-east-1:000000000000:alias/aws/secretsmanager"),
381            "aws/secretsmanager"
382        );
383        assert_eq!(normalize_alias("alias/aws/sqs"), "aws/sqs");
384        assert_eq!(normalize_alias("aws/s3"), "aws/s3");
385    }
386}