openconnect_core/
storage.rs

1use chacha20poly1305::{
2    aead::{Aead, AeadCore, KeyInit, OsRng},
3    XChaCha20Poly1305, XNonce,
4};
5use rand::SeedableRng;
6use serde::{Deserialize, Serialize};
7use sha2::Digest;
8use std::{collections::HashMap, path::PathBuf};
9
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct StoredConfigsJson {
12    default: Option<String>,
13    servers: Vec<StoredServer>,
14}
15
16impl StoredConfigsJson {
17    pub fn decrypted_by(&self, encryptor: &PassEncryptor) -> Self {
18        let servers = self
19            .servers
20            .iter()
21            .map(|server| match server {
22                StoredServer::Oidc(oidc_server) => StoredServer::Oidc(oidc_server.clone()),
23                StoredServer::Password(password_server) => {
24                    StoredServer::Password(password_server.decrypted_by(encryptor))
25                }
26            })
27            .collect();
28        Self {
29            default: self.default.clone(),
30            servers,
31        }
32    }
33}
34
35impl TryFrom<(StoredConfigsJson, PathBuf)> for StoredConfigs {
36    type Error = StoredConfigError;
37
38    fn try_from(
39        (json, config_file): (StoredConfigsJson, PathBuf),
40    ) -> Result<StoredConfigs, StoredConfigError> {
41        let mut servers = HashMap::new();
42        for server in json.servers {
43            let name = match &server {
44                StoredServer::Oidc(OidcServer { name, .. }) => name,
45                StoredServer::Password(PasswordServer { name, .. }) => name,
46            };
47
48            if servers.contains_key(name) {
49                return Err(StoredConfigError::ParseError(format!(
50                    "Duplicated server name: {}, check your config file",
51                    name
52                )));
53            }
54
55            servers.insert(name.clone(), server);
56        }
57
58        Ok(StoredConfigs {
59            default: json.default,
60            servers,
61            cipher: PassEncryptor::default(),
62            config_file,
63        })
64    }
65}
66
67impl From<StoredConfigs> for StoredConfigsJson {
68    fn from(config: StoredConfigs) -> StoredConfigsJson {
69        StoredConfigsJson {
70            default: config.default,
71            servers: config.servers.into_values().collect(),
72        }
73    }
74}
75
76#[derive(Clone, Debug, Serialize, Deserialize)]
77#[serde(rename_all = "camelCase")]
78pub struct OidcServer {
79    pub name: String,
80    pub server: String,
81    pub issuer: String,
82    pub client_id: String,
83    pub client_secret: Option<String>,
84    pub allow_insecure: Option<bool>,
85    pub updated_at: Option<String>,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct PasswordServer {
91    pub name: String,
92    pub server: String,
93    pub username: String,
94    pub password: Option<String>,
95    pub allow_insecure: Option<bool>,
96    pub updated_at: Option<String>,
97}
98
99impl PasswordServer {
100    pub fn decrypted_by(&self, encryptor: &PassEncryptor) -> Self {
101        let password = self
102            .password
103            .as_ref()
104            .and_then(|p| encryptor.decrypt(p).ok());
105        Self {
106            name: self.name.clone(),
107            server: self.server.clone(),
108            username: self.username.clone(),
109            password,
110            allow_insecure: self.allow_insecure,
111            updated_at: self.updated_at.clone(),
112        }
113    }
114
115    pub fn encrypted_by(&self, encryptor: &PassEncryptor) -> Self {
116        let password = self
117            .password
118            .as_ref()
119            .and_then(|p| encryptor.encrypt(p).ok());
120        Self {
121            name: self.name.clone(),
122            server: self.server.clone(),
123            username: self.username.clone(),
124            password,
125            allow_insecure: self.allow_insecure,
126            updated_at: self.updated_at.clone(),
127        }
128    }
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
132#[serde(rename_all = "camelCase", tag = "authType")]
133pub enum StoredServer {
134    #[serde(rename_all = "camelCase")]
135    Oidc(OidcServer),
136
137    #[serde(rename_all = "camelCase")]
138    Password(PasswordServer),
139}
140
141impl TryFrom<&StoredServer> for OidcServer {
142    type Error = StoredConfigError;
143
144    fn try_from(server: &StoredServer) -> Result<OidcServer, StoredConfigError> {
145        match server {
146            StoredServer::Oidc(oidc_server) => Ok(oidc_server.clone()),
147            StoredServer::Password(_) => Err(StoredConfigError::ParseError(
148                "Server is not OIDC type".to_string(),
149            )),
150        }
151    }
152}
153
154impl TryFrom<&StoredServer> for PasswordServer {
155    type Error = StoredConfigError;
156
157    fn try_from(server: &StoredServer) -> Result<PasswordServer, StoredConfigError> {
158        match server {
159            StoredServer::Password(password_server) => Ok(password_server.clone()),
160            StoredServer::Oidc(_) => Err(StoredConfigError::ParseError(
161                "Server is not Password type".to_string(),
162            )),
163        }
164    }
165}
166
167#[derive(Clone, Debug)]
168pub struct StoredConfigs {
169    pub default: Option<String>,
170    pub servers: HashMap<String, StoredServer>,
171    pub cipher: PassEncryptor,
172    pub config_file: PathBuf,
173}
174
175#[derive(Debug, thiserror::Error)]
176pub enum StoredConfigError {
177    #[error("Bad input: {0}")]
178    BadInput(String),
179
180    #[error("Failed to parse stored config: {0}")]
181    ParseError(String),
182
183    #[error("IO error: {0}")]
184    IoError(#[from] std::io::Error),
185
186    #[error("Cipher error")]
187    CipherError(#[from] PassEncryptorError),
188}
189
190impl StoredConfigs {
191    pub fn new(pass_key: Option<String>, config_file: PathBuf) -> Self {
192        Self {
193            default: None,
194            servers: HashMap::new(),
195            cipher: PassEncryptor::new(pass_key),
196            config_file,
197        }
198    }
199
200    pub fn default_server(&self) -> Option<&StoredServer> {
201        self.default
202            .as_ref()
203            .and_then(|name| self.servers.get(name))
204    }
205
206    pub fn getorinit_config_file() -> Result<PathBuf, StoredConfigError> {
207        let home_dir = home::home_dir().ok_or(StoredConfigError::IoError(std::io::Error::new(
208            std::io::ErrorKind::NotFound,
209            "Home directory not found",
210        )))?;
211
212        let config_folder = home_dir.join(".oidcvpn");
213        if !config_folder.exists() {
214            std::fs::create_dir(&config_folder)?;
215        }
216
217        let config_file = config_folder.join("config.json");
218        if !config_file.exists() {
219            std::fs::write(&config_file, br#"{"default":null,"servers":[]}"#)?;
220        }
221
222        Ok(config_file)
223    }
224
225    pub async fn save_to_file(&self) -> Result<&Self, StoredConfigError> {
226        let json = serde_json::to_string(&StoredConfigsJson::from(self.clone())).map_err(|e| {
227            StoredConfigError::ParseError(format!("Failed to serialize config: {}", e))
228        })?;
229
230        tokio::fs::write(&self.config_file, json).await?;
231
232        Ok(self)
233    }
234
235    pub async fn read_from_file(&mut self) -> Result<&mut Self, StoredConfigError> {
236        let content = tokio::fs::read(&self.config_file).await?;
237        let config_json: StoredConfigsJson = serde_json::from_slice(&content).map_err(|e| {
238            StoredConfigError::ParseError(format!("Failed to parse config file: {}", e))
239        })?;
240        let config = StoredConfigs::try_from((config_json, self.config_file.clone()))?;
241
242        self.default = config.default;
243        self.servers = config.servers;
244
245        Ok(self)
246    }
247
248    pub fn get_server_as_oidc_server(&self, name: &str) -> Result<&OidcServer, StoredConfigError> {
249        self.servers
250            .get(name)
251            .and_then(|server| match server {
252                StoredServer::Oidc(oidc) => Some(oidc),
253                _ => None,
254            })
255            .ok_or(StoredConfigError::ParseError(format!(
256                "Server '{}' not found",
257                name
258            )))
259    }
260
261    pub fn get_server_as_password_server(
262        &self,
263        name: &str,
264    ) -> Result<&PasswordServer, StoredConfigError> {
265        self.servers
266            .get(name)
267            .and_then(|server| match server {
268                StoredServer::Password(password_server) => Some(password_server),
269                _ => None,
270            })
271            .ok_or(StoredConfigError::ParseError(format!(
272                "Server '{}' not found",
273                name
274            )))
275    }
276
277    pub async fn upsert_server(
278        &mut self,
279        server: StoredServer,
280    ) -> Result<&mut Self, StoredConfigError> {
281        let updated_at = chrono::Utc::now().to_rfc3339();
282        let mut server = server.clone();
283        let name = match &mut server {
284            StoredServer::Oidc(oidc_server) => {
285                oidc_server.updated_at = Some(updated_at);
286                oidc_server.name.to_owned()
287            }
288            StoredServer::Password(password_server) => {
289                password_server.updated_at = Some(updated_at);
290                *password_server = password_server.encrypted_by(&self.cipher);
291                password_server.name.to_owned()
292            }
293        };
294
295        *self.servers.entry(name).or_insert(server) = server.clone();
296        self.save_to_file().await?;
297        Ok(self)
298    }
299
300    pub async fn remove_server(&mut self, name: &str) -> Result<&mut Self, StoredConfigError> {
301        if self.default.as_ref().is_some_and(|d| d == name) {
302            return Err(StoredConfigError::BadInput(format!(
303                "Cannot remove default server {}",
304                name
305            )));
306        }
307        self.servers.remove(name);
308        self.save_to_file().await?;
309        Ok(self)
310    }
311
312    pub async fn set_default_server(&mut self, name: &str) -> Result<&mut Self, StoredConfigError> {
313        if !self.servers.contains_key(name) {
314            return Err(StoredConfigError::ParseError(format!(
315                "Server {} not found",
316                name
317            )));
318        }
319
320        self.default = Some(name.to_string());
321        self.save_to_file().await?;
322        Ok(self)
323    }
324}
325
326#[derive(Clone, Debug)]
327pub struct PassEncryptor {
328    secret: chacha20poly1305::Key,
329}
330
331#[derive(Debug, thiserror::Error)]
332pub enum PassEncryptorError {
333    #[error("Cipher error: {0}")]
334    CipherError(String),
335}
336
337impl Default for PassEncryptor {
338    fn default() -> Self {
339        Self::new(None)
340    }
341}
342
343impl PassEncryptor {
344    pub fn new(unique_key: Option<String>) -> Self {
345        // by default, use machine uid as unique key to generate encryption key
346        let unique_key = unique_key
347            .or_else(|| machine_uid::get().ok())
348            .unwrap_or("openconnect-rs-2024".to_string());
349
350        let mut hasher: sha2::Sha256 = sha2::digest::Digest::new();
351        hasher.update(unique_key.as_bytes());
352        let hash = hasher.finalize(); // hash is absolutely 32 bytes
353        let mut seed = rand::rngs::StdRng::from_seed(hash.into());
354        let key = XChaCha20Poly1305::generate_key(&mut seed);
355        Self { secret: key }
356    }
357
358    pub fn encrypt(&self, plaintext: &str) -> Result<String, PassEncryptorError> {
359        let cipher = XChaCha20Poly1305::new(&self.secret);
360        let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
361        let encypted = cipher.encrypt(&nonce, plaintext.as_ref()).map_err(|e| {
362            PassEncryptorError::CipherError(format!("Failed to encrypt password: {}", e))
363        })?;
364        let combined = [nonce.to_vec(), encypted].concat();
365        Ok(hex::encode(combined))
366    }
367
368    pub fn decrypt(&self, ciphertext: &str) -> Result<String, PassEncryptorError> {
369        let cipher = XChaCha20Poly1305::new(&self.secret);
370        let ciphertext = hex::decode(ciphertext).map_err(|e| {
371            PassEncryptorError::CipherError(format!("Failed to decrypt password: {}", e))
372        })?;
373        let nonce = XNonce::from_slice(&ciphertext[..24]);
374        let plaintext = cipher.decrypt(nonce, &ciphertext[24..]).map_err(|e| {
375            PassEncryptorError::CipherError(format!("Failed to decrypt password: {}", e))
376        })?;
377        String::from_utf8(plaintext).map_err(|e| {
378            PassEncryptorError::CipherError(format!("Failed to decrypt password: {}", e))
379        })
380    }
381}
382
383#[test]
384fn test_pass_enc() {
385    let encryptor = PassEncryptor::default();
386    let password = "password";
387    let encrypted = encryptor.encrypt(password).unwrap();
388    let decrypted = encryptor.decrypt(&encrypted).unwrap();
389    assert_eq!(password, decrypted);
390}
391
392#[tokio::test]
393async fn test_read_config() {
394    let config_file = StoredConfigs::getorinit_config_file().unwrap();
395    let mut stored_configs = StoredConfigs::new(None, config_file);
396    stored_configs.read_from_file().await.unwrap();
397    println!("parsed struct: {:#?}", stored_configs);
398
399    let stored_configs_json = StoredConfigsJson::from(stored_configs.clone());
400    let json = serde_json::to_string(&stored_configs_json).unwrap();
401    println!("json: {}", json);
402}
403
404#[tokio::test]
405async fn test_save_config() {
406    let server = StoredServer::Oidc(OidcServer {
407        name: "test".to_string(),
408        server: "https://example.com".to_string(),
409        issuer: "https://example.com".to_string(),
410        client_id: "client_id".to_string(),
411        client_secret: Some("client_secret".to_string()),
412        allow_insecure: Some(true),
413        updated_at: None,
414    });
415
416    let config_file = StoredConfigs::getorinit_config_file().unwrap();
417    let mut stored_config = StoredConfigs::new(None, config_file.clone());
418    let config = stored_config
419        .read_from_file()
420        .await
421        .unwrap()
422        .upsert_server(server)
423        .await
424        .unwrap()
425        .save_to_file()
426        .await
427        .unwrap();
428
429    println!("saved: {:?}", config);
430    println!(
431        "read: {:?}",
432        StoredConfigs::new(None, config_file)
433            .read_from_file()
434            .await
435            .unwrap()
436    );
437}
438
439#[tokio::test]
440async fn test_config_type() {
441    let server = StoredServer::Oidc(OidcServer {
442        name: "oidc_server".to_string(),
443        server: "https://example.com".to_string(),
444        issuer: "https://example.com".to_string(),
445        client_id: "client_id".to_string(),
446        client_secret: None,
447        allow_insecure: Some(true),
448        updated_at: None,
449    });
450
451    let json = serde_json::to_string(&server).unwrap();
452    assert_eq!(
453        json,
454        r#"{"authType":"oidc","server":"https://example.com","issuer":"https://example.com","clientId":"client_id","clientSecret":null,"updatedAt":null}"#
455    );
456
457    let server = StoredServer::Password(PasswordServer {
458        name: "password_server".to_string(),
459        server: "https://example.com".to_string(),
460        username: "username".to_string(),
461        password: Some("password".to_string()),
462        allow_insecure: Some(true),
463        updated_at: None,
464    });
465
466    let json = serde_json::to_string(&server).unwrap();
467    assert_eq!(
468        json,
469        r#"{"authType":"password","server":"https://example.com","username":"username","password":"password","updatedAt":null}"#
470    );
471}