1#![warn(missing_docs, clippy::all)]
64
65pub use crate::deserializer::Deserializer;
66use aes_gcm::aes::Aes256;
67use aes_gcm::{AeadInPlace, Aes256Gcm, KeyInit, Nonce};
68use aes_gcm::{AesGcm, Tag};
69use base64::display::Base64Display;
70use base64::engine::general_purpose::STANDARD;
71use base64::Engine;
72use rand::{CryptoRng, Rng};
73use serde::{Deserialize, Serialize};
74use std::error;
75use std::fmt;
76use std::fs;
77use std::io;
78use std::path::Path;
79use std::result;
80use std::str::FromStr;
81use std::string::FromUtf8Error;
82use typenum::U32;
83
84const KEY_PREFIX: &str = "AES:";
85const KEY_LEN: usize = 32;
86const LEGACY_IV_LEN: usize = 32;
87const IV_LEN: usize = 12;
88const TAG_LEN: usize = 16;
89
90type LegacyAes256Gcm = AesGcm<Aes256, U32>;
91
92mod deserializer;
93
94pub type Result<T> = result::Result<T, Error>;
96
97#[derive(Debug)]
98enum ErrorCause {
99 AesGcm(aes_gcm::Error),
100 Io(io::Error),
101 Base64(base64::DecodeError),
102 Utf8(FromUtf8Error),
103 BadPrefix,
104 InvalidLength,
105 KeyExhausted,
106}
107
108#[derive(Debug)]
110pub struct Error(Box<ErrorCause>);
111
112impl fmt::Display for Error {
113 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
114 match *self.0 {
115 ErrorCause::AesGcm(ref e) => fmt::Display::fmt(e, fmt),
116 ErrorCause::Io(ref e) => fmt::Display::fmt(e, fmt),
117 ErrorCause::Base64(ref e) => fmt::Display::fmt(e, fmt),
118 ErrorCause::Utf8(ref e) => fmt::Display::fmt(e, fmt),
119 ErrorCause::BadPrefix => fmt.write_str("invalid key prefix"),
120 ErrorCause::InvalidLength => fmt.write_str("invalid encrypted value component length"),
121 ErrorCause::KeyExhausted => fmt.write_str("key cannot encrypt more than 2^64 values"),
122 }
123 }
124}
125
126impl error::Error for Error {}
127
128#[derive(Serialize, Deserialize)]
129#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE")]
130enum EncryptedValue {
131 Aes {
132 mode: AesMode,
133 #[serde(with = "serde_base64")]
134 iv: Vec<u8>,
135 #[serde(with = "serde_base64")]
136 ciphertext: Vec<u8>,
137 #[serde(with = "serde_base64")]
138 tag: Vec<u8>,
139 },
140}
141
142mod serde_base64 {
143 use base64::engine::general_purpose::STANDARD;
144 use base64::Engine;
145 use serde::de;
146 use serde::{Deserialize, Deserializer, Serialize, Serializer};
147
148 pub fn serialize<S>(buf: &[u8], s: S) -> Result<S::Ok, S::Error>
149 where
150 S: Serializer,
151 {
152 STANDARD.encode(buf).serialize(s)
153 }
154
155 pub fn deserialize<'a, D>(d: D) -> Result<Vec<u8>, D::Error>
156 where
157 D: Deserializer<'a>,
158 {
159 let s = String::deserialize(d)?;
160 STANDARD
161 .decode(&s)
162 .map_err(|_| de::Error::invalid_value(de::Unexpected::Str(&s), &"a base64 string"))
163 }
164}
165
166fn secure_rng() -> impl Rng + CryptoRng {
168 rand::rng()
169}
170
171#[derive(Serialize, Deserialize)]
172#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
173enum AesMode {
174 Gcm,
175}
176
177pub struct ReadOnly(());
179
180pub struct ReadWrite {
182 iv: [u8; IV_LEN],
189 counter: u64,
190}
191
192pub struct Key<T> {
203 key: [u8; KEY_LEN],
204 mode: T,
205}
206
207impl Key<ReadWrite> {
208 pub fn random_aes() -> Result<Key<ReadWrite>> {
210 Ok(Key {
211 key: secure_rng().random(),
212 mode: ReadWrite {
213 iv: secure_rng().random(),
214 counter: 0,
215 },
216 })
217 }
218
219 pub fn encrypt(&mut self, value: &str) -> Result<String> {
221 let counter = self.mode.counter;
222 self.mode.counter = match self.mode.counter.checked_add(1) {
223 Some(v) => v,
224 None => return Err(Error(Box::new(ErrorCause::KeyExhausted))),
225 };
226
227 let mut iv = Nonce::from(self.mode.iv);
228 for (i, byte) in counter.to_le_bytes().iter().enumerate() {
229 iv[i] ^= *byte;
230 }
231
232 let mut ciphertext = value.as_bytes().to_vec();
233 let tag = Aes256Gcm::new(&self.key.into())
234 .encrypt_in_place_detached(&iv, &[], &mut ciphertext)
235 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
236
237 let value = EncryptedValue::Aes {
238 mode: AesMode::Gcm,
239 iv: iv.to_vec(),
240 ciphertext,
241 tag: tag.to_vec(),
242 };
243
244 let value = serde_json::to_string(&value).unwrap();
245 Ok(STANDARD.encode(value.as_bytes()))
246 }
247}
248
249impl Key<ReadOnly> {
250 pub fn from_file<P>(path: P) -> Result<Option<Key<ReadOnly>>>
255 where
256 P: AsRef<Path>,
257 {
258 let s = match fs::read_to_string(path) {
259 Ok(s) => s,
260 Err(ref e) if e.kind() == io::ErrorKind::NotFound => return Ok(None),
261 Err(e) => return Err(Error(Box::new(ErrorCause::Io(e)))),
262 };
263 s.parse().map(Some)
264 }
265}
266
267impl<T> Key<T> {
268 pub fn decrypt(&self, value: &str) -> Result<String> {
270 let value = STANDARD
271 .decode(value)
272 .map_err(|e| Error(Box::new(ErrorCause::Base64(e))))?;
273
274 let (iv, mut ct, tag) = match serde_json::from_slice(&value) {
275 Ok(EncryptedValue::Aes {
276 mode: AesMode::Gcm,
277 iv,
278 ciphertext,
279 tag,
280 }) => {
281 if iv.len() != IV_LEN || tag.len() != TAG_LEN {
282 return Err(Error(Box::new(ErrorCause::InvalidLength)));
283 }
284
285 let mut iv_arr = [0; IV_LEN];
286 iv_arr.copy_from_slice(&iv);
287
288 let mut tag_arr = [0; TAG_LEN];
289 tag_arr.copy_from_slice(&tag);
290
291 (Iv::Standard(iv_arr), ciphertext, tag_arr)
292 }
293 Err(_) => {
294 if value.len() < LEGACY_IV_LEN + TAG_LEN {
295 return Err(Error(Box::new(ErrorCause::InvalidLength)));
296 }
297
298 let mut iv = [0; LEGACY_IV_LEN];
299 iv.copy_from_slice(&value[..LEGACY_IV_LEN]);
300
301 let ct = value[LEGACY_IV_LEN..value.len() - TAG_LEN].to_vec();
302
303 let mut tag = [0; TAG_LEN];
304 tag.copy_from_slice(&value[value.len() - TAG_LEN..]);
305
306 (Iv::Legacy(iv), ct, tag)
307 }
308 };
309
310 let tag = Tag::from(tag);
311
312 match iv {
313 Iv::Legacy(iv) => {
314 let iv = Nonce::from(iv);
315
316 LegacyAes256Gcm::new(&self.key.into())
317 .decrypt_in_place_detached(&iv, &[], &mut ct, &tag)
318 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
319 }
320 Iv::Standard(iv) => {
321 let iv = Nonce::from(iv);
322
323 Aes256Gcm::new(&self.key.into())
324 .decrypt_in_place_detached(&iv, &[], &mut ct, &tag)
325 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
326 }
327 };
328
329 let pt = String::from_utf8(ct).map_err(|e| Error(Box::new(ErrorCause::Utf8(e))))?;
330
331 Ok(pt)
332 }
333}
334
335impl<T> fmt::Display for Key<T> {
336 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
337 write!(fmt, "AES:{}", Base64Display::new(&self.key, &STANDARD))
338 }
339}
340
341impl FromStr for Key<ReadOnly> {
342 type Err = Error;
343
344 fn from_str(s: &str) -> Result<Key<ReadOnly>> {
345 if !s.starts_with(KEY_PREFIX) {
346 return Err(Error(Box::new(ErrorCause::BadPrefix)));
347 }
348
349 let key = STANDARD
350 .decode(&s[KEY_PREFIX.len()..])
351 .map_err(|e| Error(Box::new(ErrorCause::Base64(e))))?;
352
353 if key.len() != KEY_LEN {
354 return Err(Error(Box::new(ErrorCause::InvalidLength)));
355 }
356
357 let mut key_arr = [0; KEY_LEN];
358 key_arr.copy_from_slice(&key);
359
360 Ok(Key {
361 key: key_arr,
362 mode: ReadOnly(()),
363 })
364 }
365}
366
367enum Iv {
368 Legacy([u8; LEGACY_IV_LEN]),
369 Standard([u8; IV_LEN]),
370}
371
372#[cfg(test)]
373mod test {
374 use serde::Deserialize;
375 use std::fs::File;
376 use std::io::Write;
377 use tempfile::tempdir;
378
379 use super::*;
380
381 const KEY: &str = "AES:NwQZdNWsFmYMCNSQlfYPDJtFBgPzY8uZlFhMCLnxNQE=";
382
383 #[test]
384 fn from_file_aes() {
385 let dir = tempdir().unwrap();
386 let path = dir.path().join("encrypted-config-value.key");
387 let mut key = File::create(&path).unwrap();
388 key.write_all(KEY.as_bytes()).unwrap();
389
390 assert!(Key::from_file(&path).unwrap().is_some());
391 }
392
393 #[test]
394 fn from_file_empty() {
395 let dir = tempdir().unwrap();
396 let path = dir.path().join("encrypted-config-value.key");
397
398 assert!(Key::from_file(path).unwrap().is_none());
399 }
400
401 #[test]
402 fn decrypt_legacy() {
403 let ct =
404 "5BBfGvf90H6bApwfxUjNdoKRW1W+GZCbhBuBpzEogVBmQZyWFFxcKyf+UPV5FOhrw/wrVZyoL3npoDfYj\
405 PQV/zg0W/P9cVOw";
406 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
407
408 let key = KEY.parse::<Key<ReadOnly>>().unwrap();
409 let actual = key.decrypt(ct).unwrap();
410 assert_eq!(actual, pt);
411 }
412
413 #[test]
414 fn decrypt() {
415 let ct =
416 "eyJ0eXBlIjoiQUVTIiwibW9kZSI6IkdDTSIsIml2IjoiUCtRQXM5aHo4VFJVOUpNLyIsImNpcGhlcnRle\
417 HQiOiJmUGpDaDVuMkR0cklPSVNXSklLcVQzSUtRNUtONVI3LyIsInRhZyI6ImlJRFIzYUtER1UyK1Brej\
418 NPSEdSL0E9PSJ9";
419 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
420
421 let key = KEY.parse::<Key<ReadOnly>>().unwrap();
422 let actual = key.decrypt(ct).unwrap();
423 assert_eq!(actual, pt);
424 }
425
426 #[test]
427 fn encrypt_decrypt() {
428 let mut key = Key::random_aes().unwrap();
429 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
430 let ct = key.encrypt(pt).unwrap();
431 let actual = key.decrypt(&ct).unwrap();
432 assert_eq!(pt, actual);
433 }
434
435 #[test]
436 fn unique_ivs() {
437 let mut key = Key::random_aes().unwrap();
438 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
439 let ct1 = key.encrypt(pt).unwrap();
440 let ct2 = key.encrypt(pt).unwrap();
441 assert_ne!(ct1, ct2);
442 }
443
444 #[test]
445 fn deserializer() {
446 #[derive(Deserialize, PartialEq, Debug)]
447 struct Config {
448 sub: Subconfig,
449 }
450
451 #[derive(Deserialize, PartialEq, Debug)]
452 struct Subconfig {
453 encrypted: Vec<String>,
454 plaintext: String,
455 }
456
457 let config = r#"
458{
459 "sub": {
460 "encrypted": [
461 "${enc:5BBfGvf90H6bApwfxUjNdoKRW1W+GZCbhBuBpzEogVBmQZyWFFxcKyf+UPV5FOhrw/wrVZyoL3npoDfYjPQV/zg0W/P9cVOw}"
462 ],
463 "plaintext": "${foobar}"
464 }
465}
466 "#;
467
468 let key = KEY.parse().unwrap();
469 let mut deserializer = serde_json::Deserializer::from_str(config);
470 let deserializer = Deserializer::new(&mut deserializer, Some(&key));
471
472 let config = Config::deserialize(deserializer).unwrap();
473
474 let expected = Config {
475 sub: Subconfig {
476 encrypted: vec!["L/TqOWz7E4z0SoeiTYBrqbqu".to_string()],
477 plaintext: "${foobar}".to_string(),
478 },
479 };
480
481 assert_eq!(config, expected);
482 }
483}