1#![deny(missing_docs)]
2use aes_gcm::aead::{generic_array::typenum::U32, generic_array::GenericArray, Aead, NewAead};
5use aes_gcm::Aes256Gcm;
6use rand::Rng;
7use sha2::{Digest, Sha256};
8use std::error;
9use std::fmt;
10
11const NONCE_LEN: usize = 12;
12const TAG_LEN: usize = 16;
13
14#[derive(Clone, PartialEq)]
16pub struct Ec3Key(pub GenericArray<u8, U32>);
17
18impl Ec3Key {
19 pub fn new(key: &str) -> Self {
21 Self::new_raw(key.as_bytes())
22 }
23
24 pub fn new_raw(key: &[u8]) -> Self {
26 let mut hasher = Sha256::new();
27 hasher.update(key);
28 Self(hasher.finalize())
29 }
30
31 pub fn encrypt(&self, token: &str) -> String {
33 let nonce = rand::thread_rng().gen::<[u8; NONCE_LEN]>();
34 let nonce = GenericArray::from_slice(&nonce);
35 let cipher = Aes256Gcm::new(&self.0);
36
37 let mut ciphertext = cipher
38 .encrypt(nonce, token.as_bytes())
39 .expect("encryption failure!");
40
41 let mut encrypted: Vec<u8> = Vec::from(nonce.as_slice());
42 encrypted.append(&mut ciphertext);
43
44 base64::encode_config(&encrypted, base64::URL_SAFE_NO_PAD)
45 }
46
47 pub fn decrypt(&self, token: &str) -> Result<String, DecryptionError> {
49 let token = base64::decode_config(token, base64::URL_SAFE_NO_PAD)?;
50
51 if token.len() < (NONCE_LEN + TAG_LEN) as usize {
52 return Err(DecryptionError::IOError("invalid input length"));
53 }
54
55 let cipher = Aes256Gcm::new(&self.0);
56 let nonce = GenericArray::from_slice(&token[0..NONCE_LEN]);
57
58 let ciphertext = &token[NONCE_LEN..];
59
60 let plaintext = match cipher.decrypt(nonce, ciphertext) {
61 Ok(text) => text,
62 Err(_) => return Err(DecryptionError::IOError("decryption failed")),
63 };
64
65 let s = String::from_utf8(plaintext)?;
66 Ok(s)
67 }
68}
69
70pub fn encrypt_v3(key: &str, token: &str) -> String {
82 let key = Ec3Key::new(key);
83
84 key.encrypt(token)
85}
86
87pub fn decrypt_v3(key: &str, token: &str) -> Result<String, DecryptionError> {
97 let key = Ec3Key::new(key);
98
99 key.decrypt(token)
100}
101
102#[derive(Debug)]
104pub enum DecryptionError {
105 InvalidBase64(base64::DecodeError),
107 InvalidUTF8(std::string::FromUtf8Error),
109 IOError(&'static str),
111}
112
113impl fmt::Display for DecryptionError {
114 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
115 match *self {
116 DecryptionError::InvalidBase64(_) => write!(f, "Invalid base64."),
117 DecryptionError::InvalidUTF8(_) => write!(f, "Invalid UTF8 string decrypted."),
118 DecryptionError::IOError(description) => {
119 write!(f, "Input/Output error: {}", description)
120 }
121 }
122 }
123}
124
125impl error::Error for DecryptionError {
126 fn description(&self) -> &str {
127 match *self {
128 DecryptionError::InvalidBase64(_) => "invalid base64",
129 DecryptionError::InvalidUTF8(_) => "invalid UTF8 string decrypted",
130 DecryptionError::IOError(_) => "input/output error",
131 }
132 }
133
134 fn cause(&self) -> Option<&dyn error::Error> {
135 match *self {
136 DecryptionError::InvalidBase64(ref previous) => Some(previous),
137 DecryptionError::InvalidUTF8(ref previous) => Some(previous),
138 _ => None,
139 }
140 }
141}
142
143impl From<base64::DecodeError> for DecryptionError {
144 fn from(err: base64::DecodeError) -> DecryptionError {
145 DecryptionError::InvalidBase64(err)
146 }
147}
148
149impl From<std::string::FromUtf8Error> for DecryptionError {
150 fn from(err: std::string::FromUtf8Error) -> DecryptionError {
151 DecryptionError::InvalidUTF8(err)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn it_decodes_properly() {
161 let key = "mykey";
162 let msg = "hello world";
163 let encrypted = encrypt_v3(&key, &msg);
164 assert_eq!(msg, decrypt_v3(&key, &encrypted).expect("decrypt failed"));
165 }
166
167 #[test]
168 fn it_returns_err_on_invalid_base64_string() {
169 let decrypted = decrypt_v3(
170 "testkey123",
171 "af0c6acf7906cd500aee63a4dd2e97ddcb0142601cf83aa9d622289718c4c85413",
172 );
173
174 assert!(
175 decrypted.is_err(),
176 "decryption should be an Error with invalid base64 encoded string"
177 );
178 }
179
180 #[test]
181 fn it_returns_err_on_invalid_length() {
182 let decrypted = decrypt_v3("testkey123", "bs4W7wyy");
183
184 assert!(
185 decrypted.is_err(),
186 "decryption should be an Error with invalid length encoded string"
187 );
188 }
189}