1use aes_gcm::aead::generic_array::GenericArray;
26use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
27use aes_gcm::{AeadCore, Aes256Gcm, Key};
28
29const VERSION_HEADER: [u8; 4] = [0x01, 0x02, 0x02, 0x00];
30
31fn cipher_for(master_key_bytes: &[u8]) -> Option<Aes256Gcm> {
32 if master_key_bytes.len() != 32 {
33 return None;
34 }
35 let key = Key::<Aes256Gcm>::from_slice(master_key_bytes);
36 Some(Aes256Gcm::new(key))
37}
38
39pub fn encode(master_key_bytes: &[u8], key_id: &str, plaintext: &[u8]) -> Vec<u8> {
45 encode_with_context(master_key_bytes, key_id, plaintext, &[])
46}
47
48pub fn encode_with_context(
56 master_key_bytes: &[u8],
57 key_id: &str,
58 plaintext: &[u8],
59 extra_aad: &[u8],
60) -> Vec<u8> {
61 let cipher =
62 cipher_for(master_key_bytes).expect("KMS master key must be 32 bytes for AES-256-GCM");
63 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
64
65 let mut aad = Vec::with_capacity(key_id.len() + extra_aad.len());
73 aad.extend_from_slice(key_id.as_bytes());
74 aad.extend_from_slice(extra_aad);
75 let combined = cipher
76 .encrypt(
77 &nonce,
78 Payload {
79 msg: plaintext,
80 aad: &aad,
81 },
82 )
83 .expect("AES-GCM encrypt with 96-bit nonce never fails on valid key");
84 debug_assert!(combined.len() >= 16, "AES-GCM output includes 16-byte tag");
85 let tag_split = combined.len() - 16;
86 let ciphertext = &combined[..tag_split];
87 let tag = &combined[tag_split..];
88
89 let key_bytes = key_id.as_bytes();
90 let mut out = Vec::with_capacity(
91 VERSION_HEADER.len() + 8 + key_bytes.len() + 12 + 4 + ciphertext.len() + 16,
92 );
93 out.extend_from_slice(&VERSION_HEADER);
94 out.extend_from_slice(&(key_bytes.len() as u64).to_be_bytes());
95 out.extend_from_slice(key_bytes);
96 out.extend_from_slice(nonce.as_slice());
97 out.extend_from_slice(&(ciphertext.len() as u32).to_be_bytes());
98 out.extend_from_slice(ciphertext);
99 out.extend_from_slice(tag);
100 out
101}
102
103pub struct Decoded {
106 pub key_id: String,
107 pub plaintext: Vec<u8>,
108}
109
110pub fn decode(master_key_bytes: &[u8], blob: &[u8]) -> Option<Decoded> {
116 decode_with_context(master_key_bytes, blob, &[])
117}
118
119pub fn decode_with_context(
125 master_key_bytes: &[u8],
126 blob: &[u8],
127 extra_aad: &[u8],
128) -> Option<Decoded> {
129 if blob.len() < VERSION_HEADER.len() + 8 + 12 + 4 + 16 {
130 return None;
131 }
132 if blob[..VERSION_HEADER.len()] != VERSION_HEADER {
133 return None;
134 }
135 let mut cursor = VERSION_HEADER.len();
136
137 let key_len = u64::from_be_bytes(blob[cursor..cursor + 8].try_into().ok()?) as usize;
138 cursor += 8;
139 if cursor + key_len + 12 + 4 + 16 > blob.len() {
140 return None;
141 }
142 let key_id = std::str::from_utf8(&blob[cursor..cursor + key_len])
143 .ok()?
144 .to_string();
145 cursor += key_len;
146
147 let nonce = GenericArray::from_slice(&blob[cursor..cursor + 12]);
148 cursor += 12;
149
150 let ct_len = u32::from_be_bytes(blob[cursor..cursor + 4].try_into().ok()?) as usize;
151 cursor += 4;
152 if cursor + ct_len + 16 != blob.len() {
153 return None;
154 }
155 let ct_with_tag = &blob[cursor..cursor + ct_len + 16];
156
157 let cipher = cipher_for(master_key_bytes)?;
158 let mut aad = Vec::with_capacity(key_id.len() + extra_aad.len());
159 aad.extend_from_slice(key_id.as_bytes());
160 aad.extend_from_slice(extra_aad);
161 let plaintext = cipher
162 .decrypt(
163 nonce,
164 Payload {
165 msg: ct_with_tag,
166 aad: &aad,
167 },
168 )
169 .ok()?;
170
171 Some(Decoded { key_id, plaintext })
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn fixed_master() -> Vec<u8> {
179 (0u8..32).collect()
181 }
182
183 #[test]
184 fn round_trip_recovers_plaintext_and_key_id() {
185 let plaintext = b"super-secret-value";
186 let mk = fixed_master();
187 let blob = encode(&mk, "alias/my-key", plaintext);
188 let decoded = decode(&mk, &blob).unwrap();
189 assert_eq!(decoded.plaintext, plaintext);
190 assert_eq!(decoded.key_id, "alias/my-key");
191 }
192
193 #[test]
194 fn blob_does_not_leak_plaintext() {
195 let plaintext = b"NOT_TO_BE_FOUND_IN_BYTES";
196 let blob = encode(&fixed_master(), "key-1", plaintext);
197 assert!(blob.windows(plaintext.len()).all(|w| w != plaintext));
198 }
199
200 #[test]
201 fn decode_rejects_random_bytes() {
202 let mk = fixed_master();
203 assert!(decode(&mk, b"this-is-not-a-blob").is_none());
204 assert!(decode(&mk, &[0u8; 8]).is_none());
205 }
206
207 #[test]
208 fn decode_rejects_wrong_header() {
209 let mk = fixed_master();
210 let mut blob = encode(&mk, "k", b"data");
211 blob[0] = 0xFF;
212 assert!(decode(&mk, &blob).is_none());
213 }
214
215 #[test]
216 fn decode_rejects_tampered_ciphertext() {
217 let mk = fixed_master();
218 let mut blob = encode(&mk, "k", b"data");
219 let last = blob.len() - 1;
220 blob[last] ^= 0x01;
221 assert!(decode(&mk, &blob).is_none());
222 }
223
224 #[test]
225 fn decode_rejects_wrong_master_key() {
226 let blob = encode(&fixed_master(), "k", b"data");
227 let other_key: Vec<u8> = (32u8..64).collect();
228 assert!(decode(&other_key, &blob).is_none());
229 }
230
231 #[test]
232 fn decode_rejects_tampered_key_id_header() {
233 let mk = fixed_master();
234 let mut blob = encode(&mk, "alias/original-key", b"data");
235 let key_id_offset = 4 + 8;
240 blob[key_id_offset] ^= 0x01;
241 assert!(decode(&mk, &blob).is_none());
242 }
243
244 #[test]
245 fn distinct_calls_produce_distinct_blobs() {
246 let mk = fixed_master();
247 let a = encode(&mk, "k", b"same");
248 let b = encode(&mk, "k", b"same");
249 assert_ne!(a, b, "fresh IV should make ciphertext non-deterministic");
250 }
251
252 #[test]
253 fn decode_with_context_round_trips_when_ec_matches() {
254 let mk = fixed_master();
255 let aad = b"{\"app\":\"prod\"}";
256 let blob = encode_with_context(&mk, "k", b"secret", aad);
257 let decoded = decode_with_context(&mk, &blob, aad).expect("matching EC must decode");
258 assert_eq!(decoded.plaintext, b"secret");
259 }
260
261 #[test]
262 fn decode_with_context_rejects_mismatched_ec() {
263 let mk = fixed_master();
264 let blob = encode_with_context(&mk, "k", b"secret", b"{\"app\":\"prod\"}");
265 assert!(decode_with_context(&mk, &blob, b"{\"app\":\"staging\"}").is_none());
267 assert!(decode_with_context(&mk, &blob, b"").is_none());
269 }
270
271 #[test]
272 fn decode_without_context_rejects_blob_encoded_with_ec() {
273 let mk = fixed_master();
274 let blob = encode_with_context(&mk, "k", b"secret", b"{\"x\":\"y\"}");
275 assert!(decode(&mk, &blob).is_none());
277 }
278}