Skip to main content

aes_256_gcm/
lib.rs

1//!
2//! Inspired by simplicity of (let's say JWT) this crate provide high-level
3//! abstraction for AES-GCM Crate with some enhancement such as expiration check
4//!
5//! for detailed usage, please refer to the readme and example in the repository
6//!
7use aes_gcm::{
8    aead::{generic_array::GenericArray, Aead, KeyInit, OsRng},
9    aes::{
10        cipher::typenum::{
11            bit::{B0, B1},
12            UInt, UTerm,
13        },
14        Aes256,
15    },
16    AeadCore, Aes256Gcm, AesGcm as AG,
17};
18
19type AesGcm = AG<
20    Aes256,
21    UInt<UInt<UInt<UInt<UTerm, B1>, B1>, B0>, B0>,
22    UInt<UInt<UInt<UInt<UInt<UTerm, B1>, B0>, B0>, B0>, B0>,
23>;
24type AesGeneric = GenericArray<u8, UInt<UInt<UInt<UInt<UTerm, B1>, B1>, B0>, B0>>;
25type AesClient = AesGcm;
26
27#[derive(Debug, PartialEq, PartialOrd, Eq)]
28pub enum AesErrorCode {
29    EncryptDataNotValid,
30    EncryptOptionError,
31    EncryptFailed,
32
33    DecryptDataNotValid,
34    DecryptStringConvention,
35
36    Expired,
37}
38
39#[derive(Debug)]
40pub struct AesError {
41    pub code: AesErrorCode,
42    pub note: &'static str,
43}
44
45#[derive(serde::Serialize, serde::Deserialize, Debug)]
46pub struct AesOptions {
47    expire: Option<String>,
48}
49
50impl AesOptions {
51    pub fn with_expire_second(expire: i64) -> Self {
52        let microsecond = expire * 1_000_000;
53        let expire = chrono::Utc::now() + chrono::Duration::microseconds(microsecond);
54        let expire = expire.to_rfc3339();
55        Self {
56            expire: Some(expire),
57        }
58    }
59
60    pub fn with_expire_date(expire: chrono::DateTime<chrono::Utc>) -> Self {
61        let expire = expire.to_rfc3339();
62        Self {
63            expire: Some(expire),
64        }
65    }
66
67    pub fn build(self) -> AesOptions {
68        self
69    }
70}
71
72#[derive(Clone)]
73pub struct Client {
74    client: AesClient,
75}
76
77impl Client {
78    pub fn new<'a>(secret: impl Into<Option<&'a str>>) -> Self {
79        let data: Option<&'a str> = secret.into();
80
81        let aes_secret = match data {
82            Some(d) => Some(d.to_string()),
83            None => None,
84        };
85
86        let aes_secret = match aes_secret {
87            Some(data) => data,
88            None => {
89                let env = std::env::var("AES_GCM_SECRET").expect(
90                    "if you are not using parameter, AES_GCM_SECRET os ENV must present or fill the Client::new(secret) parameter"
91                );
92                env
93            }
94        };
95
96        let mut aes_secret = aes_secret.as_bytes().to_vec();
97        if aes_secret.len() > 32 {
98            aes_secret.truncate(32);
99        } else {
100            while aes_secret.len() < 32 {
101                aes_secret.push(0);
102            }
103        }
104        let aes_key = GenericArray::from_slice(&aes_secret);
105        let client: AesClient = Aes256Gcm::new(&aes_key);
106        Self { client }
107    }
108
109    pub fn encrypt<T>(
110        &self,
111        data: T,
112        option: impl Into<Option<AesOptions>>,
113    ) -> Result<String, AesError>
114    where
115        T: serde::Serialize,
116    {
117        let data = match serde_json::to_string(&data) {
118            Ok(data) => data,
119            Err(_) => return Err(AesError {
120                code: AesErrorCode::EncryptDataNotValid,
121                note:
122                    "Data of encryption is not valid data, data must be able to serialize to json",
123            }),
124        };
125        let mut opt = String::new();
126        let optx: Option<AesOptions> = option.into();
127        if let Some(optn) = optx {
128            let json_opt = match serde_json::to_string(&optn) {
129                Ok(data) => data,
130                Err(_) => {
131                    return Err(AesError {
132                        code: AesErrorCode::EncryptOptionError,
133                        note: "Option failed to build, please check correct parameter",
134                    })
135                }
136            };
137            opt = json_opt;
138        }
139        let data = format!("{}::{}", data, opt);
140        let data = data.as_bytes();
141        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
142        let encrypt = match self.client.encrypt(&nonce, data.as_ref()) {
143            Ok(data) => data,
144            Err(_) => {
145                return Err(AesError {
146                    code: AesErrorCode::EncryptFailed,
147                    note: "Failed to encrypt data, please check correct parameter",
148                })
149            }
150        };
151        let nonce = nonce
152            .iter()
153            .map(|b| format!("{:02x}", b))
154            .collect::<String>();
155        let data = encrypt
156            .iter()
157            .map(|x| format!("{:02x}", x))
158            .collect::<String>();
159        let data = format!("{}{}", data, nonce);
160        Ok(data)
161    }
162
163    pub fn decrypt<'a, T>(&self, data: &'a str) -> Result<T, AesError>
164    where
165        for<'de> T: serde::Deserialize<'de>,
166    {
167        let (data, nonce) = data.split_at(data.len() - 24);
168        let nonce = nonce
169            .chars()
170            .collect::<Vec<char>>()
171            .chunks(2)
172            .map(|x| x.iter().collect::<String>())
173            .map(|x| u8::from_str_radix(&x, 16).unwrap_or(0))
174            .collect::<AesGeneric>();
175        let data = data
176            .chars()
177            .collect::<Vec<char>>()
178            .chunks(2)
179            .map(|x| x.iter().collect::<String>())
180            .map(|x| u8::from_str_radix(&x, 16).unwrap_or(0))
181            .collect::<Vec<u8>>();
182        let decrypt = match self.client.decrypt(&nonce, data.as_ref()) {
183            Ok(data) => data,
184            Err(_) => {
185                return Err(AesError {
186                    code: AesErrorCode::DecryptDataNotValid,
187                    note: "Input data or token does not have valid encryption data",
188                })
189            }
190        };
191        // to string
192        let str_decrypt = match std::str::from_utf8(&decrypt) {
193            Ok(data) => data,
194            Err(_) => {
195                return Err(AesError {
196                    code: AesErrorCode::DecryptStringConvention,
197                    note:
198                        "String convention failed, either it is not valid string or not utf8 string",
199                })
200            }
201        };
202        let str_decrypt: Vec<&str> = str_decrypt.split("::").collect();
203        let decrypt = str_decrypt[0];
204        // decrypt to byte
205        let decrypt = decrypt.as_bytes();
206
207        if str_decrypt.len() > 1 {
208            let option_decrypt = str_decrypt[1].as_bytes();
209            let data_expiry = match serde_json::from_slice::<AesOptions>(&option_decrypt) {
210                Ok(data) => data,
211                Err(_) => AesOptions { expire: None },
212            };
213
214            if let Some(expire) = data_expiry.expire {
215                let date_now = chrono::Utc::now();
216                let date_exp = chrono::DateTime::parse_from_rfc3339(&expire);
217                match date_exp {
218                    Ok(date_exp) => {
219                        if date_now > date_exp {
220                            return Err(AesError {
221                                code: AesErrorCode::Expired,
222                                note: "This data/token simply expired",
223                            });
224                        }
225                    }
226                    Err(_) => {}
227                }
228            }
229        }
230
231        let data = match serde_json::from_slice::<T>(&decrypt) {
232            Ok(data) => data,
233            Err(_) => {
234                return Err(AesError {
235                    code: AesErrorCode::DecryptDataNotValid,
236                    note: "Data of decryption is not valid encrypted data",
237                })
238            }
239        };
240        Ok(data)
241    }
242}
243
244#[allow(dead_code)]
245fn main() {}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    ///
252    /// Test case 1 \
253    /// Simple encrypt and decrypt string
254    ///
255    #[test]
256    fn test_case_1() {
257        std::env::set_var("AES_GCM_SECRET", "some key");
258        let client = Client::new(None);
259        let encrypted = client.encrypt("my thing", None);
260        let decrypted: String = client.decrypt(&encrypted.unwrap()).unwrap();
261        assert_eq!(decrypted, "my thing");
262    }
263
264    ///
265    /// Test case 2 \
266    /// Simple encrypt and decrypt struct
267    ///
268    #[test]
269    fn test_case_2() {
270        std::env::set_var("AES_GCM_SECRET", "some key");
271        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
272        struct TestCase2 {
273            pub name: String,
274        }
275
276        let client = Client::new(None);
277        let data = TestCase2 {
278            name: "my name".to_string(),
279        };
280
281        let encrypted = client.encrypt(&data, None);
282        let decrypted: TestCase2 = client.decrypt(&encrypted.unwrap()).unwrap();
283        assert_eq!(decrypted, data);
284    }
285
286    ///
287    /// Test case 3 \
288    /// Simple encrypt and decrypt string with expire
289    ///
290    #[test]
291    fn test_case_3() {
292        std::env::set_var("AES_GCM_SECRET", "some key");
293        let client = Client::new(None);
294        let encrypted = client.encrypt("my thing", AesOptions::with_expire_second(3).build());
295        // sleep 5 second
296        std::thread::sleep(std::time::Duration::from_secs(2));
297        let decrypted = client.decrypt::<String>(&encrypted.unwrap());
298        if let Ok(data) = decrypted {
299            assert_eq!(data, "my thing");
300        } else {
301            assert!(false);
302        }
303    }
304
305    ///
306    /// Test case 4 \
307    /// Simple encrypt and decrypt string with expire
308    ///
309    #[test]
310    fn test_case_4() {
311        std::env::set_var("AES_GCM_SECRET", "some key");
312        let client = Client::new(None);
313        let encrypted = client.encrypt("my thing", AesOptions::with_expire_second(3).build());
314        // sleep 5 second
315        std::thread::sleep(std::time::Duration::from_secs(4));
316        let decrypted = client.decrypt::<String>(&encrypted.unwrap());
317        if let Err(e) = decrypted {
318            assert_eq!(e.code, AesErrorCode::Expired);
319        } else {
320            assert!(false);
321        }
322    }
323
324    ///
325    /// Test case 5 \
326    /// Simple encrypt and decrypt struct with expire
327    ///
328    #[test]
329    fn test_case_5() {
330        std::env::set_var("AES_GCM_SECRET", "some key");
331        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
332        struct TestCase2 {
333            pub name: String,
334        }
335
336        let client = Client::new(None);
337        let data = TestCase2 {
338            name: "my name".to_string(),
339        };
340        let encrypted = client.encrypt(data, AesOptions::with_expire_second(3).build());
341        // sleep 5 second
342        std::thread::sleep(std::time::Duration::from_secs(2));
343        let decrypted = client.decrypt::<TestCase2>(&encrypted.unwrap());
344        if let Ok(data) = decrypted {
345            assert_eq!(
346                data,
347                TestCase2 {
348                    name: "my name".to_string()
349                }
350            );
351        } else {
352            assert!(false);
353        }
354    }
355
356    ///
357    /// Test case 6 \
358    /// Simple encrypt and decrypt struct with expire
359    ///
360    #[test]
361    fn test_case_6() {
362        std::env::set_var("AES_GCM_SECRET", "some key");
363        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
364        struct TestCase2 {
365            pub name: String,
366        }
367
368        let client = Client::new(None);
369        let data = TestCase2 {
370            name: "my name".to_string(),
371        };
372        let encrypted = client.encrypt(data, AesOptions::with_expire_second(3).build());
373        // sleep 5 second
374        std::thread::sleep(std::time::Duration::from_secs(4));
375        let decrypted = client.decrypt::<TestCase2>(&encrypted.unwrap());
376        if let Err(e) = decrypted {
377            assert_eq!(e.code, AesErrorCode::Expired);
378        } else {
379            assert!(false);
380        }
381    }
382
383    ///
384    /// Test case 7
385    /// secret with &str
386    ///
387    #[test]
388    fn test_case_7() {
389        // remove os env
390        std::env::remove_var("AES_GCM_SECRET");
391        let secrets = "my secret";
392        let client = Client::new(secrets);
393        let encrypted = client.encrypt("my thing", None);
394        let decrypted: String = client.decrypt(&encrypted.unwrap()).unwrap();
395        assert_eq!(decrypted, "my thing");
396    }
397
398    ///
399    /// Test case 8
400    /// secret with String
401    ///
402    #[test]
403    fn test_case_8() {
404        std::env::remove_var("AES_GCM_SECRET");
405        let secrets = String::from("my secret");
406        let client = Client::new(&*secrets);
407        let encrypted = client.encrypt("my thing", None);
408        let decrypted: String = client.decrypt(&encrypted.unwrap()).unwrap();
409        assert_eq!(decrypted, "my thing");
410    }
411
412    ///
413    /// Test case 9
414    /// no secret and no env
415    ///
416    #[test]
417    #[should_panic]
418    fn test_case_9() {
419        std::env::remove_var("AES_GCM_SECRET");
420        // expect it to be panic
421        Client::new(None);
422    }
423}