use crate::error::{WebullError, WebullResult};
use crate::config::WebullConfig;
use crate::utils::crypto::{encrypt_password, generate_signature, generate_timestamp};
use crate::utils::serialization::{from_json, to_json};
use chrono::{DateTime, Utc};
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccessToken {
pub token: String,
pub expires_at: DateTime<Utc>,
pub refresh_token: Option<String>,
}
pub trait TokenStore: Send + Sync {
fn get_token(&self) -> WebullResult<Option<AccessToken>>;
fn store_token(&self, token: AccessToken) -> WebullResult<()>;
fn clear_token(&self) -> WebullResult<()>;
}
#[derive(Debug, Default)]
pub struct MemoryTokenStore {
token: Mutex<Option<AccessToken>>,
}
impl TokenStore for MemoryTokenStore {
fn get_token(&self) -> WebullResult<Option<AccessToken>> {
Ok(self.token.lock().unwrap().clone())
}
fn store_token(&self, token: AccessToken) -> WebullResult<()> {
*self.token.lock().unwrap() = Some(token);
Ok(())
}
fn clear_token(&self) -> WebullResult<()> {
*self.token.lock().unwrap() = None;
Ok(())
}
}
pub struct AuthManager {
credentials: Option<Credentials>,
pub token_store: Box<dyn TokenStore>,
config: WebullConfig,
client: reqwest::Client,
}
impl AuthManager {
pub fn new(
config: WebullConfig,
token_store: Box<dyn TokenStore>,
client: reqwest::Client,
) -> Self {
Self {
credentials: None,
token_store,
config,
client,
}
}
pub async fn authenticate(&mut self, username: &str, password: &str) -> WebullResult<AccessToken> {
self.credentials = Some(Credentials {
username: username.to_string(),
password: password.to_string(),
});
let encrypted_password = encrypt_password(password, &self.config.api_secret.clone().unwrap_or_default())?;
let body = json!({
"username": username,
"password": encrypted_password,
"deviceId": self.config.device_id.clone().unwrap_or_default(),
"deviceName": "Rust API Client",
"deviceType": "Web",
});
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(api_key) = &self.config.api_key {
headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
}
let timestamp = generate_timestamp();
let signature = if let Some(api_secret) = &self.config.api_secret {
let message = format!("{}{}", timestamp, to_json(&body)?);
generate_signature(api_secret, &message)?
} else {
String::new()
};
headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
let response = self.client.post(format!("{}/api/passport/login/v5/account", self.config.base_url))
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| WebullError::NetworkError(e))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 401 {
return Err(WebullError::Unauthorized);
} else if status.as_u16() == 429 {
return Err(WebullError::RateLimitExceeded);
} else {
return Err(WebullError::ApiError {
code: status.as_u16().to_string(),
message: text,
});
}
}
let response_text = response.text().await
.map_err(|e| WebullError::NetworkError(e))?;
#[derive(Debug, Deserialize)]
struct LoginResponse {
access_token: String,
refresh_token: String,
token_type: String,
expires_in: i64,
}
let login_response: LoginResponse = from_json(&response_text)?;
let token = AccessToken {
token: login_response.access_token,
expires_at: Utc::now() + chrono::Duration::seconds(login_response.expires_in),
refresh_token: Some(login_response.refresh_token),
};
self.token_store.store_token(token.clone())?;
Ok(token)
}
pub async fn multi_factor_auth(&mut self, mfa_code: &str) -> WebullResult<AccessToken> {
let credentials = self.credentials.as_ref()
.ok_or_else(|| WebullError::InvalidRequest("No credentials available for MFA".to_string()))?;
let body = json!({
"username": credentials.username,
"verificationCode": mfa_code,
"deviceId": self.config.device_id.clone().unwrap_or_default(),
});
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(api_key) = &self.config.api_key {
headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
}
let timestamp = generate_timestamp();
let signature = if let Some(api_secret) = &self.config.api_secret {
let message = format!("{}{}", timestamp, to_json(&body)?);
generate_signature(api_secret, &message)?
} else {
String::new()
};
headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
let response = self.client.post(format!("{}/api/passport/verificationCode/verify", self.config.base_url))
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| WebullError::NetworkError(e))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 401 {
return Err(WebullError::Unauthorized);
} else if status.as_u16() == 429 {
return Err(WebullError::RateLimitExceeded);
} else {
return Err(WebullError::ApiError {
code: status.as_u16().to_string(),
message: text,
});
}
}
let response_text = response.text().await
.map_err(|e| WebullError::NetworkError(e))?;
#[derive(Debug, Deserialize)]
struct MfaResponse {
access_token: String,
refresh_token: String,
token_type: String,
expires_in: i64,
}
let mfa_response: MfaResponse = from_json(&response_text)?;
let token = AccessToken {
token: mfa_response.access_token,
expires_at: Utc::now() + chrono::Duration::seconds(mfa_response.expires_in),
refresh_token: Some(mfa_response.refresh_token),
};
self.token_store.store_token(token.clone())?;
Ok(token)
}
pub async fn refresh_token(&mut self) -> WebullResult<AccessToken> {
let current_token = self.token_store.get_token()?
.ok_or_else(|| WebullError::InvalidRequest("No token available for refresh".to_string()))?;
let refresh_token = current_token.refresh_token
.ok_or_else(|| WebullError::InvalidRequest("No refresh token available".to_string()))?;
let body = json!({
"refreshToken": refresh_token,
"deviceId": self.config.device_id.clone().unwrap_or_default(),
});
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(api_key) = &self.config.api_key {
headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
}
let timestamp = generate_timestamp();
let signature = if let Some(api_secret) = &self.config.api_secret {
let message = format!("{}{}", timestamp, to_json(&body)?);
generate_signature(api_secret, &message)?
} else {
String::new()
};
headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
let response = self.client.post(format!("{}/api/passport/refreshToken", self.config.base_url))
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| WebullError::NetworkError(e))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 401 {
return Err(WebullError::Unauthorized);
} else if status.as_u16() == 429 {
return Err(WebullError::RateLimitExceeded);
} else {
return Err(WebullError::ApiError {
code: status.as_u16().to_string(),
message: text,
});
}
}
let response_text = response.text().await
.map_err(|e| WebullError::NetworkError(e))?;
#[derive(Debug, Deserialize)]
struct RefreshResponse {
access_token: String,
refresh_token: String,
token_type: String,
expires_in: i64,
}
let refresh_response: RefreshResponse = from_json(&response_text)?;
let token = AccessToken {
token: refresh_response.access_token,
expires_at: Utc::now() + chrono::Duration::seconds(refresh_response.expires_in),
refresh_token: Some(refresh_response.refresh_token),
};
self.token_store.store_token(token.clone())?;
Ok(token)
}
pub async fn get_token(&self) -> WebullResult<AccessToken> {
match self.token_store.get_token()? {
Some(token) => {
if token.expires_at <= Utc::now() {
return Err(WebullError::Unauthorized);
}
Ok(token)
}
None => Err(WebullError::Unauthorized),
}
}
pub async fn revoke_token(&mut self) -> WebullResult<()> {
let current_token = match self.token_store.get_token()? {
Some(token) => token,
None => {
self.credentials = None;
return Ok(());
}
};
let body = json!({
"accessToken": current_token.token,
"deviceId": self.config.device_id.clone().unwrap_or_default(),
});
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", current_token.token)).unwrap());
if let Some(api_key) = &self.config.api_key {
headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
}
let timestamp = generate_timestamp();
let signature = if let Some(api_secret) = &self.config.api_secret {
let message = format!("{}{}", timestamp, to_json(&body)?);
generate_signature(api_secret, &message)?
} else {
String::new()
};
headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
let response = self.client.post(format!("{}/api/passport/logout", self.config.base_url))
.headers(headers)
.json(&body)
.send()
.await
.map_err(|e| WebullError::NetworkError(e))?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
if status.as_u16() == 401 {
} else if status.as_u16() == 429 {
return Err(WebullError::RateLimitExceeded);
} else {
return Err(WebullError::ApiError {
code: status.as_u16().to_string(),
message: text,
});
}
}
self.token_store.clear_token()?;
self.credentials = None;
Ok(())
}
}