1use aes_gcm::aead::generic_array::GenericArray;
26use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
27use aes_gcm::{AeadCore, Aes256Gcm, Key};
28
29const VERSION_HEADER: [u8; 4] = [0x01, 0x02, 0x02, 0x00];
30
31fn cipher_for(master_key_bytes: &[u8]) -> Option<Aes256Gcm> {
32 if master_key_bytes.len() != 32 {
33 return None;
34 }
35 let key = Key::<Aes256Gcm>::from_slice(master_key_bytes);
36 Some(Aes256Gcm::new(key))
37}
38
39pub fn encode(master_key_bytes: &[u8], key_id: &str, plaintext: &[u8]) -> Vec<u8> {
45 let cipher =
46 cipher_for(master_key_bytes).expect("KMS master key must be 32 bytes for AES-256-GCM");
47 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
48
49 let combined = cipher
58 .encrypt(
59 &nonce,
60 Payload {
61 msg: plaintext,
62 aad: key_id.as_bytes(),
63 },
64 )
65 .expect("AES-GCM encrypt with 96-bit nonce never fails on valid key");
66 debug_assert!(combined.len() >= 16, "AES-GCM output includes 16-byte tag");
67 let tag_split = combined.len() - 16;
68 let ciphertext = &combined[..tag_split];
69 let tag = &combined[tag_split..];
70
71 let key_bytes = key_id.as_bytes();
72 let mut out = Vec::with_capacity(
73 VERSION_HEADER.len() + 8 + key_bytes.len() + 12 + 4 + ciphertext.len() + 16,
74 );
75 out.extend_from_slice(&VERSION_HEADER);
76 out.extend_from_slice(&(key_bytes.len() as u64).to_be_bytes());
77 out.extend_from_slice(key_bytes);
78 out.extend_from_slice(nonce.as_slice());
79 out.extend_from_slice(&(ciphertext.len() as u32).to_be_bytes());
80 out.extend_from_slice(ciphertext);
81 out.extend_from_slice(tag);
82 out
83}
84
85pub struct Decoded {
88 pub key_id: String,
89 pub plaintext: Vec<u8>,
90}
91
92pub fn decode(master_key_bytes: &[u8], blob: &[u8]) -> Option<Decoded> {
98 if blob.len() < VERSION_HEADER.len() + 8 + 12 + 4 + 16 {
99 return None;
100 }
101 if blob[..VERSION_HEADER.len()] != VERSION_HEADER {
102 return None;
103 }
104 let mut cursor = VERSION_HEADER.len();
105
106 let key_len = u64::from_be_bytes(blob[cursor..cursor + 8].try_into().ok()?) as usize;
107 cursor += 8;
108 if cursor + key_len + 12 + 4 + 16 > blob.len() {
109 return None;
110 }
111 let key_id = std::str::from_utf8(&blob[cursor..cursor + key_len])
112 .ok()?
113 .to_string();
114 cursor += key_len;
115
116 let nonce = GenericArray::from_slice(&blob[cursor..cursor + 12]);
117 cursor += 12;
118
119 let ct_len = u32::from_be_bytes(blob[cursor..cursor + 4].try_into().ok()?) as usize;
120 cursor += 4;
121 if cursor + ct_len + 16 != blob.len() {
122 return None;
123 }
124 let ct_with_tag = &blob[cursor..cursor + ct_len + 16];
125
126 let cipher = cipher_for(master_key_bytes)?;
127 let plaintext = cipher
128 .decrypt(
129 nonce,
130 Payload {
131 msg: ct_with_tag,
132 aad: key_id.as_bytes(),
133 },
134 )
135 .ok()?;
136
137 Some(Decoded { key_id, plaintext })
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 fn fixed_master() -> Vec<u8> {
145 (0u8..32).collect()
147 }
148
149 #[test]
150 fn round_trip_recovers_plaintext_and_key_id() {
151 let plaintext = b"super-secret-value";
152 let mk = fixed_master();
153 let blob = encode(&mk, "alias/my-key", plaintext);
154 let decoded = decode(&mk, &blob).unwrap();
155 assert_eq!(decoded.plaintext, plaintext);
156 assert_eq!(decoded.key_id, "alias/my-key");
157 }
158
159 #[test]
160 fn blob_does_not_leak_plaintext() {
161 let plaintext = b"NOT_TO_BE_FOUND_IN_BYTES";
162 let blob = encode(&fixed_master(), "key-1", plaintext);
163 assert!(blob.windows(plaintext.len()).all(|w| w != plaintext));
164 }
165
166 #[test]
167 fn decode_rejects_random_bytes() {
168 let mk = fixed_master();
169 assert!(decode(&mk, b"this-is-not-a-blob").is_none());
170 assert!(decode(&mk, &[0u8; 8]).is_none());
171 }
172
173 #[test]
174 fn decode_rejects_wrong_header() {
175 let mk = fixed_master();
176 let mut blob = encode(&mk, "k", b"data");
177 blob[0] = 0xFF;
178 assert!(decode(&mk, &blob).is_none());
179 }
180
181 #[test]
182 fn decode_rejects_tampered_ciphertext() {
183 let mk = fixed_master();
184 let mut blob = encode(&mk, "k", b"data");
185 let last = blob.len() - 1;
186 blob[last] ^= 0x01;
187 assert!(decode(&mk, &blob).is_none());
188 }
189
190 #[test]
191 fn decode_rejects_wrong_master_key() {
192 let blob = encode(&fixed_master(), "k", b"data");
193 let other_key: Vec<u8> = (32u8..64).collect();
194 assert!(decode(&other_key, &blob).is_none());
195 }
196
197 #[test]
198 fn decode_rejects_tampered_key_id_header() {
199 let mk = fixed_master();
200 let mut blob = encode(&mk, "alias/original-key", b"data");
201 let key_id_offset = 4 + 8;
206 blob[key_id_offset] ^= 0x01;
207 assert!(decode(&mk, &blob).is_none());
208 }
209
210 #[test]
211 fn distinct_calls_produce_distinct_blobs() {
212 let mk = fixed_master();
213 let a = encode(&mk, "k", b"same");
214 let b = encode(&mk, "k", b"same");
215 assert_ne!(a, b, "fresh IV should make ciphertext non-deterministic");
216 }
217}