1use crate::config::WebullConfig;
2use crate::error::{WebullError, WebullResult};
3use crate::utils::crypto::{encrypt_password, generate_signature, generate_timestamp};
4use crate::utils::serialization::{from_json, to_json};
5use chrono::{DateTime, Utc};
6use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Mutex;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Credentials {
14 pub username: String,
16
17 pub password: String,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct AccessToken {
24 pub token: String,
26
27 pub expires_at: DateTime<Utc>,
29
30 pub refresh_token: Option<String>,
32}
33
34pub trait TokenStore: Send + Sync {
36 fn get_token(&self) -> WebullResult<Option<AccessToken>>;
38
39 fn store_token(&self, token: AccessToken) -> WebullResult<()>;
41
42 fn clear_token(&self) -> WebullResult<()>;
44}
45
46#[derive(Debug, Default)]
48pub struct MemoryTokenStore {
49 token: Mutex<Option<AccessToken>>,
50}
51
52impl TokenStore for MemoryTokenStore {
53 fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54 Ok(self.token.lock().unwrap().clone())
55 }
56
57 fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58 *self.token.lock().unwrap() = Some(token);
59 Ok(())
60 }
61
62 fn clear_token(&self) -> WebullResult<()> {
63 *self.token.lock().unwrap() = None;
64 Ok(())
65 }
66}
67
68pub struct AuthManager {
70 credentials: Option<Credentials>,
72
73 pub token_store: Box<dyn TokenStore>,
75
76 config: WebullConfig,
78
79 client: reqwest::Client,
81}
82
83impl AuthManager {
84 pub fn new(
86 config: WebullConfig,
87 token_store: Box<dyn TokenStore>,
88 client: reqwest::Client,
89 ) -> Self {
90 Self {
91 credentials: None,
92 token_store,
93 config,
94 client,
95 }
96 }
97
98 pub async fn authenticate(
100 &mut self,
101 username: &str,
102 password: &str,
103 ) -> WebullResult<AccessToken> {
104 self.credentials = Some(Credentials {
106 username: username.to_string(),
107 password: password.to_string(),
108 });
109
110 let encrypted_password = encrypt_password(
112 password,
113 &self.config.api_secret.clone().unwrap_or_default(),
114 )?;
115
116 let body = json!({
118 "username": username,
119 "password": encrypted_password,
120 "deviceId": self.config.device_id.clone().unwrap_or_default(),
121 "deviceName": "Rust API Client",
122 "deviceType": "Web",
123 });
124
125 let mut headers = HeaderMap::new();
127 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
128
129 if let Some(api_key) = &self.config.api_key {
131 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
132 }
133
134 let timestamp = generate_timestamp();
136 let signature = if let Some(api_secret) = &self.config.api_secret {
137 let message = format!("{}{}", timestamp, to_json(&body)?);
138 generate_signature(api_secret, &message)?
139 } else {
140 String::new()
141 };
142
143 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
145 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
146
147 let response = self
149 .client
150 .post(format!(
151 "{}/api/passport/login/v5/account",
152 self.config.base_url
153 ))
154 .headers(headers)
155 .json(&body)
156 .send()
157 .await
158 .map_err(|e| WebullError::NetworkError(e))?;
159
160 if !response.status().is_success() {
162 let status = response.status();
163 let text = response
164 .text()
165 .await
166 .unwrap_or_else(|_| "Unknown error".to_string());
167
168 if status.as_u16() == 401 {
169 return Err(WebullError::Unauthorized);
170 } else if status.as_u16() == 429 {
171 return Err(WebullError::RateLimitExceeded);
172 } else {
173 return Err(WebullError::ApiError {
174 code: status.as_u16().to_string(),
175 message: text,
176 });
177 }
178 }
179
180 let response_text = response
182 .text()
183 .await
184 .map_err(|e| WebullError::NetworkError(e))?;
185
186 #[derive(Debug, Deserialize)]
187 struct LoginResponse {
188 access_token: String,
189 refresh_token: String,
190 token_type: String,
191 expires_in: i64,
192 }
193
194 let login_response: LoginResponse = from_json(&response_text)?;
195
196 let token = AccessToken {
198 token: login_response.access_token,
199 expires_at: Utc::now() + chrono::Duration::seconds(login_response.expires_in),
200 refresh_token: Some(login_response.refresh_token),
201 };
202
203 self.token_store.store_token(token.clone())?;
205
206 Ok(token)
207 }
208
209 pub async fn multi_factor_auth(&mut self, mfa_code: &str) -> WebullResult<AccessToken> {
211 let credentials = self.credentials.as_ref().ok_or_else(|| {
213 WebullError::InvalidRequest("No credentials available for MFA".to_string())
214 })?;
215
216 let body = json!({
218 "username": credentials.username,
219 "verificationCode": mfa_code,
220 "deviceId": self.config.device_id.clone().unwrap_or_default(),
221 });
222
223 let mut headers = HeaderMap::new();
225 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
226
227 if let Some(api_key) = &self.config.api_key {
229 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
230 }
231
232 let timestamp = generate_timestamp();
234 let signature = if let Some(api_secret) = &self.config.api_secret {
235 let message = format!("{}{}", timestamp, to_json(&body)?);
236 generate_signature(api_secret, &message)?
237 } else {
238 String::new()
239 };
240
241 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
243 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
244
245 let response = self
247 .client
248 .post(format!(
249 "{}/api/passport/verificationCode/verify",
250 self.config.base_url
251 ))
252 .headers(headers)
253 .json(&body)
254 .send()
255 .await
256 .map_err(|e| WebullError::NetworkError(e))?;
257
258 if !response.status().is_success() {
260 let status = response.status();
261 let text = response
262 .text()
263 .await
264 .unwrap_or_else(|_| "Unknown error".to_string());
265
266 if status.as_u16() == 401 {
267 return Err(WebullError::Unauthorized);
268 } else if status.as_u16() == 429 {
269 return Err(WebullError::RateLimitExceeded);
270 } else {
271 return Err(WebullError::ApiError {
272 code: status.as_u16().to_string(),
273 message: text,
274 });
275 }
276 }
277
278 let response_text = response
280 .text()
281 .await
282 .map_err(|e| WebullError::NetworkError(e))?;
283
284 #[derive(Debug, Deserialize)]
285 struct MfaResponse {
286 access_token: String,
287 refresh_token: String,
288 token_type: String,
289 expires_in: i64,
290 }
291
292 let mfa_response: MfaResponse = from_json(&response_text)?;
293
294 let token = AccessToken {
296 token: mfa_response.access_token,
297 expires_at: Utc::now() + chrono::Duration::seconds(mfa_response.expires_in),
298 refresh_token: Some(mfa_response.refresh_token),
299 };
300
301 self.token_store.store_token(token.clone())?;
303
304 Ok(token)
305 }
306
307 pub async fn refresh_token(&mut self) -> WebullResult<AccessToken> {
309 let current_token = self.token_store.get_token()?.ok_or_else(|| {
311 WebullError::InvalidRequest("No token available for refresh".to_string())
312 })?;
313
314 let refresh_token = current_token
316 .refresh_token
317 .ok_or_else(|| WebullError::InvalidRequest("No refresh token available".to_string()))?;
318
319 let body = json!({
321 "refreshToken": refresh_token,
322 "deviceId": self.config.device_id.clone().unwrap_or_default(),
323 });
324
325 let mut headers = HeaderMap::new();
327 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
328
329 if let Some(api_key) = &self.config.api_key {
331 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
332 }
333
334 let timestamp = generate_timestamp();
336 let signature = if let Some(api_secret) = &self.config.api_secret {
337 let message = format!("{}{}", timestamp, to_json(&body)?);
338 generate_signature(api_secret, &message)?
339 } else {
340 String::new()
341 };
342
343 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
345 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
346
347 let response = self
349 .client
350 .post(format!(
351 "{}/api/passport/refreshToken",
352 self.config.base_url
353 ))
354 .headers(headers)
355 .json(&body)
356 .send()
357 .await
358 .map_err(|e| WebullError::NetworkError(e))?;
359
360 if !response.status().is_success() {
362 let status = response.status();
363 let text = response
364 .text()
365 .await
366 .unwrap_or_else(|_| "Unknown error".to_string());
367
368 if status.as_u16() == 401 {
369 return Err(WebullError::Unauthorized);
370 } else if status.as_u16() == 429 {
371 return Err(WebullError::RateLimitExceeded);
372 } else {
373 return Err(WebullError::ApiError {
374 code: status.as_u16().to_string(),
375 message: text,
376 });
377 }
378 }
379
380 let response_text = response
382 .text()
383 .await
384 .map_err(|e| WebullError::NetworkError(e))?;
385
386 #[derive(Debug, Deserialize)]
387 struct RefreshResponse {
388 access_token: String,
389 refresh_token: String,
390 token_type: String,
391 expires_in: i64,
392 }
393
394 let refresh_response: RefreshResponse = from_json(&response_text)?;
395
396 let token = AccessToken {
398 token: refresh_response.access_token,
399 expires_at: Utc::now() + chrono::Duration::seconds(refresh_response.expires_in),
400 refresh_token: Some(refresh_response.refresh_token),
401 };
402
403 self.token_store.store_token(token.clone())?;
405
406 Ok(token)
407 }
408
409 pub async fn get_token(&self) -> WebullResult<AccessToken> {
411 match self.token_store.get_token()? {
412 Some(token) => {
413 if token.expires_at <= Utc::now() {
415 return Err(WebullError::Unauthorized);
416 }
417 Ok(token)
418 }
419 None => Err(WebullError::Unauthorized),
420 }
421 }
422
423 pub async fn revoke_token(&mut self) -> WebullResult<()> {
425 let current_token = match self.token_store.get_token()? {
427 Some(token) => token,
428 None => {
429 self.credentials = None;
431 return Ok(());
432 }
433 };
434
435 let body = json!({
437 "accessToken": current_token.token,
438 "deviceId": self.config.device_id.clone().unwrap_or_default(),
439 });
440
441 let mut headers = HeaderMap::new();
443 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
444 headers.insert(
445 AUTHORIZATION,
446 HeaderValue::from_str(&format!("Bearer {}", current_token.token)).unwrap(),
447 );
448
449 if let Some(api_key) = &self.config.api_key {
451 headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
452 }
453
454 let timestamp = generate_timestamp();
456 let signature = if let Some(api_secret) = &self.config.api_secret {
457 let message = format!("{}{}", timestamp, to_json(&body)?);
458 generate_signature(api_secret, &message)?
459 } else {
460 String::new()
461 };
462
463 headers.insert("timestamp", HeaderValue::from_str(×tamp).unwrap());
465 headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
466
467 let response = self
469 .client
470 .post(format!("{}/api/passport/logout", self.config.base_url))
471 .headers(headers)
472 .json(&body)
473 .send()
474 .await
475 .map_err(|e| WebullError::NetworkError(e))?;
476
477 if !response.status().is_success() {
479 let status = response.status();
480 let text = response
481 .text()
482 .await
483 .unwrap_or_else(|_| "Unknown error".to_string());
484
485 if status.as_u16() == 401 {
486 } else if status.as_u16() == 429 {
488 return Err(WebullError::RateLimitExceeded);
489 } else {
490 return Err(WebullError::ApiError {
491 code: status.as_u16().to_string(),
492 message: text,
493 });
494 }
495 }
496
497 self.token_store.clear_token()?;
499 self.credentials = None;
500
501 Ok(())
502 }
503}