webull_rs/
auth.rs

1use crate::error::{WebullError, WebullResult};
2use crate::config::WebullConfig;
3use crate::utils::crypto::{encrypt_password, generate_signature, generate_timestamp};
4use crate::utils::serialization::{from_json, to_json};
5use chrono::{DateTime, Utc};
6use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::sync::Mutex;
10
11/// Credentials for authentication.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Credentials {
14    /// Username for authentication
15    pub username: String,
16
17    /// Password for authentication
18    pub password: String,
19}
20
21/// Access token for API requests.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct AccessToken {
24    /// The access token
25    pub token: String,
26
27    /// When the token expires
28    pub expires_at: DateTime<Utc>,
29
30    /// The refresh token
31    pub refresh_token: Option<String>,
32}
33
34/// Interface for storing and retrieving tokens.
35pub trait TokenStore: Send + Sync {
36    /// Get the current access token.
37    fn get_token(&self) -> WebullResult<Option<AccessToken>>;
38
39    /// Store an access token.
40    fn store_token(&self, token: AccessToken) -> WebullResult<()>;
41
42    /// Clear the stored token.
43    fn clear_token(&self) -> WebullResult<()>;
44}
45
46/// In-memory token store.
47#[derive(Debug, Default)]
48pub struct MemoryTokenStore {
49    token: Mutex<Option<AccessToken>>,
50}
51
52impl TokenStore for MemoryTokenStore {
53    fn get_token(&self) -> WebullResult<Option<AccessToken>> {
54        Ok(self.token.lock().unwrap().clone())
55    }
56
57    fn store_token(&self, token: AccessToken) -> WebullResult<()> {
58        *self.token.lock().unwrap() = Some(token);
59        Ok(())
60    }
61
62    fn clear_token(&self) -> WebullResult<()> {
63        *self.token.lock().unwrap() = None;
64        Ok(())
65    }
66}
67
68/// Manager for authentication.
69pub struct AuthManager {
70    /// Credentials for authentication
71    credentials: Option<Credentials>,
72
73    /// Token store
74    pub token_store: Box<dyn TokenStore>,
75
76    /// Configuration
77    config: WebullConfig,
78
79    /// HTTP client
80    client: reqwest::Client,
81}
82
83impl AuthManager {
84    /// Create a new authentication manager.
85    pub fn new(
86        config: WebullConfig,
87        token_store: Box<dyn TokenStore>,
88        client: reqwest::Client,
89    ) -> Self {
90        Self {
91            credentials: None,
92            token_store,
93            config,
94            client,
95        }
96    }
97
98    /// Authenticate with username and password.
99    pub async fn authenticate(&mut self, username: &str, password: &str) -> WebullResult<AccessToken> {
100        // Store credentials for potential token refresh
101        self.credentials = Some(Credentials {
102            username: username.to_string(),
103            password: password.to_string(),
104        });
105
106        // Encrypt the password
107        let encrypted_password = encrypt_password(password, &self.config.api_secret.clone().unwrap_or_default())?;
108
109        // Create the request body
110        let body = json!({
111            "username": username,
112            "password": encrypted_password,
113            "deviceId": self.config.device_id.clone().unwrap_or_default(),
114            "deviceName": "Rust API Client",
115            "deviceType": "Web",
116        });
117
118        // Create headers
119        let mut headers = HeaderMap::new();
120        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
121
122        // Add API key if available
123        if let Some(api_key) = &self.config.api_key {
124            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
125        }
126
127        // Generate timestamp and signature
128        let timestamp = generate_timestamp();
129        let signature = if let Some(api_secret) = &self.config.api_secret {
130            let message = format!("{}{}", timestamp, to_json(&body)?);
131            generate_signature(api_secret, &message)?
132        } else {
133            String::new()
134        };
135
136        // Add timestamp and signature to headers
137        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
138        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
139
140        // Send the request
141        let response = self.client.post(format!("{}/api/passport/login/v5/account", self.config.base_url))
142            .headers(headers)
143            .json(&body)
144            .send()
145            .await
146            .map_err(|e| WebullError::NetworkError(e))?;
147
148        // Check for errors
149        if !response.status().is_success() {
150            let status = response.status();
151            let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
152
153            if status.as_u16() == 401 {
154                return Err(WebullError::Unauthorized);
155            } else if status.as_u16() == 429 {
156                return Err(WebullError::RateLimitExceeded);
157            } else {
158                return Err(WebullError::ApiError {
159                    code: status.as_u16().to_string(),
160                    message: text,
161                });
162            }
163        }
164
165        // Parse the response
166        let response_text = response.text().await
167            .map_err(|e| WebullError::NetworkError(e))?;
168
169        #[derive(Debug, Deserialize)]
170        struct LoginResponse {
171            access_token: String,
172            refresh_token: String,
173            token_type: String,
174            expires_in: i64,
175        }
176
177        let login_response: LoginResponse = from_json(&response_text)?;
178
179        // Create the token
180        let token = AccessToken {
181            token: login_response.access_token,
182            expires_at: Utc::now() + chrono::Duration::seconds(login_response.expires_in),
183            refresh_token: Some(login_response.refresh_token),
184        };
185
186        // Store the token
187        self.token_store.store_token(token.clone())?;
188
189        Ok(token)
190    }
191
192    /// Handle multi-factor authentication.
193    pub async fn multi_factor_auth(&mut self, mfa_code: &str) -> WebullResult<AccessToken> {
194        // Check if we have credentials
195        let credentials = self.credentials.as_ref()
196            .ok_or_else(|| WebullError::InvalidRequest("No credentials available for MFA".to_string()))?;
197
198        // Create the request body
199        let body = json!({
200            "username": credentials.username,
201            "verificationCode": mfa_code,
202            "deviceId": self.config.device_id.clone().unwrap_or_default(),
203        });
204
205        // Create headers
206        let mut headers = HeaderMap::new();
207        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
208
209        // Add API key if available
210        if let Some(api_key) = &self.config.api_key {
211            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
212        }
213
214        // Generate timestamp and signature
215        let timestamp = generate_timestamp();
216        let signature = if let Some(api_secret) = &self.config.api_secret {
217            let message = format!("{}{}", timestamp, to_json(&body)?);
218            generate_signature(api_secret, &message)?
219        } else {
220            String::new()
221        };
222
223        // Add timestamp and signature to headers
224        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
225        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
226
227        // Send the request
228        let response = self.client.post(format!("{}/api/passport/verificationCode/verify", self.config.base_url))
229            .headers(headers)
230            .json(&body)
231            .send()
232            .await
233            .map_err(|e| WebullError::NetworkError(e))?;
234
235        // Check for errors
236        if !response.status().is_success() {
237            let status = response.status();
238            let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
239
240            if status.as_u16() == 401 {
241                return Err(WebullError::Unauthorized);
242            } else if status.as_u16() == 429 {
243                return Err(WebullError::RateLimitExceeded);
244            } else {
245                return Err(WebullError::ApiError {
246                    code: status.as_u16().to_string(),
247                    message: text,
248                });
249            }
250        }
251
252        // Parse the response
253        let response_text = response.text().await
254            .map_err(|e| WebullError::NetworkError(e))?;
255
256        #[derive(Debug, Deserialize)]
257        struct MfaResponse {
258            access_token: String,
259            refresh_token: String,
260            token_type: String,
261            expires_in: i64,
262        }
263
264        let mfa_response: MfaResponse = from_json(&response_text)?;
265
266        // Create the token
267        let token = AccessToken {
268            token: mfa_response.access_token,
269            expires_at: Utc::now() + chrono::Duration::seconds(mfa_response.expires_in),
270            refresh_token: Some(mfa_response.refresh_token),
271        };
272
273        // Store the token
274        self.token_store.store_token(token.clone())?;
275
276        Ok(token)
277    }
278
279    /// Refresh the access token.
280    pub async fn refresh_token(&mut self) -> WebullResult<AccessToken> {
281        // Get the current token
282        let current_token = self.token_store.get_token()?
283            .ok_or_else(|| WebullError::InvalidRequest("No token available for refresh".to_string()))?;
284
285        // Check if we have a refresh token
286        let refresh_token = current_token.refresh_token
287            .ok_or_else(|| WebullError::InvalidRequest("No refresh token available".to_string()))?;
288
289        // Create the request body
290        let body = json!({
291            "refreshToken": refresh_token,
292            "deviceId": self.config.device_id.clone().unwrap_or_default(),
293        });
294
295        // Create headers
296        let mut headers = HeaderMap::new();
297        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
298
299        // Add API key if available
300        if let Some(api_key) = &self.config.api_key {
301            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
302        }
303
304        // Generate timestamp and signature
305        let timestamp = generate_timestamp();
306        let signature = if let Some(api_secret) = &self.config.api_secret {
307            let message = format!("{}{}", timestamp, to_json(&body)?);
308            generate_signature(api_secret, &message)?
309        } else {
310            String::new()
311        };
312
313        // Add timestamp and signature to headers
314        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
315        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
316
317        // Send the request
318        let response = self.client.post(format!("{}/api/passport/refreshToken", self.config.base_url))
319            .headers(headers)
320            .json(&body)
321            .send()
322            .await
323            .map_err(|e| WebullError::NetworkError(e))?;
324
325        // Check for errors
326        if !response.status().is_success() {
327            let status = response.status();
328            let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
329
330            if status.as_u16() == 401 {
331                return Err(WebullError::Unauthorized);
332            } else if status.as_u16() == 429 {
333                return Err(WebullError::RateLimitExceeded);
334            } else {
335                return Err(WebullError::ApiError {
336                    code: status.as_u16().to_string(),
337                    message: text,
338                });
339            }
340        }
341
342        // Parse the response
343        let response_text = response.text().await
344            .map_err(|e| WebullError::NetworkError(e))?;
345
346        #[derive(Debug, Deserialize)]
347        struct RefreshResponse {
348            access_token: String,
349            refresh_token: String,
350            token_type: String,
351            expires_in: i64,
352        }
353
354        let refresh_response: RefreshResponse = from_json(&response_text)?;
355
356        // Create the token
357        let token = AccessToken {
358            token: refresh_response.access_token,
359            expires_at: Utc::now() + chrono::Duration::seconds(refresh_response.expires_in),
360            refresh_token: Some(refresh_response.refresh_token),
361        };
362
363        // Store the token
364        self.token_store.store_token(token.clone())?;
365
366        Ok(token)
367    }
368
369    /// Get the current access token.
370    pub async fn get_token(&self) -> WebullResult<AccessToken> {
371        match self.token_store.get_token()? {
372            Some(token) => {
373                // Check if token is expired
374                if token.expires_at <= Utc::now() {
375                    return Err(WebullError::Unauthorized);
376                }
377                Ok(token)
378            }
379            None => Err(WebullError::Unauthorized),
380        }
381    }
382
383    /// Revoke the current token.
384    pub async fn revoke_token(&mut self) -> WebullResult<()> {
385        // Get the current token
386        let current_token = match self.token_store.get_token()? {
387            Some(token) => token,
388            None => {
389                // No token to revoke
390                self.credentials = None;
391                return Ok(());
392            }
393        };
394
395        // Create the request body
396        let body = json!({
397            "accessToken": current_token.token,
398            "deviceId": self.config.device_id.clone().unwrap_or_default(),
399        });
400
401        // Create headers
402        let mut headers = HeaderMap::new();
403        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
404        headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {}", current_token.token)).unwrap());
405
406        // Add API key if available
407        if let Some(api_key) = &self.config.api_key {
408            headers.insert("api-key", HeaderValue::from_str(api_key).unwrap());
409        }
410
411        // Generate timestamp and signature
412        let timestamp = generate_timestamp();
413        let signature = if let Some(api_secret) = &self.config.api_secret {
414            let message = format!("{}{}", timestamp, to_json(&body)?);
415            generate_signature(api_secret, &message)?
416        } else {
417            String::new()
418        };
419
420        // Add timestamp and signature to headers
421        headers.insert("timestamp", HeaderValue::from_str(&timestamp).unwrap());
422        headers.insert("signature", HeaderValue::from_str(&signature).unwrap());
423
424        // Send the request
425        let response = self.client.post(format!("{}/api/passport/logout", self.config.base_url))
426            .headers(headers)
427            .json(&body)
428            .send()
429            .await
430            .map_err(|e| WebullError::NetworkError(e))?;
431
432        // Check for errors
433        if !response.status().is_success() {
434            let status = response.status();
435            let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
436
437            if status.as_u16() == 401 {
438                // Token is already invalid, so we can just clear it
439            } else if status.as_u16() == 429 {
440                return Err(WebullError::RateLimitExceeded);
441            } else {
442                return Err(WebullError::ApiError {
443                    code: status.as_u16().to_string(),
444                    message: text,
445                });
446            }
447        }
448
449        // Clear the token and credentials
450        self.token_store.clear_token()?;
451        self.credentials = None;
452
453        Ok(())
454    }
455}