Skip to main content

baidu_netdisk_sdk/auth/
mod.rs

1//! Authentication and token management module
2//!
3//! Provides device code authorization, access token management, and automatic refresh functionality
4//!
5//! # Features
6//!
7//! - **Device Code OAuth Flow**: Get device code and poll for access token
8//! - **Token Management**: Store, validate, and refresh access tokens
9//! - **Automatic Refresh**: Auto-refresh tokens before expiration
10//! - **Thread Safe**: Safe concurrent access with RwLock
11//!
12//! # Quick Start
13//!
14//! ```
15//! use baidu_netdisk_sdk::BaiduNetDiskClient;
16//!
17//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
18//! let client = BaiduNetDiskClient::builder()
19//!     .app_key("your_app_key")
20//!     .app_secret("your_app_secret")
21//!     .build()?;
22//!
23//! // Device code authorization
24//! let auth = client.authorize();
25//! let device_code = auth.get_device_code().await?;
26//! println!("Please visit: {} and enter code: {}",
27//!     device_code.verification_url,
28//!     device_code.user_code
29//! );
30//!
31//! // Poll for token
32//! let token = loop {
33//!     if let Some(token) = auth.request_access_token(&device_code).await? {
34//!         break token;
35//!     }
36//!     tokio::time::sleep(tokio::time::Duration::from_secs(device_code.interval as u64)).await;
37//! };
38//!
39//! // Use token provider for auto-refresh
40//! let provider = client.token_provider();
41//! provider.set_access_token(token)?;
42//!
43//! // Get valid token (auto-refreshes if needed)
44//! let valid_token = provider.get_valid_token().await?;
45//! # Ok(())
46//! # }
47//! ```
48
49use serde::{Deserialize, Serialize};
50use std::time::{SystemTime, UNIX_EPOCH};
51
52pub mod authorization;
53pub mod token_provider;
54
55pub use self::authorization::Authorization;
56pub use self::token_provider::{TokenProvider, TokenProviderConfig};
57
58/// Device code response from Baidu OAuth API
59#[derive(Debug, Deserialize, Serialize, Clone)]
60pub struct DeviceCodeResponse {
61    /// Device code for polling
62    pub device_code: String,
63    /// User code for user input
64    pub user_code: String,
65    /// Verification URL for user to visit
66    pub verification_url: String,
67    /// QR code URL for scanning
68    pub qrcode_url: String,
69    /// Polling interval in seconds
70    pub interval: u32,
71    /// Expiration time in seconds
72    pub expires_in: u32,
73}
74
75/// Device code information with expiration timestamp
76#[derive(Debug, Deserialize, Serialize, Clone)]
77pub struct DeviceCode {
78    pub device_code: String,
79    pub user_code: String,
80    pub verification_url: String,
81    pub qrcode_url: String,
82    pub interval: u32,
83    pub expires_at: u64,
84}
85
86impl From<DeviceCodeResponse> for DeviceCode {
87    fn from(response: DeviceCodeResponse) -> Self {
88        let expires_at = SystemTime::now()
89            .duration_since(UNIX_EPOCH)
90            .unwrap_or_default()
91            .as_secs()
92            + response.expires_in as u64;
93
94        DeviceCode {
95            device_code: response.device_code,
96            user_code: response.user_code,
97            verification_url: response.verification_url,
98            qrcode_url: response.qrcode_url,
99            interval: response.interval,
100            expires_at,
101        }
102    }
103}
104
105/// Access token response from Baidu OAuth API
106///
107/// Fields match exactly what Baidu API returns:
108/// {"expires_in":2592000,"refresh_token":"...","access_token":"...","session_secret":"","session_key":"","scope":"basic netdisk"}
109#[derive(Debug, Deserialize, Serialize, Clone)]
110pub struct AccessTokenResponse {
111    /// Access token for API requests
112    pub access_token: String,
113    /// Expiration time in seconds (typically 2592000 seconds = 30 days)
114    pub expires_in: u64,
115    /// Refresh token for obtaining new access token
116    pub refresh_token: String,
117    /// Scope of permissions (e.g., "basic netdisk")
118    pub scope: String,
119    /// Session key for signed API requests (may be empty string)
120    pub session_key: String,
121    /// Session secret for signed API requests (may be empty string)
122    pub session_secret: String,
123}
124
125/// Access token information with acquisition timestamp
126#[derive(Debug, Deserialize, Serialize, Clone)]
127pub struct AccessToken {
128    /// Access token for API requests
129    pub access_token: String,
130    /// Expiration time in seconds
131    pub expires_in: u64,
132    /// Refresh token for obtaining new access token
133    pub refresh_token: String,
134    /// Scope of permissions
135    pub scope: String,
136    /// Session key for signed API requests
137    pub session_key: String,
138    /// Session secret for signed API requests
139    pub session_secret: String,
140    /// Acquisition timestamp in seconds
141    pub acquired_at: u64,
142}
143
144impl From<AccessTokenResponse> for AccessToken {
145    fn from(response: AccessTokenResponse) -> Self {
146        let acquired_at = SystemTime::now()
147            .duration_since(UNIX_EPOCH)
148            .unwrap_or_default()
149            .as_secs();
150
151        AccessToken {
152            access_token: response.access_token,
153            expires_in: response.expires_in,
154            refresh_token: response.refresh_token,
155            scope: response.scope,
156            session_key: response.session_key,
157            session_secret: response.session_secret,
158            acquired_at,
159        }
160    }
161}
162
163impl AccessToken {
164    /// Create a new AccessToken with the current timestamp
165    ///
166    /// # Arguments
167    /// * `access_token` - The access token string
168    /// * `refresh_token` - The refresh token string
169    /// * `expires_in` - Expiration time in seconds
170    /// * `scope` - Scope of permissions
171    pub fn new(
172        access_token: String,
173        refresh_token: String,
174        expires_in: u64,
175        scope: String,
176    ) -> Self {
177        let acquired_at = SystemTime::now()
178            .duration_since(UNIX_EPOCH)
179            .unwrap_or_default()
180            .as_secs();
181
182        AccessToken {
183            access_token,
184            refresh_token,
185            expires_in,
186            scope,
187            session_key: String::new(),
188            session_secret: String::new(),
189            acquired_at,
190        }
191    }
192
193    /// Create an AccessToken with all fields specified
194    ///
195    /// Use this when you need to set session_key, session_secret, or a specific acquired_at
196    /// (e.g., for testing expired tokens)
197    pub fn with_all(
198        access_token: String,
199        refresh_token: String,
200        expires_in: u64,
201        scope: String,
202        session_key: String,
203        session_secret: String,
204        acquired_at: u64,
205    ) -> Self {
206        AccessToken {
207            access_token,
208            refresh_token,
209            expires_in,
210            scope,
211            session_key,
212            session_secret,
213            acquired_at,
214        }
215    }
216
217    /// Check if the token is expired
218    /// Returns true if token has expired or will expire within 60 seconds
219    pub fn is_expired(&self) -> bool {
220        let now = SystemTime::now()
221            .duration_since(UNIX_EPOCH)
222            .unwrap_or_default()
223            .as_secs();
224        now >= self.acquired_at + self.expires_in - 60
225    }
226
227    /// Get remaining valid seconds of the token
228    pub fn remaining_seconds(&self) -> u64 {
229        let now = SystemTime::now()
230            .duration_since(UNIX_EPOCH)
231            .unwrap_or_default()
232            .as_secs();
233        let expires_at = self.acquired_at + self.expires_in;
234        expires_at.saturating_sub(now)
235    }
236
237    /// Validate the token and return its status
238    pub fn validate(&self) -> TokenStatus {
239        let remaining = self.remaining_seconds();
240        if remaining == 0 {
241            TokenStatus::Expired
242        } else if remaining <= 300 {
243            TokenStatus::ExpiringSoon
244        } else {
245            TokenStatus::Valid
246        }
247    }
248
249    /// Check if token is valid (not expired and not expiring soon)
250    pub fn is_valid(&self) -> bool {
251        matches!(self.validate(), TokenStatus::Valid)
252    }
253
254    /// Get expiration timestamp in seconds since epoch
255    pub fn expires_at(&self) -> u64 {
256        self.acquired_at + self.expires_in
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    fn create_test_token(acquired_at: u64, expires_in: u64) -> AccessToken {
265        AccessToken {
266            access_token: "test_token".to_string(),
267            refresh_token: "test_refresh".to_string(),
268            expires_in,
269            scope: "basic netdisk".to_string(),
270            session_key: String::new(),
271            session_secret: String::new(),
272            acquired_at,
273        }
274    }
275
276    #[test]
277    fn test_token_new() {
278        let token = AccessToken::new(
279            "access".to_string(),
280            "refresh".to_string(),
281            3600,
282            "basic netdisk".to_string(),
283        );
284
285        assert_eq!(token.access_token, "access");
286        assert_eq!(token.refresh_token, "refresh");
287        assert_eq!(token.expires_in, 3600);
288        assert!(token.session_key.is_empty());
289    }
290
291    #[test]
292    fn test_token_with_all() {
293        let now = SystemTime::now()
294            .duration_since(UNIX_EPOCH)
295            .unwrap_or_default()
296            .as_secs();
297
298        let token = AccessToken::with_all(
299            "access".to_string(),
300            "refresh".to_string(),
301            7200,
302            "netdisk".to_string(),
303            "session_key".to_string(),
304            "session_secret".to_string(),
305            now,
306        );
307
308        assert_eq!(token.scope, "netdisk");
309        assert_eq!(token.session_key, "session_key");
310        assert_eq!(token.session_secret, "session_secret");
311        assert_eq!(token.acquired_at, now);
312    }
313
314    #[test]
315    fn test_token_is_expired() {
316        let now = SystemTime::now()
317            .duration_since(UNIX_EPOCH)
318            .unwrap_or_default()
319            .as_secs();
320
321        let valid_token = create_test_token(now, 3600);
322        assert!(!valid_token.is_expired());
323
324        let expired_token = create_test_token(now - 7200, 3600);
325        assert!(expired_token.is_expired());
326    }
327
328    #[test]
329    fn test_token_validate() {
330        let now = SystemTime::now()
331            .duration_since(UNIX_EPOCH)
332            .unwrap_or_default()
333            .as_secs();
334
335        let valid_token = create_test_token(now, 3600);
336        assert_eq!(valid_token.validate(), TokenStatus::Valid);
337
338        let expiring_soon_token = create_test_token(now, 200);
339        assert_eq!(expiring_soon_token.validate(), TokenStatus::ExpiringSoon);
340
341        let expired_token = create_test_token(now - 4000, 3600);
342        assert_eq!(expired_token.validate(), TokenStatus::Expired);
343    }
344
345    #[test]
346    fn test_token_is_valid() {
347        let now = SystemTime::now()
348            .duration_since(UNIX_EPOCH)
349            .unwrap_or_default()
350            .as_secs();
351
352        let valid_token = create_test_token(now, 3600);
353        assert!(valid_token.is_valid());
354
355        let expired_token = create_test_token(now - 4000, 3600);
356        assert!(!expired_token.is_valid());
357    }
358
359    #[test]
360    fn test_token_remaining_seconds() {
361        let now = SystemTime::now()
362            .duration_since(UNIX_EPOCH)
363            .unwrap_or_default()
364            .as_secs();
365
366        let token = create_test_token(now, 1000);
367        let remaining = token.remaining_seconds();
368        assert!((990..=1000).contains(&remaining));
369
370        let expired_token = create_test_token(now - 100, 50);
371        assert_eq!(expired_token.remaining_seconds(), 0);
372    }
373
374    #[test]
375    fn test_token_expires_at() {
376        let now = SystemTime::now()
377            .duration_since(UNIX_EPOCH)
378            .unwrap_or_default()
379            .as_secs();
380
381        let token = create_test_token(now, 500);
382        assert_eq!(token.expires_at(), now + 500);
383    }
384}
385
386/// Token validation status
387#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
388pub enum TokenStatus {
389    /// Token is valid and has plenty of time left
390    Valid,
391    /// Token will expire soon (within 5 minutes)
392    ExpiringSoon,
393    /// Token has already expired
394    Expired,
395}
396
397/// API error response structure
398///
399/// Baidu NetDisk API may return error codes in different formats:
400/// - `errno` and `errmsg` (common format)
401/// - `error_code` and `error_msg` (alternative format)
402///
403/// All fields are optional to handle various response formats.
404#[derive(Debug, Deserialize, Serialize, Default)]
405pub struct ApiErrorResponse {
406    /// Error code (format 1)
407    pub errno: Option<i32>,
408    /// Error message (format 1)
409    pub errmsg: Option<String>,
410    /// Alternative error code field (format 2)
411    #[serde(rename = "error_code")]
412    pub error_code: Option<i32>,
413    /// Alternative error message field (format 2)
414    #[serde(rename = "error_msg")]
415    pub error_msg: Option<String>,
416}
417
418impl ApiErrorResponse {
419    /// Get error code, prioritizing error_code over errno
420    pub fn get_errno(&self) -> i32 {
421        self.error_code.or(self.errno).unwrap_or(-1)
422    }
423
424    /// Get error message, prioritizing error_msg over errmsg
425    pub fn get_errmsg(&self) -> &str {
426        if let Some(msg) = &self.error_msg {
427            msg
428        } else if let Some(msg) = &self.errmsg {
429            msg
430        } else {
431            "Unknown error"
432        }
433    }
434
435    /// Check if this response contains any error information
436    pub fn has_error(&self) -> bool {
437        self.errno.is_some()
438            || self.error_code.is_some()
439            || self.errmsg.is_some()
440            || self.error_msg.is_some()
441    }
442}
443
444/// Authentication error response structure
445#[derive(Debug, Deserialize, Serialize)]
446pub struct AuthErrorResponse {
447    /// Error type
448    pub error: String,
449    /// Error description
450    pub error_description: String,
451}
452
453/// User information response from Baidu NetDisk API
454#[derive(Debug, Deserialize, Serialize, Clone)]
455pub struct UserInfo {
456    /// Baidu account name
457    pub baidu_name: String,
458    /// NetDisk account name
459    pub netdisk_name: String,
460    /// Avatar URL
461    pub avatar_url: String,
462    /// VIP type: 0=normal, 1=VIP, 2=Super VIP
463    pub vip_type: i32,
464    /// User ID
465    pub uk: u64,
466}
467
468/// Quota information response from Baidu NetDisk API
469#[derive(Debug, Deserialize, Serialize, Clone)]
470pub struct QuotaInfo {
471    /// Total storage capacity in bytes
472    pub total: u64,
473    /// Used storage capacity in bytes
474    pub used: u64,
475    /// Free storage capacity in bytes
476    pub free: u64,
477}