rust_crypto_utils/
keywrap.rs1use aes_gcm::{
6 aead::{Aead, KeyInit, OsRng},
7 Aes256Gcm, Nonce,
8};
9use rand::RngCore;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14#[derive(Error, Debug)]
16pub enum KeyWrapError {
17 #[error("Wrapping failed: {0}")]
18 WrapFailed(String),
19
20 #[error("Unwrapping failed: {0}")]
21 UnwrapFailed(String),
22
23 #[error("Invalid key length: expected {expected}, got {actual}")]
24 InvalidKeyLength { expected: usize, actual: usize },
25
26 #[error("Invalid wrapped key format")]
27 InvalidFormat,
28}
29
30#[derive(Zeroize, ZeroizeOnDrop)]
32pub struct KeyEncryptionKey {
33 key: Vec<u8>,
34}
35
36impl KeyEncryptionKey {
37 pub fn generate() -> Self {
39 let mut key = vec![0u8; 32];
40 OsRng.fill_bytes(&mut key);
41 Self { key }
42 }
43
44 pub fn from_bytes(bytes: &[u8]) -> Result<Self, KeyWrapError> {
46 if bytes.len() != 32 {
47 return Err(KeyWrapError::InvalidKeyLength {
48 expected: 32,
49 actual: bytes.len(),
50 });
51 }
52 Ok(Self {
53 key: bytes.to_vec(),
54 })
55 }
56
57 pub fn as_bytes(&self) -> &[u8] {
59 &self.key
60 }
61}
62
63#[derive(Clone, Serialize, Deserialize)]
65pub struct WrappedKey {
66 pub ciphertext: Vec<u8>,
67 pub nonce: [u8; 12],
68 pub key_id: String,
69 pub algorithm: String,
70 pub wrapped_at: chrono::DateTime<chrono::Utc>,
71}
72
73impl WrappedKey {
74 pub fn to_hex(&self) -> String {
76 hex::encode(&self.ciphertext)
77 }
78
79 pub fn nonce_hex(&self) -> String {
81 hex::encode(self.nonce)
82 }
83
84 pub fn to_json(&self) -> Result<String, serde_json::Error> {
86 serde_json::to_string_pretty(self)
87 }
88}
89
90pub struct KeyWrapper {
92 cipher: Aes256Gcm,
93}
94
95impl KeyWrapper {
96 pub fn new(kek: &KeyEncryptionKey) -> Result<Self, KeyWrapError> {
98 let cipher = Aes256Gcm::new_from_slice(kek.as_bytes())
99 .map_err(|e| KeyWrapError::WrapFailed(e.to_string()))?;
100 Ok(Self { cipher })
101 }
102
103 pub fn wrap(&self, key: &[u8], key_id: &str) -> Result<WrappedKey, KeyWrapError> {
105 let mut nonce_bytes = [0u8; 12];
106 OsRng.fill_bytes(&mut nonce_bytes);
107 let nonce = Nonce::from_slice(&nonce_bytes);
108
109 let ciphertext = self
110 .cipher
111 .encrypt(nonce, key)
112 .map_err(|e| KeyWrapError::WrapFailed(e.to_string()))?;
113
114 Ok(WrappedKey {
115 ciphertext,
116 nonce: nonce_bytes,
117 key_id: key_id.to_string(),
118 algorithm: "AES-256-GCM".to_string(),
119 wrapped_at: chrono::Utc::now(),
120 })
121 }
122
123 pub fn unwrap(&self, wrapped: &WrappedKey) -> Result<Vec<u8>, KeyWrapError> {
125 let nonce = Nonce::from_slice(&wrapped.nonce);
126
127 let plaintext = self
128 .cipher
129 .decrypt(nonce, wrapped.ciphertext.as_ref())
130 .map_err(|e| KeyWrapError::UnwrapFailed(e.to_string()))?;
131
132 Ok(plaintext)
133 }
134
135 pub fn rewrap(
137 &self,
138 wrapped: &WrappedKey,
139 new_wrapper: &KeyWrapper,
140 ) -> Result<WrappedKey, KeyWrapError> {
141 let key = self.unwrap(wrapped)?;
142 new_wrapper.wrap(&key, &wrapped.key_id)
143 }
144}
145
146pub struct KeyHierarchy {
148 master_wrapper: KeyWrapper,
149 level_keys: Vec<KeyEncryptionKey>,
150}
151
152impl KeyHierarchy {
153 pub fn new(master_kek: KeyEncryptionKey) -> Result<Self, KeyWrapError> {
155 let master_wrapper = KeyWrapper::new(&master_kek)?;
156 Ok(Self {
157 master_wrapper,
158 level_keys: Vec::new(),
159 })
160 }
161
162 pub fn add_level(&mut self) -> Result<WrappedKey, KeyWrapError> {
164 let level_kek = KeyEncryptionKey::generate();
165 let level_id = format!("level-{}", self.level_keys.len());
166 let wrapped = self.master_wrapper.wrap(level_kek.as_bytes(), &level_id)?;
167 self.level_keys.push(level_kek);
168 Ok(wrapped)
169 }
170
171 pub fn get_level_wrapper(&self, level: usize) -> Result<KeyWrapper, KeyWrapError> {
173 let kek = self
174 .level_keys
175 .get(level)
176 .ok_or(KeyWrapError::InvalidFormat)?;
177 KeyWrapper::new(kek)
178 }
179
180 pub fn wrap_data_key(&self, key: &[u8], level: usize, key_id: &str) -> Result<WrappedKey, KeyWrapError> {
182 let wrapper = self.get_level_wrapper(level)?;
183 wrapper.wrap(key, key_id)
184 }
185
186 pub fn unwrap_data_key(&self, wrapped: &WrappedKey, level: usize) -> Result<Vec<u8>, KeyWrapError> {
188 let wrapper = self.get_level_wrapper(level)?;
189 wrapper.unwrap(wrapped)
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_wrap_unwrap() {
199 let kek = KeyEncryptionKey::generate();
200 let wrapper = KeyWrapper::new(&kek).unwrap();
201
202 let data_key = vec![0u8; 32]; let wrapped = wrapper.wrap(&data_key, "key-001").unwrap();
204 let unwrapped = wrapper.unwrap(&wrapped).unwrap();
205
206 assert_eq!(data_key, unwrapped);
207 }
208
209 #[test]
210 fn test_wrapped_key_metadata() {
211 let kek = KeyEncryptionKey::generate();
212 let wrapper = KeyWrapper::new(&kek).unwrap();
213
214 let data_key = vec![0u8; 32];
215 let wrapped = wrapper.wrap(&data_key, "my-key").unwrap();
216
217 assert_eq!(wrapped.key_id, "my-key");
218 assert_eq!(wrapped.algorithm, "AES-256-GCM");
219 }
220
221 #[test]
222 fn test_wrong_kek_fails() {
223 let kek1 = KeyEncryptionKey::generate();
224 let kek2 = KeyEncryptionKey::generate();
225
226 let wrapper1 = KeyWrapper::new(&kek1).unwrap();
227 let wrapper2 = KeyWrapper::new(&kek2).unwrap();
228
229 let data_key = vec![0u8; 32];
230 let wrapped = wrapper1.wrap(&data_key, "key-001").unwrap();
231 let result = wrapper2.unwrap(&wrapped);
232
233 assert!(result.is_err());
234 }
235
236 #[test]
237 fn test_rewrap() {
238 let kek1 = KeyEncryptionKey::generate();
239 let kek2 = KeyEncryptionKey::generate();
240
241 let wrapper1 = KeyWrapper::new(&kek1).unwrap();
242 let wrapper2 = KeyWrapper::new(&kek2).unwrap();
243
244 let data_key = vec![0u8; 32];
245 let wrapped1 = wrapper1.wrap(&data_key, "key-001").unwrap();
246 let wrapped2 = wrapper1.rewrap(&wrapped1, &wrapper2).unwrap();
247
248 let unwrapped = wrapper2.unwrap(&wrapped2).unwrap();
249 assert_eq!(data_key, unwrapped);
250 }
251
252 #[test]
253 fn test_key_hierarchy() {
254 let master_kek = KeyEncryptionKey::generate();
255 let mut hierarchy = KeyHierarchy::new(master_kek).unwrap();
256
257 hierarchy.add_level().unwrap();
259 hierarchy.add_level().unwrap();
260
261 let data_key = vec![42u8; 32];
263 let wrapped = hierarchy.wrap_data_key(&data_key, 0, "data-key-001").unwrap();
264
265 let unwrapped = hierarchy.unwrap_data_key(&wrapped, 0).unwrap();
267 assert_eq!(data_key, unwrapped);
268 }
269
270 #[test]
271 fn test_wrapped_key_json() {
272 let kek = KeyEncryptionKey::generate();
273 let wrapper = KeyWrapper::new(&kek).unwrap();
274
275 let wrapped = wrapper.wrap(&[0u8; 32], "test-key").unwrap();
276 let json = wrapped.to_json().unwrap();
277
278 assert!(json.contains("test-key"));
279 assert!(json.contains("AES-256-GCM"));
280 }
281
282 #[test]
283 fn test_invalid_kek_length() {
284 let result = KeyEncryptionKey::from_bytes(&[0u8; 16]); assert!(result.is_err());
286 }
287}