1use aes_gcm::{
2 aead::{generic_array::GenericArray, Aead, KeyInit},
3 Aes256Gcm, Nonce,
4};
5use base64::engine::general_purpose;
6use base64::Engine;
7use serde::de::DeserializeOwned;
8use std::{fmt::Debug, str::FromStr, string::FromUtf8Error};
9
10#[derive(Debug, thiserror::Error)]
11pub enum ConfigError {
12 #[error("get path error")]
13 PathError(#[from] std::io::Error),
14
15 #[error("parse toml error")]
16 TomlError(#[from] toml::de::Error),
17
18 #[error("encrypt or decrypt error")]
19 AesError(aes_gcm::aead::Error),
20
21 #[error("string from utf8 error")]
22 StringFromUtf8Error(#[from] FromUtf8Error),
23
24 #[error("unknown config type: {0}")]
25 UnknownConfigType(String),
26
27 #[error("salt need 32 bity")]
28 SaltLenError,
29
30 #[error("base64 decode error")]
31 Base64Error(#[from] base64::DecodeError),
32}
33
34impl From<aes_gcm::aead::Error> for ConfigError {
35 fn from(e: aes_gcm::aead::Error) -> Self {
36 Self::AesError(e)
37 }
38}
39
40#[derive(Debug, Clone, Copy)]
41pub enum ConfigType {
42 TOML,
43 JSON,
44 INI,
45}
46
47#[derive(Debug)]
48pub struct ConfigInfo {
49 path: String,
50 salt: Option<String>,
51 file_type: ConfigType,
52}
53
54const FIXED_NONCE: [u8; 12] = [
55 0x12, 0x34, 0x56, 0x70, 0x9a, 0xba, 0x99, 0xf9, 0x12, 0x34, 0x56, 0x78,
56];
57
58impl FromStr for ConfigType {
59 type Err = ConfigError;
60 fn from_str(s: &str) -> Result<Self, Self::Err> {
61 match s {
62 "toml" => Ok(ConfigType::TOML),
63 "json" => Ok(ConfigType::JSON),
64 "ini" => Ok(ConfigType::INI),
65 unknown => Err(ConfigError::UnknownConfigType(unknown.to_string())),
66 }
67 }
68}
69
70impl ConfigInfo {
71 pub fn new(
72 path: String,
73 mut salt: Option<String>,
74 file_type: ConfigType,
75 ) -> Result<Self, ConfigError> {
76 if let Some(key) = salt.clone() {
77 if key.len() != 32 {
78 return Err(ConfigError::SaltLenError);
79 }
80 } else {
81 if let Ok(key) = std::env::var("AES_CONFIG_KEY") {
82 if key.len() != 32 {
83 return Err(ConfigError::SaltLenError);
84 }
85 salt = Some(key);
86 }
87 }
88
89 Ok(Self {
90 path,
91 salt,
92 file_type,
93 })
94 }
95
96 pub fn try_get_config<T: DeserializeOwned>(&self) -> Result<T, ConfigError> {
97 let config_string = self.try_decrypt_config()?;
98
99 match self.file_type {
100 ConfigType::TOML => {
101 let t: T = toml::from_str(&config_string)?;
102 Ok(t)
103 }
104 ConfigType::JSON => {
105 todo!()
106 }
107 ConfigType::INI => {
108 todo!()
109 }
110 }
111 }
112
113 pub fn try_encrypt_config(&self) -> Result<String, ConfigError> {
114 let config_string = std::fs::read_to_string(&self.path)?;
115 if let Some(salt) = &self.salt {
116 let salt = salt.as_bytes().try_into().unwrap();
117 return encrypt_config(config_string.as_bytes(), salt);
118 } else {
119 return Err(ConfigError::SaltLenError);
120 }
121 }
122
123 pub fn try_decrypt_config(&self) -> Result<String, ConfigError> {
124 let config_string = std::fs::read_to_string(&self.path)?;
125 if let Some(salt) = &self.salt {
126 let salt = salt.as_bytes().try_into().unwrap();
127 let plain = decrypt_config(config_string, salt)?;
128 let config_string = String::from_utf8(plain)?;
129 return Ok(config_string);
130 } else {
131 return Ok(config_string);
132 }
133 }
134}
135
136fn encrypt_config(config: &[u8], key: &[u8; 32]) -> Result<String, ConfigError> {
137 let key = GenericArray::from_slice(key);
138 let nonce = Nonce::from_slice(&FIXED_NONCE);
139 let cipher = Aes256Gcm::new(key).encrypt(nonce, config)?;
140 Ok(general_purpose::STANDARD.encode(cipher))
141}
142
143fn decrypt_config(cipher: String, key: &[u8; 32]) -> Result<Vec<u8>, ConfigError> {
144 let cipher = general_purpose::STANDARD.decode(cipher)?;
145 let key = GenericArray::from_slice(key);
146 let nonce = Nonce::from_slice(&FIXED_NONCE);
147 let plain = Aes256Gcm::new(key).decrypt(nonce, cipher.as_slice())?;
148 Ok(plain)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn encrypt_config_should_work() {
157 let config = "hello world";
158 let key = [0u8; 32];
159 let cipher = encrypt_config(config.as_bytes(), &key).unwrap();
160 let plain = decrypt_config(cipher, &key).unwrap();
161 assert_eq!(config.as_bytes(), plain.as_slice());
162 assert_eq!(config, String::from_utf8(plain).unwrap());
163 }
164}