1use crate::error::{WebullError, WebullResult};
2use crate::config::WebullConfig;
3use crate::utils::crypto::{encrypt_password, generate_signature, generate_timestamp};
4use crate::utils::serialization::{from_json, to_json};
5use chrono::{DateTime, Utc};
6use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Mutex;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Credentials {
14 pub username: String,
16
17 pub password: String,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct AccessToken {
24 pub token: String,
26
27 pub expires_at: DateTime<Utc>,
29
30 pub refresh_token: Option<String>,
32}
33
34pub trait TokenStore: Send + Sync {
36 fn get_token(&self) -> WebullResult<Option<AccessToken>>;
38
39 fn store_token(&self, token: AccessToken) -> WebullResult<()>;
41
42 fn clear_token(&self) -> WebullResult<()>;
44}
45
46#[derive(Debug, Default)]
48pub struct MemoryTokenStore {
49 token: Mutex<Option<AccessToken>>,
50}
51
52impl TokenStore for MemoryTokenStore {
53 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54 Ok(self.token.lock().unwrap().clone())
55 }
56
57 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58 *self.token.lock().unwrap() = Some(token);
59 Ok(())
60 }
61
62 fn clear_token(&self) -> WebullResult<()> {
63 *self.token.lock().unwrap() = None;
64 Ok(())
65 }
66}
67
68pub struct AuthManager {
70 credentials: Option<Credentials>,
72
73 pub token_store: Box<dyn TokenStore>,
75
76 config: WebullConfig,
78
79 client: reqwest::Client,
81}
82
83impl AuthManager {
84 pub fn new(
86 config: WebullConfig,
87 token_store: Box<dyn TokenStore>,
88 client: reqwest::Client,
89 ) -> Self {
90 Self {
91 credentials: None,
92 token_store,
93 config,
94 client,
95 }
96 }
97
98 pub async fn authenticate(&mut self, username: &str, password: &str) -> WebullResult<AccessToken> {
100 self.credentials = Some(Credentials {
102 username: username.to_string(),
103 password: password.to_string(),
104 });
105
106 let encrypted_password = encrypt_password(password, &self.config.api_secret.clone().unwrap_or_default())?;
108
109 let body = json!({
111 "username": username,
112 "password": encrypted_password,
113 "deviceId": self.config.device_id.clone().unwrap_or_default(),
114 "deviceName": "Rust API Client",
115 "deviceType": "Web",
116 });
117
118 let mut headers = HeaderMap::new();
120 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
121
122 if let Some(api_key) = &self.config.api_key {
124 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
125 }
126
127 let timestamp = generate_timestamp();
129 let signature = if let Some(api_secret) = &self.config.api_secret {
130 let message = format!("{}{}", timestamp, to_json(&body)?);
131 generate_signature(api_secret, &message)?
132 } else {
133 String::new()
134 };
135
136 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
138 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
139
140 let response = self.client.post(format!("{}/api/passport/login/v5/account", self.config.base_url))
142 .headers(headers)
143 .json(&body)
144 .send()
145 .await
146 .map_err(|e| WebullError::NetworkError(e))?;
147
148 if !response.status().is_success() {
150 let status = response.status();
151 let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
152
153 if status.as_u16() == 401 {
154 return Err(WebullError::Unauthorized);
155 } else if status.as_u16() == 429 {
156 return Err(WebullError::RateLimitExceeded);
157 } else {
158 return Err(WebullError::ApiError {
159 code: status.as_u16().to_string(),
160 message: text,
161 });
162 }
163 }
164
165 let response_text = response.text().await
167 .map_err(|e| WebullError::NetworkError(e))?;
168
169 #[derive(Debug, Deserialize)]
170 struct LoginResponse {
171 access_token: String,
172 refresh_token: String,
173 token_type: String,
174 expires_in: i64,
175 }
176
177 let login_response: LoginResponse = from_json(&response_text)?;
178
179 let token = AccessToken {
181 token: login_response.access_token,
182 expires_at: Utc::now() + chrono::Duration::seconds(login_response.expires_in),
183 refresh_token: Some(login_response.refresh_token),
184 };
185
186 self.token_store.store_token(token.clone())?;
188
189 Ok(token)
190 }
191
192 pub async fn multi_factor_auth(&mut self, mfa_code: &str) -> WebullResult<AccessToken> {
194 let credentials = self.credentials.as_ref()
196 .ok_or_else(|| WebullError::InvalidRequest("No credentials available for MFA".to_string()))?;
197
198 let body = json!({
200 "username": credentials.username,
201 "verificationCode": mfa_code,
202 "deviceId": self.config.device_id.clone().unwrap_or_default(),
203 });
204
205 let mut headers = HeaderMap::new();
207 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
208
209 if let Some(api_key) = &self.config.api_key {
211 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
212 }
213
214 let timestamp = generate_timestamp();
216 let signature = if let Some(api_secret) = &self.config.api_secret {
217 let message = format!("{}{}", timestamp, to_json(&body)?);
218 generate_signature(api_secret, &message)?
219 } else {
220 String::new()
221 };
222
223 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
225 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
226
227 let response = self.client.post(format!("{}/api/passport/verificationCode/verify", self.config.base_url))
229 .headers(headers)
230 .json(&body)
231 .send()
232 .await
233 .map_err(|e| WebullError::NetworkError(e))?;
234
235 if !response.status().is_success() {
237 let status = response.status();
238 let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
239
240 if status.as_u16() == 401 {
241 return Err(WebullError::Unauthorized);
242 } else if status.as_u16() == 429 {
243 return Err(WebullError::RateLimitExceeded);
244 } else {
245 return Err(WebullError::ApiError {
246 code: status.as_u16().to_string(),
247 message: text,
248 });
249 }
250 }
251
252 let response_text = response.text().await
254 .map_err(|e| WebullError::NetworkError(e))?;
255
256 #[derive(Debug, Deserialize)]
257 struct MfaResponse {
258 access_token: String,
259 refresh_token: String,
260 token_type: String,
261 expires_in: i64,
262 }
263
264 let mfa_response: MfaResponse = from_json(&response_text)?;
265
266 let token = AccessToken {
268 token: mfa_response.access_token,
269 expires_at: Utc::now() + chrono::Duration::seconds(mfa_response.expires_in),
270 refresh_token: Some(mfa_response.refresh_token),
271 };
272
273 self.token_store.store_token(token.clone())?;
275
276 Ok(token)
277 }
278
279 pub async fn refresh_token(&mut self) -> WebullResult<AccessToken> {
281 let current_token = self.token_store.get_token()?
283 .ok_or_else(|| WebullError::InvalidRequest("No token available for refresh".to_string()))?;
284
285 let refresh_token = current_token.refresh_token
287 .ok_or_else(|| WebullError::InvalidRequest("No refresh token available".to_string()))?;
288
289 let body = json!({
291 "refreshToken": refresh_token,
292 "deviceId": self.config.device_id.clone().unwrap_or_default(),
293 });
294
295 let mut headers = HeaderMap::new();
297 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
298
299 if let Some(api_key) = &self.config.api_key {
301 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
302 }
303
304 let timestamp = generate_timestamp();
306 let signature = if let Some(api_secret) = &self.config.api_secret {
307 let message = format!("{}{}", timestamp, to_json(&body)?);
308 generate_signature(api_secret, &message)?
309 } else {
310 String::new()
311 };
312
313 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
315 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
316
317 let response = self.client.post(format!("{}/api/passport/refreshToken", self.config.base_url))
319 .headers(headers)
320 .json(&body)
321 .send()
322 .await
323 .map_err(|e| WebullError::NetworkError(e))?;
324
325 if !response.status().is_success() {
327 let status = response.status();
328 let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
329
330 if status.as_u16() == 401 {
331 return Err(WebullError::Unauthorized);
332 } else if status.as_u16() == 429 {
333 return Err(WebullError::RateLimitExceeded);
334 } else {
335 return Err(WebullError::ApiError {
336 code: status.as_u16().to_string(),
337 message: text,
338 });
339 }
340 }
341
342 let response_text = response.text().await
344 .map_err(|e| WebullError::NetworkError(e))?;
345
346 #[derive(Debug, Deserialize)]
347 struct RefreshResponse {
348 access_token: String,
349 refresh_token: String,
350 token_type: String,
351 expires_in: i64,
352 }
353
354 let refresh_response: RefreshResponse = from_json(&response_text)?;
355
356 let token = AccessToken {
358 token: refresh_response.access_token,
359 expires_at: Utc::now() + chrono::Duration::seconds(refresh_response.expires_in),
360 refresh_token: Some(refresh_response.refresh_token),
361 };
362
363 self.token_store.store_token(token.clone())?;
365
366 Ok(token)
367 }
368
369 pub async fn get_token(&self) -> WebullResult<AccessToken> {
371 match self.token_store.get_token()? {
372 Some(token) => {
373 if token.expires_at <= Utc::now() {
375 return Err(WebullError::Unauthorized);
376 }
377 Ok(token)
378 }
379 None => Err(WebullError::Unauthorized),
380 }
381 }
382
383 pub async fn revoke_token(&mut self) -> WebullResult<()> {
385 let current_token = match self.token_store.get_token()? {
387 Some(token) => token,
388 None => {
389 self.credentials = None;
391 return Ok(());
392 }
393 };
394
395 let body = json!({
397 "accessToken": current_token.token,
398 "deviceId": self.config.device_id.clone().unwrap_or_default(),
399 });
400
401 let mut headers = HeaderMap::new();
403 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
404 headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", current_token.token)).unwrap());
405
406 if let Some(api_key) = &self.config.api_key {
408 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
409 }
410
411 let timestamp = generate_timestamp();
413 let signature = if let Some(api_secret) = &self.config.api_secret {
414 let message = format!("{}{}", timestamp, to_json(&body)?);
415 generate_signature(api_secret, &message)?
416 } else {
417 String::new()
418 };
419
420 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
422 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
423
424 let response = self.client.post(format!("{}/api/passport/logout", self.config.base_url))
426 .headers(headers)
427 .json(&body)
428 .send()
429 .await
430 .map_err(|e| WebullError::NetworkError(e))?;
431
432 if !response.status().is_success() {
434 let status = response.status();
435 let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
436
437 if status.as_u16() == 401 {
438 } else if status.as_u16() == 429 {
440 return Err(WebullError::RateLimitExceeded);
441 } else {
442 return Err(WebullError::ApiError {
443 code: status.as_u16().to_string(),
444 message: text,
445 });
446 }
447 }
448
449 self.token_store.clear_token()?;
451 self.credentials = None;
452
453 Ok(())
454 }
455}