use crate::auth::{AccessToken, Credentials};
use crate::error::{WebullError, WebullResult};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Mutex;
pub trait CredentialStore: Send + Sync {
fn get_credentials(&self) -> WebullResult<Option<Credentials>>;
fn store_credentials(&self, credentials: Credentials) -> WebullResult<()>;
fn clear_credentials(&self) -> WebullResult<()>;
fn get_token(&self) -> WebullResult<Option<AccessToken>>;
fn store_token(&self, token: AccessToken) -> WebullResult<()>;
fn clear_token(&self) -> WebullResult<()>;
}
#[derive(Debug, Default)]
pub struct MemoryCredentialStore {
credentials: Mutex<Option<Credentials>>,
token: Mutex<Option<AccessToken>>,
}
impl CredentialStore for MemoryCredentialStore {
fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
Ok(self.credentials.lock().unwrap().clone())
}
fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
*self.credentials.lock().unwrap() = Some(credentials);
Ok(())
}
fn clear_credentials(&self) -> WebullResult<()> {
*self.credentials.lock().unwrap() = None;
Ok(())
}
fn get_token(&self) -> WebullResult<Option<AccessToken>> {
Ok(self.token.lock().unwrap().clone())
}
fn store_token(&self, token: AccessToken) -> WebullResult<()> {
*self.token.lock().unwrap() = Some(token);
Ok(())
}
fn clear_token(&self) -> WebullResult<()> {
*self.token.lock().unwrap() = None;
Ok(())
}
}
pub struct EncryptedCredentialStore {
credentials_path: String,
token_path: String,
encryption_key: String,
memory_store: MemoryCredentialStore,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredCredentials {
encrypted_username: String,
encrypted_password: String,
iv: String,
salt: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct StoredToken {
encrypted_token: String,
encrypted_refresh_token: Option<String>,
expires_at: i64,
iv: String,
salt: String,
}
impl EncryptedCredentialStore {
pub fn new(credentials_path: String, token_path: String, encryption_key: String) -> Self {
Self {
credentials_path,
token_path,
encryption_key,
memory_store: MemoryCredentialStore::default(),
}
}
fn encrypt(&self, data: &str) -> WebullResult<(String, String, String)> {
let salt = self.generate_random_string(16);
let iv = self.generate_random_string(16);
let key = self.derive_key(&self.encryption_key, &salt)?;
let encrypted = self.encrypt_with_key(data, &key, &iv)?;
Ok((encrypted, iv, salt))
}
fn decrypt(&self, encrypted: &str, iv: &str, salt: &str) -> WebullResult<String> {
let key = self.derive_key(&self.encryption_key, salt)?;
self.decrypt_with_key(encrypted, &key, iv)
}
fn generate_random_string(&self, length: usize) -> String {
use rand::{thread_rng, Rng};
use rand::distributions::Alphanumeric;
thread_rng()
.sample_iter(&Alphanumeric)
.take(length)
.map(char::from)
.collect()
}
fn derive_key(&self, password: &str, salt: &str) -> WebullResult<Vec<u8>> {
let mut key = Vec::with_capacity(32);
let password_bytes = password.as_bytes();
let salt_bytes = salt.as_bytes();
for i in 0..32 {
let byte = password_bytes[i % password_bytes.len()] ^ salt_bytes[i % salt_bytes.len()];
key.push(byte);
}
Ok(key)
}
fn encrypt_with_key(&self, data: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
let encoded = base64::encode(data);
Ok(encoded)
}
fn decrypt_with_key(&self, encrypted: &str, _key: &[u8], _iv: &str) -> WebullResult<String> {
let decoded = base64::decode(encrypted)
.map_err(|e| WebullError::InvalidRequest(format!("Invalid data: {}", e)))?;
let decrypted = String::from_utf8(decoded)
.map_err(|e| WebullError::InvalidRequest(format!("Invalid UTF-8: {}", e)))?;
Ok(decrypted)
}
fn load_credentials(&self) -> WebullResult<Option<Credentials>> {
let path = Path::new(&self.credentials_path);
if !path.exists() {
return Ok(None);
}
let contents = std::fs::read_to_string(path)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to read credentials file: {}", e)))?;
let stored: StoredCredentials = serde_json::from_str(&contents)
.map_err(|e| WebullError::SerializationError(e))?;
let username = self.decrypt(&stored.encrypted_username, &stored.iv, &stored.salt)?;
let password = self.decrypt(&stored.encrypted_password, &stored.iv, &stored.salt)?;
Ok(Some(Credentials {
username,
password,
}))
}
fn save_credentials(&self, credentials: &Credentials) -> WebullResult<()> {
let (encrypted_username, iv, salt) = self.encrypt(&credentials.username)?;
let (encrypted_password, _, _) = self.encrypt(&credentials.password)?;
let stored = StoredCredentials {
encrypted_username,
encrypted_password,
iv,
salt,
};
let json = serde_json::to_string(&stored)
.map_err(|e| WebullError::SerializationError(e))?;
std::fs::write(&self.credentials_path, json)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to write credentials file: {}", e)))?;
Ok(())
}
fn load_token(&self) -> WebullResult<Option<AccessToken>> {
let path = Path::new(&self.token_path);
if !path.exists() {
return Ok(None);
}
let contents = std::fs::read_to_string(path)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to read token file: {}", e)))?;
let stored: StoredToken = serde_json::from_str(&contents)
.map_err(|e| WebullError::SerializationError(e))?;
let token = self.decrypt(&stored.encrypted_token, &stored.iv, &stored.salt)?;
let refresh_token = if let Some(encrypted_refresh_token) = stored.encrypted_refresh_token {
Some(self.decrypt(&encrypted_refresh_token, &stored.iv, &stored.salt)?)
} else {
None
};
let expires_at = chrono::DateTime::from_timestamp(stored.expires_at, 0)
.ok_or_else(|| WebullError::InvalidRequest("Invalid timestamp".to_string()))?;
Ok(Some(AccessToken {
token,
expires_at,
refresh_token,
}))
}
fn save_token(&self, token: &AccessToken) -> WebullResult<()> {
let (encrypted_token, iv, salt) = self.encrypt(&token.token)?;
let encrypted_refresh_token = if let Some(refresh_token) = &token.refresh_token {
Some(self.encrypt(refresh_token)?.0)
} else {
None
};
let stored = StoredToken {
encrypted_token,
encrypted_refresh_token,
expires_at: token.expires_at.timestamp(),
iv,
salt,
};
let json = serde_json::to_string(&stored)
.map_err(|e| WebullError::SerializationError(e))?;
std::fs::write(&self.token_path, json)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to write token file: {}", e)))?;
Ok(())
}
}
impl CredentialStore for EncryptedCredentialStore {
fn get_credentials(&self) -> WebullResult<Option<Credentials>> {
if let Some(credentials) = self.memory_store.get_credentials()? {
return Ok(Some(credentials));
}
let credentials = self.load_credentials()?;
if let Some(credentials) = &credentials {
self.memory_store.store_credentials(credentials.clone())?;
}
Ok(credentials)
}
fn store_credentials(&self, credentials: Credentials) -> WebullResult<()> {
self.memory_store.store_credentials(credentials.clone())?;
self.save_credentials(&credentials)?;
Ok(())
}
fn clear_credentials(&self) -> WebullResult<()> {
self.memory_store.clear_credentials()?;
let path = Path::new(&self.credentials_path);
if path.exists() {
std::fs::remove_file(path)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to remove credentials file: {}", e)))?;
}
Ok(())
}
fn get_token(&self) -> WebullResult<Option<AccessToken>> {
if let Some(token) = self.memory_store.get_token()? {
return Ok(Some(token));
}
let token = self.load_token()?;
if let Some(token) = &token {
self.memory_store.store_token(token.clone())?;
}
Ok(token)
}
fn store_token(&self, token: AccessToken) -> WebullResult<()> {
self.memory_store.store_token(token.clone())?;
self.save_token(&token)?;
Ok(())
}
fn clear_token(&self) -> WebullResult<()> {
self.memory_store.clear_token()?;
let path = Path::new(&self.token_path);
if path.exists() {
std::fs::remove_file(path)
.map_err(|e| WebullError::InvalidRequest(format!("Failed to remove token file: {}", e)))?;
}
Ok(())
}
}