1use crate::auth::{AccessToken, Credentials};
2use crate::error::{WebullError, WebullResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5use std::sync::Mutex;
6
7pub trait CredentialStore: Send + Sync {
9 fn get_credentials(&self) -> WebullResult<Option<Credentials>>;
11
12 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()>;
14
15 fn clear_credentials(&self) -> WebullResult<()>;
17
18 fn get_token(&self) -> WebullResult<Option<AccessToken>>;
20
21 fn store_token(&self, token: AccessToken) -> WebullResult<()>;
23
24 fn clear_token(&self) -> WebullResult<()>;
26}
27
28#[derive(Debug, Default)]
30pub struct MemoryCredentialStore {
31 credentials: Mutex<Option<Credentials>>,
33
34 token: Mutex<Option<AccessToken>>,
36}
37
38impl CredentialStore for MemoryCredentialStore {
39 fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
40 Ok(self.credentials.lock().unwrap().clone())
41 }
42
43 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
44 *self.credentials.lock().unwrap() = Some(credentials);
45 Ok(())
46 }
47
48 fn clear_credentials(&self) -> WebullResult<()> {
49 *self.credentials.lock().unwrap() = None;
50 Ok(())
51 }
52
53 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54 Ok(self.token.lock().unwrap().clone())
55 }
56
57 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58 *self.token.lock().unwrap() = Some(token);
59 Ok(())
60 }
61
62 fn clear_token(&self) -> WebullResult<()> {
63 *self.token.lock().unwrap() = None;
64 Ok(())
65 }
66}
67
68pub struct EncryptedCredentialStore {
70 credentials_path: String,
72
73 token_path: String,
75
76 encryption_key: String,
78
79 memory_store: MemoryCredentialStore,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85struct StoredCredentials {
86 encrypted_username: String,
88
89 encrypted_password: String,
91
92 iv: String,
94
95 salt: String,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101struct StoredToken {
102 encrypted_token: String,
104
105 encrypted_refresh_token: Option<String>,
107
108 expires_at: i64,
110
111 iv: String,
113
114 salt: String,
116}
117
118impl EncryptedCredentialStore {
119 pub fn new(credentials_path: String, token_path: String, encryption_key: String) -> Self {
121 Self {
122 credentials_path,
123 token_path,
124 encryption_key,
125 memory_store: MemoryCredentialStore::default(),
126 }
127 }
128
129 fn encrypt(&self, data: &str) -> WebullResult<(String, String, String)> {
131 let salt = self.generate_random_string(16);
133 let iv = self.generate_random_string(16);
134
135 let key = self.derive_key(&self.encryption_key, &salt)?;
137
138 let encrypted = self.encrypt_with_key(data, &key, &iv)?;
140
141 Ok((encrypted, iv, salt))
142 }
143
144 fn decrypt(&self, encrypted: &str, iv: &str, salt: &str) -> WebullResult<String> {
146 let key = self.derive_key(&self.encryption_key, salt)?;
148
149 self.decrypt_with_key(encrypted, &key, iv)
151 }
152
153 fn generate_random_string(&self, length: usize) -> String {
155 use rand::distributions::Alphanumeric;
156 use rand::{thread_rng, Rng};
157
158 thread_rng()
159 .sample_iter(&Alphanumeric)
160 .take(length)
161 .map(char::from)
162 .collect()
163 }
164
165 fn derive_key(&self, password: &str, salt: &str) -> WebullResult<Vec<u8>> {
167 let mut key = Vec::with_capacity(32);
172 let password_bytes = password.as_bytes();
173 let salt_bytes = salt.as_bytes();
174
175 for i in 0..32 {
176 let byte = password_bytes[i % password_bytes.len()] ^ salt_bytes[i % salt_bytes.len()];
177 key.push(byte);
178 }
179
180 Ok(key)
181 }
182
183 fn encrypt_with_key(&self, data: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
185 let encoded = base64::encode(data);
190 Ok(encoded)
191 }
192
193 fn decrypt_with_key(&self, encrypted: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
195 let decoded = base64::decode(encrypted)
200 .map_err(|e| WebullError::InvalidRequest(format!("Invalid data: {}", e)))?;
201
202 let decrypted = String::from_utf8(decoded)
203 .map_err(|e| WebullError::InvalidRequest(format!("Invalid UTF-8: {}", e)))?;
204
205 Ok(decrypted)
206 }
207
208 fn load_credentials(&self) -> WebullResult<Option<Credentials>> {
210 let path = Path::new(&self.credentials_path);
212 if !path.exists() {
213 return Ok(None);
214 }
215
216 let contents = std::fs::read_to_string(path).map_err(|e| {
218 WebullError::InvalidRequest(format!("Failed to read credentials file: {}", e))
219 })?;
220
221 let stored: StoredCredentials =
223 serde_json::from_str(&contents).map_err(|e| WebullError::SerializationError(e))?;
224
225 let username = self.decrypt(&stored.encrypted_username, &stored.iv, &stored.salt)?;
227 let password = self.decrypt(&stored.encrypted_password, &stored.iv, &stored.salt)?;
228
229 Ok(Some(Credentials { username, password }))
230 }
231
232 fn save_credentials(&self, credentials: &Credentials) -> WebullResult<()> {
234 let (encrypted_username, iv, salt) = self.encrypt(&credentials.username)?;
236 let (encrypted_password, _, _) = self.encrypt(&credentials.password)?;
237
238 let stored = StoredCredentials {
240 encrypted_username,
241 encrypted_password,
242 iv,
243 salt,
244 };
245
246 let json =
248 serde_json::to_string(&stored).map_err(|e| WebullError::SerializationError(e))?;
249
250 std::fs::write(&self.credentials_path, json).map_err(|e| {
252 WebullError::InvalidRequest(format!("Failed to write credentials file: {}", e))
253 })?;
254
255 Ok(())
256 }
257
258 fn load_token(&self) -> WebullResult<Option<AccessToken>> {
260 let path = Path::new(&self.token_path);
262 if !path.exists() {
263 return Ok(None);
264 }
265
266 let contents = std::fs::read_to_string(path).map_err(|e| {
268 WebullError::InvalidRequest(format!("Failed to read token file: {}", e))
269 })?;
270
271 let stored: StoredToken =
273 serde_json::from_str(&contents).map_err(|e| WebullError::SerializationError(e))?;
274
275 let token = self.decrypt(&stored.encrypted_token, &stored.iv, &stored.salt)?;
277
278 let refresh_token = if let Some(encrypted_refresh_token) = stored.encrypted_refresh_token {
280 Some(self.decrypt(&encrypted_refresh_token, &stored.iv, &stored.salt)?)
281 } else {
282 None
283 };
284
285 let expires_at = chrono::DateTime::from_timestamp(stored.expires_at, 0)
287 .ok_or_else(|| WebullError::InvalidRequest("Invalid timestamp".to_string()))?;
288
289 Ok(Some(AccessToken {
290 token,
291 expires_at,
292 refresh_token,
293 }))
294 }
295
296 fn save_token(&self, token: &AccessToken) -> WebullResult<()> {
298 let (encrypted_token, iv, salt) = self.encrypt(&token.token)?;
300
301 let encrypted_refresh_token = if let Some(refresh_token) = &token.refresh_token {
303 Some(self.encrypt(refresh_token)?.0)
304 } else {
305 None
306 };
307
308 let stored = StoredToken {
310 encrypted_token,
311 encrypted_refresh_token,
312 expires_at: token.expires_at.timestamp(),
313 iv,
314 salt,
315 };
316
317 let json =
319 serde_json::to_string(&stored).map_err(|e| WebullError::SerializationError(e))?;
320
321 std::fs::write(&self.token_path, json).map_err(|e| {
323 WebullError::InvalidRequest(format!("Failed to write token file: {}", e))
324 })?;
325
326 Ok(())
327 }
328}
329
330impl CredentialStore for EncryptedCredentialStore {
331 fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
332 if let Some(credentials) = self.memory_store.get_credentials()? {
334 return Ok(Some(credentials));
335 }
336
337 let credentials = self.load_credentials()?;
339
340 if let Some(credentials) = &credentials {
342 self.memory_store.store_credentials(credentials.clone())?;
343 }
344
345 Ok(credentials)
346 }
347
348 fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
349 self.memory_store.store_credentials(credentials.clone())?;
351
352 self.save_credentials(&credentials)?;
354
355 Ok(())
356 }
357
358 fn clear_credentials(&self) -> WebullResult<()> {
359 self.memory_store.clear_credentials()?;
361
362 let path = Path::new(&self.credentials_path);
364 if path.exists() {
365 std::fs::remove_file(path).map_err(|e| {
366 WebullError::InvalidRequest(format!("Failed to remove credentials file: {}", e))
367 })?;
368 }
369
370 Ok(())
371 }
372
373 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
374 if let Some(token) = self.memory_store.get_token()? {
376 return Ok(Some(token));
377 }
378
379 let token = self.load_token()?;
381
382 if let Some(token) = &token {
384 self.memory_store.store_token(token.clone())?;
385 }
386
387 Ok(token)
388 }
389
390 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
391 self.memory_store.store_token(token.clone())?;
393
394 self.save_token(&token)?;
396
397 Ok(())
398 }
399
400 fn clear_token(&self) -> WebullResult<()> {
401 self.memory_store.clear_token()?;
403
404 let path = Path::new(&self.token_path);
406 if path.exists() {
407 std::fs::remove_file(path).map_err(|e| {
408 WebullError::InvalidRequest(format!("Failed to remove token file: {}", e))
409 })?;
410 }
411
412 Ok(())
413 }
414}