1use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, AES_256_GCM};
7use ring::rand::{SecureRandom, SystemRandom};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use zeroize::{Zeroize, ZeroizeOnDrop};
11
12pub mod key_derivation;
13
14#[derive(Error, Debug)]
15pub enum CryptoError {
16 #[error("Encryption failed: {0}")]
17 EncryptionFailed(String),
18
19 #[error("Decryption failed: {0}")]
20 DecryptionFailed(String),
21
22 #[error("Key generation failed: {0}")]
23 KeyGenerationFailed(String),
24
25 #[error("Invalid key length: expected {expected}, got {actual}")]
26 InvalidKeyLength { expected: usize, actual: usize },
27
28 #[error("Invalid nonce length: expected {expected}, got {actual}")]
29 InvalidNonceLength { expected: usize, actual: usize },
30
31 #[error("Ring error: {0}")]
32 RingError(String),
33}
34
35impl From<ring::error::Unspecified> for CryptoError {
36 fn from(err: ring::error::Unspecified) -> Self {
37 CryptoError::RingError(format!("{:?}", err))
38 }
39}
40
41pub type Result<T> = std::result::Result<T, CryptoError>;
42
43pub const KEY_SIZE: usize = 32;
45
46pub const NONCE_SIZE: usize = 12;
48
49pub const TAG_SIZE: usize = 16;
51
52#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
54pub enum Algorithm {
55 #[serde(rename = "aes-256-gcm")]
56 Aes256Gcm,
57}
58
59impl Default for Algorithm {
60 fn default() -> Self {
61 Algorithm::Aes256Gcm
62 }
63}
64
65#[derive(Clone, Zeroize, ZeroizeOnDrop)]
67pub struct SecretKey {
68 #[zeroize(skip)]
69 algorithm: Algorithm,
70 bytes: Vec<u8>,
71}
72
73impl SecretKey {
74 pub fn from_bytes(algorithm: Algorithm, bytes: Vec<u8>) -> Result<Self> {
76 let expected_len = match algorithm {
77 Algorithm::Aes256Gcm => KEY_SIZE,
78 };
79
80 if bytes.len() != expected_len {
81 return Err(CryptoError::InvalidKeyLength {
82 expected: expected_len,
83 actual: bytes.len(),
84 });
85 }
86
87 Ok(Self { algorithm, bytes })
88 }
89
90 pub fn generate(algorithm: Algorithm) -> Result<Self> {
92 let rng = SystemRandom::new();
93 let mut bytes = vec![0u8; KEY_SIZE];
94 rng.fill(&mut bytes)
95 .map_err(|e| CryptoError::KeyGenerationFailed(format!("{:?}", e)))?;
96
97 Ok(Self { algorithm, bytes })
98 }
99
100 pub fn algorithm(&self) -> Algorithm {
102 self.algorithm
103 }
104
105 pub fn as_bytes(&self) -> &[u8] {
107 &self.bytes
108 }
109
110 pub fn to_hex(&self) -> String {
112 hex::encode(&self.bytes)
113 }
114
115 pub fn from_hex(algorithm: Algorithm, hex_str: &str) -> Result<Self> {
117 let bytes = hex::decode(hex_str)
118 .map_err(|e| CryptoError::KeyGenerationFailed(format!("Invalid hex: {}", e)))?;
119 Self::from_bytes(algorithm, bytes)
120 }
121
122 pub fn from_base64(algorithm: Algorithm, b64_str: &str) -> Result<Self> {
124 use base64::Engine;
125 let bytes = base64::engine::general_purpose::STANDARD
126 .decode(b64_str)
127 .map_err(|e| CryptoError::KeyGenerationFailed(format!("Invalid base64: {}", e)))?;
128 Self::from_bytes(algorithm, bytes)
129 }
130
131 pub fn to_base64(&self) -> String {
133 use base64::Engine;
134 base64::engine::general_purpose::STANDARD.encode(&self.bytes)
135 }
136}
137
138impl std::fmt::Debug for SecretKey {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("SecretKey")
141 .field("algorithm", &self.algorithm)
142 .field("bytes", &"<redacted>")
143 .finish()
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct EncryptedData {
150 pub algorithm: Algorithm,
152
153 #[serde(with = "hex_serde")]
155 pub nonce: Vec<u8>,
156
157 #[serde(with = "hex_serde")]
159 pub ciphertext: Vec<u8>,
160
161 #[serde(default)]
163 pub key_version: u32,
164
165 #[serde(skip_serializing_if = "Option::is_none")]
167 pub aad_context: Option<String>,
168}
169
170mod hex_serde {
171 use serde::{Deserialize, Deserializer, Serializer};
172
173 pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
174 where
175 S: Serializer,
176 {
177 serializer.serialize_str(&hex::encode(bytes))
178 }
179
180 pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
181 where
182 D: Deserializer<'de>,
183 {
184 let s = String::deserialize(deserializer)?;
185 hex::decode(&s).map_err(serde::de::Error::custom)
186 }
187}
188
189pub fn encrypt(
191 key: &SecretKey,
192 plaintext: &[u8],
193 aad_context: Option<&str>,
194) -> Result<EncryptedData> {
195 if key.algorithm() != Algorithm::Aes256Gcm {
196 return Err(CryptoError::EncryptionFailed(
197 "Only AES-256-GCM is currently supported".to_string(),
198 ));
199 }
200
201 let rng = SystemRandom::new();
203 let mut nonce_bytes = [0u8; NONCE_SIZE];
204 rng.fill(&mut nonce_bytes)?;
205
206 let unbound_key = UnboundKey::new(&AES_256_GCM, key.as_bytes())?;
208 let less_safe_key = LessSafeKey::new(unbound_key);
209 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)?;
210
211 let mut in_out = plaintext.to_vec();
213
214 let aad = match aad_context {
216 Some(ctx) => Aad::from(ctx.as_bytes()),
217 None => Aad::from(&[] as &[u8]),
218 };
219
220 less_safe_key
222 .seal_in_place_append_tag(nonce, aad, &mut in_out)
223 .map_err(|e| CryptoError::EncryptionFailed(format!("{:?}", e)))?;
224
225 Ok(EncryptedData {
226 algorithm: Algorithm::Aes256Gcm,
227 nonce: nonce_bytes.to_vec(),
228 ciphertext: in_out,
229 key_version: 1,
230 aad_context: aad_context.map(String::from),
231 })
232}
233
234pub fn decrypt(
236 key: &SecretKey,
237 encrypted: &EncryptedData,
238) -> Result<Vec<u8>> {
239 if encrypted.algorithm != Algorithm::Aes256Gcm {
240 return Err(CryptoError::DecryptionFailed(
241 "Only AES-256-GCM is currently supported".to_string(),
242 ));
243 }
244
245 if key.algorithm() != Algorithm::Aes256Gcm {
246 return Err(CryptoError::DecryptionFailed(
247 "Key algorithm mismatch".to_string(),
248 ));
249 }
250
251 if encrypted.nonce.len() != NONCE_SIZE {
252 return Err(CryptoError::InvalidNonceLength {
253 expected: NONCE_SIZE,
254 actual: encrypted.nonce.len(),
255 });
256 }
257
258 let unbound_key = UnboundKey::new(&AES_256_GCM, key.as_bytes())?;
260 let less_safe_key = LessSafeKey::new(unbound_key);
261 let mut nonce_bytes = [0u8; NONCE_SIZE];
262 nonce_bytes.copy_from_slice(&encrypted.nonce);
263 let nonce = Nonce::try_assume_unique_for_key(&nonce_bytes)?;
264
265 let mut in_out = encrypted.ciphertext.clone();
267
268 let aad = match &encrypted.aad_context {
270 Some(ctx) => Aad::from(ctx.as_bytes()),
271 None => Aad::from(&[] as &[u8]),
272 };
273
274 let plaintext = less_safe_key
276 .open_in_place(nonce, aad, &mut in_out)
277 .map_err(|e| CryptoError::DecryptionFailed(format!("{:?}", e)))?;
278
279 Ok(plaintext.to_vec())
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_key_generation() {
288 let key = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
289 assert_eq!(key.as_bytes().len(), KEY_SIZE);
290 assert_eq!(key.algorithm(), Algorithm::Aes256Gcm);
291 }
292
293 #[test]
294 fn test_key_hex_round_trip() {
295 let key1 = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
296 let hex = key1.to_hex();
297 let key2 = SecretKey::from_hex(Algorithm::Aes256Gcm, &hex).unwrap();
298 assert_eq!(key1.as_bytes(), key2.as_bytes());
299 }
300
301 #[test]
302 fn test_key_base64_round_trip() {
303 let key1 = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
304 let b64 = key1.to_base64();
305 let key2 = SecretKey::from_base64(Algorithm::Aes256Gcm, &b64).unwrap();
306 assert_eq!(key1.as_bytes(), key2.as_bytes());
307 }
308
309 #[test]
310 fn test_encrypt_decrypt() {
311 let key = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
312 let plaintext = b"Hello, World! This is a secret message.";
313
314 let encrypted = encrypt(&key, plaintext, None).unwrap();
315 assert_eq!(encrypted.algorithm, Algorithm::Aes256Gcm);
316 assert_eq!(encrypted.nonce.len(), NONCE_SIZE);
317
318 let decrypted = decrypt(&key, &encrypted).unwrap();
319 assert_eq!(decrypted, plaintext);
320 }
321
322 #[test]
323 fn test_encrypt_decrypt_with_aad() {
324 let key = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
325 let plaintext = b"Secret data";
326 let aad_context = "tenant-123/config/production";
327
328 let encrypted = encrypt(&key, plaintext, Some(aad_context)).unwrap();
329 let decrypted = decrypt(&key, &encrypted).unwrap();
330 assert_eq!(decrypted, plaintext);
331 }
332
333 #[test]
334 fn test_wrong_key_fails() {
335 let key1 = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
336 let key2 = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
337 let plaintext = b"Secret";
338
339 let encrypted = encrypt(&key1, plaintext, None).unwrap();
340 let result = decrypt(&key2, &encrypted);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_tampered_ciphertext_fails() {
346 let key = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
347 let plaintext = b"Secret";
348
349 let mut encrypted = encrypt(&key, plaintext, None).unwrap();
350 if let Some(byte) = encrypted.ciphertext.first_mut() {
352 *byte ^= 0xFF;
353 }
354
355 let result = decrypt(&key, &encrypted);
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn test_encrypted_data_serialization() {
361 let key = SecretKey::generate(Algorithm::Aes256Gcm).unwrap();
362 let plaintext = b"Test data";
363
364 let encrypted = encrypt(&key, plaintext, Some("test-context")).unwrap();
365
366 let json = serde_json::to_string(&encrypted).unwrap();
368
369 let deserialized: EncryptedData = serde_json::from_str(&json).unwrap();
371
372 let decrypted = decrypt(&key, &deserialized).unwrap();
374 assert_eq!(decrypted, plaintext);
375 }
376}